├── .gitignore ├── Appendix.pdf ├── LICENSE ├── README.md ├── figures └── main_method.png └── src ├── clevr_hans ├── cnn │ ├── data_xil.py │ ├── model.py │ ├── runs │ │ └── conf_3 │ │ │ ├── resnet-clevr-hans-17-conf_3_seed0 │ │ │ ├── args.txt │ │ │ └── model_epoch56_bestvalloss_0.0132.pth │ │ │ └── resnet-clevr-hans-lexi-17-conf_3_seed0 │ │ │ ├── args.txt │ │ │ └── model_epoch86_bestvalloss_0.0122.pth │ ├── scripts │ │ ├── clevr-hans-resnet-xil.sh │ │ ├── clevr-hans-resnet-xil_plot.sh │ │ ├── clevr-hans-resnet-xil_test.sh │ │ ├── clevr-hans-resnet.sh │ │ ├── clevr-hans-resnet_plot.sh │ │ └── clevr-hans-resnet_test.sh │ ├── train_resnet34.py │ ├── train_resnet34_xil.py │ └── utils.py └── nesy_xil │ ├── data_clevr_hans.py │ ├── logs │ └── slot-attention-clevr-state-3_final │ ├── runs │ └── conf_3 │ │ ├── concept-learner-17-conf_3_seed0 │ │ ├── args.txt │ │ └── model_epoch52_bestvalloss_0.0167.pth │ │ ├── concept-learner-xil-17-conf_3_seed0 │ │ ├── args.txt │ │ └── model_epoch47_bestvalloss_0.7580.pth │ │ └── concept-learner-xil-global-rule-17-conf_3_seed1 │ │ ├── args.txt │ │ └── model_epoch37_bestvalloss_0.1620.pth │ ├── scripts │ ├── clevr-hans-concept-learner-xil_global.sh │ ├── clevr-hans-concept-learner-xil_plot_imgs_global.sh │ ├── clevr-hans-concept-learner-xil_test_global.sh │ ├── clevr-hans-concept-learner_default.sh │ ├── clevr-hans-concept-learner_default_plot_imgs.sh │ ├── clevr-hans-concept-learner_default_test.sh │ ├── clevr-hans-concept-learner_xil.sh │ ├── clevr-hans-concept-learner_xil_plot_imgs.sh │ └── clevr-hans-concept-learner_xil_test.sh │ ├── train_clevr_hans_concept_learner.py │ ├── train_clevr_hans_concept_learner_xil.py │ ├── train_clevr_hans_concept_learner_xil_globalrule.py │ └── xil_losses.py ├── color_mnist ├── data.py ├── data │ └── generate_color_mnist.py ├── runs │ └── default_9c9a152c-13a5-11eb-bbe2-0242ac150002 │ │ ├── events.out.tfevents.1603288758.2b005b7ead9e.36225.0 │ │ ├── figures │ │ ├── test_expl_0.png │ │ ├── test_expl_1.png │ │ ├── test_expl_10.png │ │ ├── test_expl_11.png │ │ ├── test_expl_12.png │ │ ├── test_expl_13.npy │ │ ├── test_expl_13.png │ │ ├── test_expl_13_true_0_pred_9.npy │ │ ├── test_expl_14.png │ │ ├── test_expl_15.png │ │ ├── test_expl_2.png │ │ ├── test_expl_3.png │ │ ├── test_expl_4.png │ │ ├── test_expl_5.png │ │ ├── test_expl_6.png │ │ ├── test_expl_7.png │ │ ├── test_expl_8.png │ │ ├── test_expl_9.png │ │ ├── test_img_0_true_7_pred_3.png │ │ ├── test_img_10_true_0_pred_0.png │ │ ├── test_img_11_true_6_pred_4.png │ │ ├── test_img_12_true_9_pred_0.png │ │ ├── test_img_13_true_0_pred_9.png │ │ ├── test_img_14_true_1_pred_8.png │ │ ├── test_img_15_true_5_pred_2.png │ │ ├── test_img_1_true_2_pred_8.png │ │ ├── test_img_2_true_1_pred_1.png │ │ ├── test_img_3_true_0_pred_6.png │ │ ├── test_img_4_true_4_pred_2.png │ │ ├── test_img_5_true_1_pred_4.png │ │ ├── test_img_6_true_4_pred_5.png │ │ ├── test_img_7_true_9_pred_2.png │ │ ├── test_img_8_true_5_pred_6.png │ │ ├── test_img_9_true_9_pred_3.png │ │ ├── test_imgs.png │ │ └── train_imgs.png │ │ └── mnist_cnn.pt ├── scripts │ ├── color_mnist_plot.sh │ └── color_mnist_train.sh ├── train.py └── utils.py └── docker ├── Dockerfile ├── captum_xil ├── .circleci │ └── config.yml ├── .conda │ └── meta.yaml ├── .gitattributes ├── .github │ └── ISSUE_TEMPLATE │ │ ├── ---bug-report.md │ │ ├── ---documentation.md │ │ ├── ---feature-request.md │ │ └── --questions-and-help.md ├── .gitignore ├── .isort.cfg ├── CITATION ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── captum │ ├── __init__.py │ ├── attr │ │ ├── __init__.py │ │ ├── _core │ │ │ ├── __init__.py │ │ │ ├── deep_lift.py │ │ │ ├── feature_ablation.py │ │ │ ├── feature_permutation.py │ │ │ ├── gradient_shap.py │ │ │ ├── guided_backprop_deconvnet.py │ │ │ ├── guided_grad_cam.py │ │ │ ├── input_x_gradient.py │ │ │ ├── integrated_gradients.py │ │ │ ├── layer │ │ │ │ ├── __init__.py │ │ │ │ ├── grad_cam.py │ │ │ │ ├── internal_influence.py │ │ │ │ ├── layer_activation.py │ │ │ │ ├── layer_conductance.py │ │ │ │ ├── layer_deep_lift.py │ │ │ │ ├── layer_feature_ablation.py │ │ │ │ ├── layer_gradient_shap.py │ │ │ │ ├── layer_gradient_x_activation.py │ │ │ │ └── layer_integrated_gradients.py │ │ │ ├── neuron │ │ │ │ ├── __init__.py │ │ │ │ ├── neuron_conductance.py │ │ │ │ ├── neuron_deep_lift.py │ │ │ │ ├── neuron_feature_ablation.py │ │ │ │ ├── neuron_gradient.py │ │ │ │ ├── neuron_gradient_shap.py │ │ │ │ ├── neuron_guided_backprop_deconvnet.py │ │ │ │ └── neuron_integrated_gradients.py │ │ │ ├── noise_tunnel.py │ │ │ ├── occlusion.py │ │ │ ├── saliency.py │ │ │ └── shapley_value.py │ │ ├── _models │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── pytext.py │ │ └── _utils │ │ │ ├── __init__.py │ │ │ ├── approximation_methods.py │ │ │ ├── attribution.py │ │ │ ├── batching.py │ │ │ ├── common.py │ │ │ ├── gradient.py │ │ │ ├── stat.py │ │ │ ├── summarizer.py │ │ │ ├── typing.py │ │ │ └── visualization.py │ └── insights │ │ ├── __init__.py │ │ ├── _utils │ │ ├── __init__.py │ │ └── transforms.py │ │ ├── api.py │ │ ├── config.py │ │ ├── example.py │ │ ├── features.py │ │ ├── models │ │ └── cifar_torchvision.pt │ │ ├── server.py │ │ └── widget │ │ ├── __init__.py │ │ ├── _version.py │ │ └── widget.py ├── docs │ ├── README.md │ ├── algorithms.md │ ├── algorithms_comparison_matrix.md │ ├── captum_insights.md │ ├── extension │ │ └── integrated_gradients.md │ ├── getting_started.md │ ├── introduction.md │ ├── overview.md │ └── presentations │ │ └── Captum_NeurIPS_2019_final.key ├── environment.yml ├── scripts │ ├── build_docs.sh │ ├── build_insights.sh │ ├── install_via_conda.sh │ ├── install_via_pip.sh │ ├── parse_sphinx.py │ ├── parse_tutorials.py │ ├── publish_site.sh │ ├── run_mypy.sh │ ├── update_versions_html.py │ └── versions.js ├── setup.cfg ├── setup.py ├── sphinx │ ├── Makefile │ ├── README.md │ ├── make.bat │ └── source │ │ ├── approximation_methods.rst │ │ ├── attribution.rst │ │ ├── base_classes.rst │ │ ├── common.rst │ │ ├── conf.py │ │ ├── deconvolution.rst │ │ ├── deep_lift.rst │ │ ├── deep_lift_shap.rst │ │ ├── feature_ablation.rst │ │ ├── feature_permutation.rst │ │ ├── gradient_shap.rst │ │ ├── guided_backprop.rst │ │ ├── guided_grad_cam.rst │ │ ├── index.rst │ │ ├── input_x_gradient.rst │ │ ├── insights.rst │ │ ├── integrated_gradients.rst │ │ ├── layer.rst │ │ ├── neuron.rst │ │ ├── noise_tunnel.rst │ │ ├── occlusion.rst │ │ ├── pytext.rst │ │ ├── saliency.rst │ │ ├── shapley_value_sampling.rst │ │ └── utilities.rst ├── tests │ ├── __init__.py │ ├── attr │ │ ├── __init__.py │ │ ├── helpers │ │ │ ├── __init__.py │ │ │ ├── basic_models.py │ │ │ ├── classification_models.py │ │ │ ├── conductance_reference.py │ │ │ └── utils.py │ │ ├── layer │ │ │ ├── __init__.py │ │ │ ├── test_grad_cam.py │ │ │ ├── test_internal_influence.py │ │ │ ├── test_layer_ablation.py │ │ │ ├── test_layer_activation.py │ │ │ ├── test_layer_conductance.py │ │ │ ├── test_layer_deeplift_basic.py │ │ │ ├── test_layer_gradient_shap.py │ │ │ ├── test_layer_gradient_x_activation.py │ │ │ └── test_layer_integrated_gradients.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── test_base.py │ │ │ └── test_pytext.py │ │ ├── neuron │ │ │ ├── __init__.py │ │ │ ├── test_neuron_ablation.py │ │ │ ├── test_neuron_conductance.py │ │ │ ├── test_neuron_deeplift_basic.py │ │ │ ├── test_neuron_gradient.py │ │ │ ├── test_neuron_gradient_shap.py │ │ │ └── test_neuron_integrated_gradients.py │ │ ├── test_approximation_methods.py │ │ ├── test_common.py │ │ ├── test_data_parallel.py │ │ ├── test_deconvolution.py │ │ ├── test_deeplift_basic.py │ │ ├── test_deeplift_classification.py │ │ ├── test_feature_ablation.py │ │ ├── test_feature_permutation.py │ │ ├── test_gradient.py │ │ ├── test_gradient_shap.py │ │ ├── test_guided_backprop.py │ │ ├── test_guided_grad_cam.py │ │ ├── test_input_x_gradient.py │ │ ├── test_integrated_gradients_basic.py │ │ ├── test_integrated_gradients_classification.py │ │ ├── test_occlusion.py │ │ ├── test_saliency.py │ │ ├── test_shapley.py │ │ ├── test_stat.py │ │ ├── test_summarizer.py │ │ ├── test_targets.py │ │ └── test_utils_batching.py │ └── insights │ │ └── test_contribution.py ├── tutorials │ ├── Bert_SQUAD_Interpret.ipynb │ ├── CIFAR_TorchVision_Captum_Insights.ipynb │ ├── CIFAR_TorchVision_Interpret.ipynb │ ├── House_Prices_Regression_Interpret.ipynb │ ├── IMDB_TorchText_Interpret.ipynb │ ├── Multimodal_VQA_Captum_Insights.ipynb │ ├── Multimodal_VQA_Interpret.ipynb │ ├── README.md │ ├── Resnet_TorchVision_Ablation.ipynb │ ├── Resnet_TorchVision_Interpret.ipynb │ ├── Titanic_Basic_Interpret.ipynb │ ├── img │ │ ├── bert │ │ │ └── visuals_of_start_end_predictions.png │ │ ├── captum_insights.png │ │ ├── captum_insights_vqa.png │ │ ├── resnet │ │ │ ├── swan-3299528_1280.jpg │ │ │ └── swan-3299528_1280_attribution.jpg │ │ ├── sentiment_analysis.png │ │ └── vqa │ │ │ ├── elephant.jpg │ │ │ ├── elephant_attribution.jpg │ │ │ ├── siamese.jpg │ │ │ ├── siamese_attribution.jpg │ │ │ ├── zebra.jpg │ │ │ └── zebra_attribution.jpg │ └── models │ │ ├── boston_model.pt │ │ ├── cifar_torchvision.pt │ │ ├── imdb-model-cnn.pt │ │ └── titanic_model.pt └── website │ ├── core │ ├── Footer.js │ ├── Tutorial.js │ └── TutorialSidebar.js │ ├── package.json │ ├── pages │ ├── en │ │ └── index.js │ └── tutorials │ │ └── index.js │ ├── sidebars.json │ ├── siteConfig.js │ ├── static │ ├── .nojekyll │ ├── css │ │ ├── alabaster.css │ │ ├── basic.css │ │ ├── code_block_buttons.css │ │ └── custom.css │ ├── img │ │ ├── IG_eq1.png │ │ ├── captum-icon.png │ │ ├── captum.ico │ │ ├── captum_insights_screenshot.png │ │ ├── captum_insights_screenshot_vqa.png │ │ ├── captum_logo.png │ │ ├── captum_logo.svg │ │ ├── conductance_eq_1.png │ │ ├── conductance_eq_2.png │ │ ├── deepLIFT_multipliers_eq1.png │ │ ├── expanding_arrows.svg │ │ ├── landing-background.jpg │ │ ├── multi-modal.png │ │ ├── oss_logo.png │ │ └── pytorch_logo.svg │ ├── js │ │ ├── code_block_buttons.js │ │ └── mathjax.js │ └── pygments.css │ └── tutorials.json └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # DS_Store 132 | *.DS_Store 133 | 134 | src/clevr_hans/nesy_xil/NeSyConceptLearner 135 | -------------------------------------------------------------------------------- /Appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/Appendix.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ml-research@TUDarmstadt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeSyXIL (Neuro-Symbolic Explanatory Interactive Learning) 2 | 3 | This is the official repository of the article: [Right for the Right Concept: Revising Neuro-Symbolic Concepts by 4 | Interacting with their Explanations](https://arxiv.org/pdf/2011.12854.pdf) by Wolfgang Stammer, Patrick Schramowski, 5 | Kristian Kersting, published at CVPR 2021. 6 | 7 | This repository contains all source code required to reproduce the experiments of the paper including the ColorMNIST 8 | experiments and running the CNN model. 9 | 10 | In case you are only looking for the NeSy Concept Learner please visit the 11 | separate [repository](https://github.com/ml-research/NeSyConceptLearner). 12 | 13 | ![Concept Learner with NeSy XIL](./figures/main_method.png) 14 | 15 | Included is a docker file as well as a modified version of captum ([official repo](https://captum.ai/)) that we had 16 | modified by a single line such that the gradients of the explanations are kept. Simply running the docker file will 17 | automatically install the captum_xil version. (ref. line 105 in ```NeSyXIL/src/docker/captum_xil/captum/attr/_utils/gradient.py```) 18 | 19 | ## How to Run: 20 | 21 | ### CLEVR-Hans 22 | 23 | #### Dataset 24 | 25 | Please visit the [CLEVR-Hans](https://github.com/ml-research/CLEVR-Hans) repository for instrucitons on how to download 26 | the CLEVR-Hans dataset. 27 | 28 | ### NeSy Concept Learner 29 | 30 | We have moved the source code for the NeSy Concept Learner to a separate 31 | [repository](https://github.com/ml-research/NeSyConceptLearner) in case someone is only interested in the Concept 32 | Learner for their work. In order to reproduce our experiments please clone this repository first: 33 | 34 | 1. ```cd src/clevr_hans/nesy_xil/``` 35 | 36 | 2. ```git clone https://github.com/ml-research/NeSyConceptLearner.git``` 37 | 38 | ### Docker 39 | 40 | We have attached a Dockerfile to make reproduction easier. We further recommend to build your own docker-compose file 41 | based on the DockerFile. To run without a docker-compose file: 42 | 43 | 1. ```cd src/docker/``` 44 | 45 | 2. ```docker build -t nesy-xil -f Dockerfile .``` 46 | 47 | 3. ```docker run -it -v /pathto/NeSyXIL:/workspace/repositories/NeSyXIL -v /pathto/CLEVR-Hans3:/workspace/datasets/CLEVR-Hans3 --name nesy-xil --entrypoint='/bin/bash' --runtime nvidia nesy-xil``` 48 | 49 | ### NeSy XIL experiments 50 | 51 | We have included the trained NeSy Concept Learner files for seeds 0 (or 1 for the global rule). To reproduce the NeSy 52 | XIL results please run the desired script e.g. as 53 | 54 | ```cd src/clevr_hans/nesy_xil/``` 55 | 56 | ```./scripts/clevr-hans-concept-learner-xil_global.sh 0 100 /workspace/datasets/CLEVR-Hans3/``` 57 | 58 | For training with the RRRLoss on CLEVR-Hans3 with GPU 0 and run number 100. This corresponds to the setting of our 59 | experiments on the global correction rule. 60 | 61 | ### CNN experiments 62 | 63 | We have included the trained CNN model files for seed 0. To reproduce the CNN experiments please runs the desired 64 | scripts e.g. as 65 | 66 | ```cd src/clevr_hans/cnn/``` 67 | 68 | ```./scripts/clevr-hans-resnet-xil.sh 0 0 /workspace/datasets/CLEVR-Hans3/``` 69 | 70 | For training with the HINTLoss on CLEVR-Hans3 with GPU 0 and run number 0. 71 | 72 | ### ColorMNIST 73 | 74 | For running the ColorMNIST experiment with the provided dockerfile: 75 | 76 | 1. ```cd src/color_mnist/data/``` 77 | 78 | 2. ```python generate_color_mnist.py``` 79 | 80 | 3. ```cd ..``` 81 | 82 | 4. ```./scripts/color_mnist_train.sh 0``` for running the default experiment on GPU 0. 83 | 84 | ## Citation 85 | If you find this code useful in your research, please consider citing: 86 | 87 | > @InProceedings{Stammer_2021_CVPR, 88 | author = {Stammer, Wolfgang and Schramowski, Patrick and Kersting, Kristian}, 89 | title = {Right for the Right Concept: Revising Neuro-Symbolic Concepts by Interacting With Their Explanations}, 90 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 91 | month = {June}, 92 | year = {2021}, 93 | pages = {3619-3629} 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /figures/main_method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/figures/main_method.png -------------------------------------------------------------------------------- /src/clevr_hans/cnn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | class ResNet34Small(nn.Module): 7 | def __init__(self, num_classes): 8 | super(ResNet34Small, self).__init__() 9 | original_model = models.resnet34(pretrained=True) 10 | self.features = nn.Sequential(*list(original_model.children())[:-3]) 11 | 12 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 13 | self.fc = nn.Linear(256, num_classes) 14 | 15 | def forward(self, x): 16 | x = self.features(x) 17 | x = self.avgpool(x) 18 | x = torch.flatten(x, 1) 19 | x = self.fc(x) 20 | return x 21 | -------------------------------------------------------------------------------- /src/clevr_hans/cnn/runs/conf_3/resnet-clevr-hans-17-conf_3_seed0/args.txt: -------------------------------------------------------------------------------- 1 | 2 | name: resnet-clevr-hans-17-conf_3 3 | seed: 0 4 | resume: None 5 | mode: train 6 | data_dir: /workspace/datasets/Clevr_Hans/conf_3/ 7 | fp_ckpt: None 8 | epochs: 100 9 | lr: 0.0001 10 | batch_size: 64 11 | num_workers: 4 12 | dataset: clevr-hans-state 13 | no_cuda: False 14 | train_only: False 15 | eval_only: False 16 | conf_version: conf_3 17 | device: cuda 18 | n_imgclasses: 3 19 | classes: [0 1 2] 20 | category_ids: [ 3 6 8 10 18] -------------------------------------------------------------------------------- /src/clevr_hans/cnn/runs/conf_3/resnet-clevr-hans-17-conf_3_seed0/model_epoch56_bestvalloss_0.0132.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/clevr_hans/cnn/runs/conf_3/resnet-clevr-hans-17-conf_3_seed0/model_epoch56_bestvalloss_0.0132.pth -------------------------------------------------------------------------------- /src/clevr_hans/cnn/runs/conf_3/resnet-clevr-hans-lexi-17-conf_3_seed0/args.txt: -------------------------------------------------------------------------------- 1 | 2 | name: resnet-clevr-hans-lexi-17-conf_3 3 | seed: 0 4 | resume: None 5 | mode: train 6 | data_dir: /workspace/datasets/Clevr_Hans/conf_3/ 7 | fp_ckpt: None 8 | epochs: 100 9 | lr: 0.0001 10 | l2_grads: 10.0 11 | batch_size: 64 12 | num_workers: 4 13 | dataset: clevr-hans-state 14 | no_cuda: False 15 | train_only: False 16 | eval_only: False 17 | conf_version: conf_3 18 | device: cuda 19 | n_imgclasses: 3 20 | class_weights: tensor([0.3333, 0.3333, 0.3333]) 21 | classes: [0 1 2] 22 | category_ids: [ 3 6 8 10 18] -------------------------------------------------------------------------------- /src/clevr_hans/cnn/runs/conf_3/resnet-clevr-hans-lexi-17-conf_3_seed0/model_epoch86_bestvalloss_0.0122.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/clevr_hans/cnn/runs/conf_3/resnet-clevr-hans-lexi-17-conf_3_seed0/model_epoch86_bestvalloss_0.0122.pth -------------------------------------------------------------------------------- /src/clevr_hans/cnn/scripts/clevr-hans-resnet-xil.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="resnet-clevr-hans-lexi-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # Train on CLEVR_Hans with resnet model 12 | 13 | CUDA_VISIBLE_DEVICES=$DEVICE python train_resnet34_lexi.py --data-dir $DATA --dataset $DATASET --epochs 100 \ 14 | --name $MODEL --lr 0.0001 --batch-size 64 --l2_grads 10 --seed 0 --mode train 15 | -------------------------------------------------------------------------------- /src/clevr_hans/cnn/scripts/clevr-hans-resnet-xil_plot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="resnet-clevr-hans-lexi-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # Train on CLEVR_Hans with resnet model 12 | 13 | # without obj coordinates 14 | CUDA_VISIBLE_DEVICES=$DEVICE python train_resnet34_lexi.py --data-dir $DATA --dataset $DATASET --epochs 100 \ 15 | --name $MODEL --lr 0.0001 --batch-size 500 --l2_grads 10 --seed 0 --mode plot \ 16 | --fp-ckpt runs/conf_3/resnet-clevr-hans-lexi-17-conf_3_seed0/model_epoch86_bestvalloss_0.0122.pth 17 | -------------------------------------------------------------------------------- /src/clevr_hans/cnn/scripts/clevr-hans-resnet-xil_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="resnet-clevr-hans-lexi-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # CLEVR_Hans3 12 | 13 | CUDA_VISIBLE_DEVICES=$DEVICE python train_resnet34_lexi.py --data-dir $DATA --dataset $DATASET --epochs 100 \ 14 | --name $MODEL --lr 0.0001 --batch-size 64 --l2_grads 10 --seed 0 --mode test \ 15 | --fp-ckpt runs/conf_3/resnet-clevr-hans-lexi-17-conf_3_seed0/model_epoch86_bestvalloss_0.0122.pth 16 | -------------------------------------------------------------------------------- /src/clevr_hans/cnn/scripts/clevr-hans-resnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="resnet-clevr-hans-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # Train on CLEVR_Hans with resnet model 12 | 13 | CUDA_VISIBLE_DEVICES=$DEVICE python train_resnet34.py --data-dir $DATA --dataset $DATASET --epochs 100 --name $MODEL \ 14 | --lr 0.0001 --batch-size 64 --seed 10 --mode train -------------------------------------------------------------------------------- /src/clevr_hans/cnn/scripts/clevr-hans-resnet_plot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="resnet-clevr-hans-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # Train on CLEVR_Hans with resnet model 12 | 13 | # without obj coordinates 14 | CUDA_VISIBLE_DEVICES=$DEVICE python train_resnet34.py --data-dir $DATA --dataset $DATASET --epochs 100 --name $MODEL \ 15 | --lr 0.0001 --batch-size 500 --seed 0 --mode plot \ 16 | --fp-ckpt runs/conf_3/resnet-clevr-hans-17-conf_3_seed0/model_epoch56_bestvalloss_0.0132.pth 17 | -------------------------------------------------------------------------------- /src/clevr_hans/cnn/scripts/clevr-hans-resnet_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="resnet-clevr-hans-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # CLEVR-Hans3 12 | 13 | CUDA_VISIBLE_DEVICES=$DEVICE python train_resnet34.py --data-dir $DATA --dataset $DATASET --epochs 100 --name $MODEL \ 14 | --lr 0.0001 --batch-size 64 --seed 0 --mode test \ 15 | --fp-ckpt runs/conf_3/resnet-clevr-hans-17-conf_3_seed0/model_epoch56_bestvalloss_0.0132.pth 16 | -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/logs/slot-attention-clevr-state-3_final: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/clevr_hans/nesy_xil/logs/slot-attention-clevr-state-3_final -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/runs/conf_3/concept-learner-17-conf_3_seed0/args.txt: -------------------------------------------------------------------------------- 1 | 2 | name: slot-attention-clevr-state-set-transformer-17-conf_3 3 | mode: train 4 | resume: None 5 | seed: 0 6 | epochs: 100 7 | lr: 0.0001 8 | batch_size: 128 9 | num_workers: 4 10 | dataset: clevr-hans-state 11 | no_cuda: False 12 | data_dir: /workspace/datasets/Clevr_Hans/conf_3/ 13 | fp_ckpt: None 14 | n_slots: 10 15 | n_iters_slot_att: 3 16 | n_attr: 18 17 | n_heads: 4 18 | set_transf_hidden: 128 19 | conf_version: conf_3 20 | device: cuda 21 | conf_num: conf_3 22 | n_imgclasses: 3 23 | classes: [0 1 2] 24 | category_ids: [ 3 6 8 10 18] -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/runs/conf_3/concept-learner-17-conf_3_seed0/model_epoch52_bestvalloss_0.0167.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/clevr_hans/nesy_xil/runs/conf_3/concept-learner-17-conf_3_seed0/model_epoch52_bestvalloss_0.0167.pth -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/runs/conf_3/concept-learner-xil-17-conf_3_seed0/args.txt: -------------------------------------------------------------------------------- 1 | 2 | name: slot-attention-clevr-state-set-transformer-xil-17-conf_3 3 | mode: train 4 | resume: None 5 | seed: 0 6 | epochs: 50 7 | lr: 0.001 8 | l2_grads: 1000.0 9 | batch_size: 128 10 | num_workers: 4 11 | dataset: clevr-hans-state 12 | no_cuda: False 13 | train_only: False 14 | eval_only: False 15 | multi_gpu: False 16 | data_dir: /workspace/datasets/Clevr_Hans/conf_3/ 17 | fp_ckpt: None 18 | n_slots: 10 19 | n_iters_slot_att: 3 20 | n_attr: 18 21 | n_heads: 4 22 | set_transf_hidden: 128 23 | conf_version: conf_3 24 | device: cuda 25 | n_imgclasses: 3 26 | class_weights: tensor([0.3333, 0.3333, 0.3333]) 27 | classes: [0 1 2] 28 | category_ids: [ 3 6 8 10 18] -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/runs/conf_3/concept-learner-xil-17-conf_3_seed0/model_epoch47_bestvalloss_0.7580.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/clevr_hans/nesy_xil/runs/conf_3/concept-learner-xil-17-conf_3_seed0/model_epoch47_bestvalloss_0.7580.pth -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/runs/conf_3/concept-learner-xil-global-rule-17-conf_3_seed1/args.txt: -------------------------------------------------------------------------------- 1 | 2 | name: slot-attention-clevr-state-set-transformer-rrr-4-conf_3 3 | mode: train 4 | resume: None 5 | seed: 1 6 | epochs: 50 7 | lr: 0.001 8 | l2_grads: 20.0 9 | batch_size: 128 10 | num_workers: 4 11 | dataset: clevr-hans-state 12 | no_cuda: False 13 | train_only: False 14 | eval_only: False 15 | multi_gpu: False 16 | data_dir: /workspace/datasets/Clevr_Hans/conf_3/ 17 | fp_ckpt: None 18 | n_slots: 10 19 | n_iters_slot_att: 3 20 | n_attr: 18 21 | n_heads: 4 22 | set_transf_hidden: 128 23 | conf_version: conf_3 24 | device: cuda 25 | n_imgclasses: 3 26 | class_weights: tensor([0.3333, 0.3333, 0.3333]) 27 | classes: [0 1 2] 28 | category_ids: [ 3 6 8 10 18] -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/runs/conf_3/concept-learner-xil-global-rule-17-conf_3_seed1/model_epoch37_bestvalloss_0.1620.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/clevr_hans/nesy_xil/runs/conf_3/concept-learner-xil-global-rule-17-conf_3_seed1/model_epoch37_bestvalloss_0.1620.pth -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/scripts/clevr-hans-concept-learner-xil_global.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="concept-learner-xil-global-rule-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # CLEVR-Hans3 12 | 13 | CUDA_VISIBLE_DEVICES=$DEVICE python train_clevr_hans_concept_learner_xil_globalrule.py --data-dir $DATA \ 14 | --dataset $DATASET --epochs 50 --name $MODEL --lr 0.001 --l2_grads 20 --batch-size 128 --n-slots 10 \ 15 | --n-iters-slot-att 3 --n-attr 18 --seed 1 --mode train 16 | -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/scripts/clevr-hans-concept-learner-xil_plot_imgs_global.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="concept-learner-xil-global-rule-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | 12 | CUDA_VISIBLE_DEVICES=$DEVICE python train_clevr_hans_concept_learner_xil_globalrule.py --data-dir $DATA \ 13 | --dataset $DATASET --epochs 100 --name $MODEL --lr 0.001 --l2_grads 20 --batch-size 128 --n-slots 10 \ 14 | --n-iters-slot-att 3 --n-attr 18 --seed 1 --mode plot \ 15 | --fp-ckpt runs/conf_3/concept-learner-xil-global-rule-17-conf_3_seed1/model_epoch37_bestvalloss_0.1620.pth -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/scripts/clevr-hans-concept-learner-xil_test_global.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="concept-learner-xil-global-rule-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # CLEVR-Hans3 12 | 13 | python train_clevr_hans_concept_learner_xil_globalrule.py --data-dir $DATA --dataset $DATASET --epochs 50 \ 14 | --name $MODEL --lr 0.001 --l2_grads 20 --batch-size 128 --n-slots 10 --n-iters-slot-att 3 --n-attr 18 --seed 0 \ 15 | --mode test \ 16 | --fp-ckpt runs/conf_3/concept-learner-xil-global-rule-17-conf_3_seed1/model_epoch48_bestvalloss_0.3832.pth 17 | -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/scripts/clevr-hans-concept-learner_default.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="concept-learner-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # CLEVR-Hans3 12 | 13 | CUDA_VISIBLE_DEVICES=$DEVICE python train_clevr_hans_concept_learner.py --data-dir $DATA --dataset $DATASET \ 14 | --epochs 50 --name $MODEL --lr 0.0001 --batch-size 128 --n-slots 10 --n-iters-slot-att 3 --n-attr 18 --seed 0 \ 15 | --mode train -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/scripts/clevr-hans-concept-learner_default_plot_imgs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="concept-learner-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | 12 | # without obj coordinates 13 | CUDA_VISIBLE_DEVICES=$DEVICE python train_clevr_hans_concept_learner.py --data-dir $DATA --dataset $DATASET \ 14 | --epochs 50 --name $MODEL --lr 0.0001 --batch-size 128 --n-slots 10 --n-iters-slot-att 3 --n-attr 18 --seed 0 \ 15 | --mode plot \ 16 | --fp-ckpt runs/conf_3/concept-learner-17-conf_3_seed0/model_epoch52_bestvalloss_0.0167.pth 17 | -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/scripts/clevr-hans-concept-learner_default_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="concept-learner-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # CLEVR-Hans3 12 | 13 | CUDA_VISIBLE_DEVICES=$DEVICE python train_clevr_hans_concept_learner.py --data-dir $DATA --dataset $DATASET \ 14 | --epochs 50 --name $MODEL --lr 0.0001 --batch-size 128 --n-slots 10 --n-iters-slot-att 3 --n-attr 18 --seed 0 \ 15 | --mode test \ 16 | --fp-ckpt runs/conf_3/concept-learner-17-conf_3_seed0/model_epoch52_bestvalloss_0.0167.pth 17 | -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/scripts/clevr-hans-concept-learner_xil.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="concept-learner-xil-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # CLEVR-Hans3 12 | 13 | CUDA_VISIBLE_DEVICES=$DEVICE python train_clevr_hans_concept_learner_xil.py --data-dir $DATA --dataset $DATASET \ 14 | --epochs 50 --name $MODEL --lr 0.001 --l2_grads 1000 --batch-size 128 --n-slots 10 --n-iters-slot-att 3 --n-attr 18 \ 15 | --seed 0 --mode train 16 | -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/scripts/clevr-hans-concept-learner_xil_plot_imgs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="concept-learner-xil-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | 12 | cCUDA_VISIBLE_DEVICES=$DEVICE python train_clevr_hans_concept_learner_xil.py --data-dir $DATA --dataset $DATASET \ 13 | --epochs 100 --name $MODEL --lr 0.001 --l2_grads 1000 --batch-size 128 --n-slots 10 --n-iters-slot-att 3 --n-attr 18 \ 14 | --seed 0 --mode plot \ 15 | --fp-ckpt runs/conf_3/concept-learner-xil-17-conf_3_seed0/model_epoch47_bestvalloss_0.7580.pth -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/scripts/clevr-hans-concept-learner_xil_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # CUDA DEVICE ID 4 | DEVICE=$1 5 | NUM=$2 6 | DATA=$3 7 | MODEL="concept-learner-xil-$NUM" 8 | DATASET=clevr-hans-state 9 | 10 | #-------------------------------------------------------------------------------# 11 | # CLEVR-Hans3 12 | 13 | python train_clevr_hans_concept_learner_xil.py --data-dir $DATA --dataset $DATASET --epochs 50 --name $MODEL \ 14 | --lr 0.001 --l2_grads 1000 --batch-size 128 --n-slots 10 --n-iters-slot-att 3 --n-attr 18 --seed 0 --mode test \ 15 | --fp-ckpt runs/conf_3/concept-learner-xil-17-conf_3_seed0/model_epoch47_bestvalloss_0.7580.pth 16 | 17 | -------------------------------------------------------------------------------- /src/clevr_hans/nesy_xil/xil_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for rrr loss. 3 | @author: Wolfgang Stammer and Patrick Schramowski 4 | @date: 11.05.2020 5 | """ 6 | 7 | import torch 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import scipy.ndimage as ndimage 11 | from captum.attr import IntegratedGradients 12 | 13 | 14 | def resize_mask_to_convoutput(mask, output): 15 | return torch.tensor(ndimage.zoom(mask.cpu(), (1, 1, output.shape[-2]/mask.shape[-2], output.shape[-1]/mask.shape[-1]))) 16 | 17 | 18 | log_softmax = torch.nn.LogSoftmax(dim=1).cuda() 19 | 20 | 21 | def rrr_loss_function(A, X, y, logits, criterion, l2_grads=1000, reduce_func=torch.sum): 22 | log_prob_ys = log_softmax(logits) 23 | right_answer_loss = criterion(log_prob_ys, y) 24 | 25 | gradXes = torch.autograd.grad(log_prob_ys, X, torch.ones_like(log_prob_ys).cuda(), create_graph=True)[0] 26 | 27 | for _ in range(len(gradXes.shape) - len(A.shape)): 28 | A = A.unsqueeze(dim=1) 29 | expand_list = [-1]*len(A.shape) 30 | expand_list[-3] = gradXes.shape[-3] 31 | A = A.expand(expand_list) 32 | A_gradX = torch.mul(A, gradXes) ** 2 33 | 34 | # right_reason_loss = l2_grads * torch.sum(A_gradX) 35 | right_reason_loss = torch.sum(A_gradX, dim=list(range(1, len(A_gradX.shape)))) 36 | right_reason_loss = reduce_func(right_reason_loss) 37 | right_reason_loss *= l2_grads 38 | 39 | res = right_answer_loss + right_reason_loss 40 | return res, right_answer_loss, right_reason_loss 41 | 42 | 43 | class LexiLoss: 44 | def __init__(self, class_weights, args): 45 | self.criterion = torch.nn.MSELoss() 46 | self.args = args 47 | 48 | def norm_saliencies(self, saliencies): 49 | saliencies_norm = saliencies.clone() 50 | 51 | for i in range(saliencies.shape[0]): 52 | if len(torch.nonzero(saliencies[i], as_tuple=False)) == 0: 53 | saliencies_norm[i] = saliencies[i] 54 | else: 55 | saliencies_norm[i] = (saliencies[i] - torch.min(saliencies[i])) / \ 56 | (torch.max(saliencies[i]) - torch.min(saliencies[i])) 57 | 58 | return saliencies_norm 59 | 60 | def generate_intgrad_captum_table(self, net, input, labels): 61 | labels = labels.to("cuda") 62 | explainer = IntegratedGradients(net) 63 | saliencies = explainer.attribute(input, target=labels) 64 | # remove negative attributions 65 | saliencies[saliencies < 0] = 0. 66 | return self.norm_saliencies(saliencies) 67 | 68 | def __call__(self, net, target_set, class_ids, masks, epoch, batch_id, writer=None, writer_prefix=""): 69 | 70 | masks = masks.squeeze(dim=1).double() 71 | saliencies = self.generate_intgrad_captum_table(net, target_set, class_ids).squeeze(dim=1) 72 | 73 | assert len(saliencies.shape) == 3 74 | assert masks.shape == saliencies.shape 75 | 76 | loss = self.criterion(saliencies, masks) 77 | 78 | return loss 79 | -------------------------------------------------------------------------------- /src/color_mnist/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | import torch.utils.data as utils 6 | 7 | torch.manual_seed(0) 8 | 9 | NUM_WORKERS = 2 10 | 11 | 12 | def get_dataset(root, dataset, batch_size): 13 | if dataset == 'mnist': 14 | trans = ([transforms.ToTensor()]) 15 | trans = transforms.Compose(trans) 16 | fulltrainset = torchvision.datasets.MNIST(root=root, train=True, transform=trans, download=True) 17 | 18 | train_set = fulltrainset 19 | trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, 20 | num_workers=NUM_WORKERS, pin_memory=True) 21 | #validloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, 22 | # num_workers=NUM_WORKERS, pin_memory=True) 23 | validloader = None 24 | test_set = torchvision.datasets.MNIST(root=root, train=False, transform=trans) 25 | testloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS) 26 | 27 | nb_classes = 10 28 | dim_inp = 28 * 28 29 | elif 'cmnist' in dataset: 30 | data_dir_cmnist = root + dataset + '/' 31 | data_x = np.load(data_dir_cmnist + 'train_x.npy') 32 | data_y = np.load(data_dir_cmnist + 'train_y.npy') 33 | 34 | data_x = torch.from_numpy(data_x).type('torch.FloatTensor') 35 | data_y = torch.from_numpy(data_y).type('torch.LongTensor') 36 | 37 | my_dataset = utils.TensorDataset(data_x, data_y) 38 | 39 | #train_set, valset = _split_train_val(my_dataset, val_fraction=0.1) 40 | train_set = my_dataset 41 | trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS) 42 | #validloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, 43 | # num_workers=NUM_WORKERS, pin_memory=True) 44 | validloader = None 45 | data_x = np.load(data_dir_cmnist + 'test_x.npy') 46 | data_y = np.load(data_dir_cmnist + 'test_y.npy') 47 | data_x = torch.from_numpy(data_x).type('torch.FloatTensor') 48 | data_y = torch.from_numpy(data_y).type('torch.LongTensor') 49 | my_dataset = utils.TensorDataset(data_x, data_y) 50 | testloader = torch.utils.data.DataLoader(my_dataset, batch_size=batch_size, shuffle=False, num_workers=NUM_WORKERS) 51 | 52 | nb_classes = 10 53 | dim_inp = 28 * 28 * 3 54 | else: 55 | raise ValueError('unknown dataset') 56 | return trainloader, validloader, testloader, nb_classes, dim_inp 57 | 58 | 59 | # if __name__ == '__main__': 60 | # import matplotlib.pyplot as plt 61 | # #trainloader, validloader, testloader, nb_classes, dim_inp = get_dataset('./.data', 'fgbg_cmnist_cpr0.5-0.5', 1) 62 | # trainloader, validloader, testloader, nb_classes, dim_inp = get_dataset('./.data', 'fgbg_cmnist_cpr0.1-0.1-0.1-0.1-0.1-0.1-0.1-0.1-0.1-0.1', 1) 63 | # #trainloader, validloader, testloader, nb_classes, dim_inp = get_dataset('./.data', 'fgbg_cmnist_cpr1', 1) 64 | # 65 | # for i in range(10): 66 | # cnt = 0 67 | # for sample, target in trainloader: 68 | # if target.item() == i: 69 | # plt.figure() 70 | # sample = np.transpose(sample.cpu().detach().numpy().squeeze(), (1,2,0)) 71 | # plt.imshow(sample) 72 | # cnt += 1 73 | # if cnt == 4: 74 | # cnt = 0 75 | # break 76 | # plt.show() 77 | -------------------------------------------------------------------------------- /src/color_mnist/data/generate_color_mnist.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | import numpy as np 4 | import torchvision.transforms as transforms 5 | import tqdm 6 | import os 7 | from colour import Color 8 | 9 | data_path = './data/' 10 | 11 | trans = ([transforms.ToTensor()]) 12 | trans = transforms.Compose(trans) 13 | fulltrainset = torchvision.datasets.MNIST(root=data_path, train=True, download=True, transform=trans) 14 | trainloader = torch.utils.data.DataLoader(fulltrainset, batch_size=2000, shuffle=False, num_workers=2, pin_memory=True) 15 | test_set = torchvision.datasets.MNIST(root=data_path, train=False, download=True, transform=trans) 16 | testloader = torch.utils.data.DataLoader(test_set, batch_size=2000, shuffle=False, num_workers=2, pin_memory=True) 17 | nb_classes = 10 18 | 19 | 20 | red = Color("red") 21 | COLORS = list(red.range_to(Color("purple"),10)) 22 | COLORS = [np.asarray(x.get_rgb())*255 for x in COLORS] 23 | COLORS = [x.astype('int') for x in COLORS] 24 | COLOR_NAMES = ['red', 'tangerine', 'lime', 'harlequin', 'malachite', 'persian green', 25 | 'allports', 'resolution blue', 'pigment indigo', 'purple'] 26 | 27 | 28 | def gen_fgcolor_data(loader, img_size=(3, 28, 28), split='train'): 29 | tot_iters = len(loader) 30 | for i in tqdm.tqdm(range(tot_iters), total=tot_iters): 31 | x, targets = next(iter(loader)) 32 | assert len( 33 | x.size()) == 4, 'Something is wrong, size of input x should be 4 dimensional (B x C x H x W; perhaps number of channels is degenrate? If so, it should be 1)' 34 | targets = targets.cpu().numpy() 35 | bs = targets.shape[0] 36 | 37 | x_rgb = torch.ones(x.size(0), 3, x.size()[2], x.size()[3]).type('torch.FloatTensor') 38 | x_rgb = x_rgb * x 39 | x_rgb_fg = 1. * x_rgb 40 | 41 | if split == 'test': 42 | color_choice = np.random.randint(0, 10, targets.shape[0]) 43 | elif split == 'train': 44 | color_choice = targets.astype('int') 45 | c = np.array([COLORS[ind] for ind in color_choice]) 46 | c = c.reshape(-1, 3, 1, 1) 47 | c = torch.from_numpy(c).type('torch.FloatTensor') 48 | x_rgb_fg[:, 0] = x_rgb_fg[:, 0] * c[:, 0] 49 | x_rgb_fg[:, 1] = x_rgb_fg[:, 1] * c[:, 1] 50 | x_rgb_fg[:, 2] = x_rgb_fg[:, 2] * c[:, 2] 51 | 52 | bg = (torch.zeros_like(x_rgb)) 53 | x_rgb = x_rgb_fg + bg 54 | x_rgb = torch.clamp(x_rgb, 0., 255.) 55 | 56 | if i == 0: 57 | color_data_x = np.zeros((bs * tot_iters, *img_size)) 58 | color_data_y = np.zeros((bs * tot_iters,)) 59 | 60 | color_data_x[i * bs: (i + 1) * bs] = x_rgb / 255. 61 | color_data_y[i * bs: (i + 1) * bs] = targets 62 | return color_data_x, color_data_y 63 | 64 | 65 | dir_name = data_path + '/cmnist/' 66 | print(dir_name) 67 | if not os.path.exists(data_path + 'cmnist/'): 68 | os.makedirs(data_path + 'cmnist/', exist_ok=True) 69 | if not os.path.exists(dir_name): 70 | os.makedirs(dir_name, exist_ok=True) 71 | 72 | color_data_x, color_data_y = gen_fgcolor_data(trainloader, img_size=(3, 28, 28), split='train') 73 | np.save(dir_name + 'train_x.npy', color_data_x) 74 | np.save(dir_name + 'train_y.npy', color_data_y) 75 | 76 | color_data_x, color_data_y = gen_fgcolor_data(testloader, img_size=(3, 28, 28), split='test') 77 | np.save(dir_name + 'test_x.npy', color_data_x) 78 | np.save(dir_name + 'test_y.npy', color_data_y) 79 | -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/events.out.tfevents.1603288758.2b005b7ead9e.36225.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/events.out.tfevents.1603288758.2b005b7ead9e.36225.0 -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_0.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_1.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_10.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_11.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_12.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_13.npy -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_13.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_13_true_0_pred_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_13_true_0_pred_9.npy -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_14.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_15.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_2.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_3.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_4.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_5.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_6.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_7.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_8.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_expl_9.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_0_true_7_pred_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_0_true_7_pred_3.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_10_true_0_pred_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_10_true_0_pred_0.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_11_true_6_pred_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_11_true_6_pred_4.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_12_true_9_pred_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_12_true_9_pred_0.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_13_true_0_pred_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_13_true_0_pred_9.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_14_true_1_pred_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_14_true_1_pred_8.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_15_true_5_pred_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_15_true_5_pred_2.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_1_true_2_pred_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_1_true_2_pred_8.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_2_true_1_pred_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_2_true_1_pred_1.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_3_true_0_pred_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_3_true_0_pred_6.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_4_true_4_pred_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_4_true_4_pred_2.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_5_true_1_pred_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_5_true_1_pred_4.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_6_true_4_pred_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_6_true_4_pred_5.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_7_true_9_pred_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_7_true_9_pred_2.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_8_true_5_pred_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_8_true_5_pred_6.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_9_true_9_pred_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_img_9_true_9_pred_3.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_imgs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/test_imgs.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/train_imgs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/figures/train_imgs.png -------------------------------------------------------------------------------- /src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/mnist_cnn.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/color_mnist/runs/default_9c9a152c-13a5-11eb-bbe2-0242ac150002/mnist_cnn.pt -------------------------------------------------------------------------------- /src/color_mnist/scripts/color_mnist_plot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # to be run as: python color_mnist_plot.sh 0 4 | # (for cuda device 0) 5 | 6 | # CUDA DEVICE ID 7 | DEVICE=$1 8 | 9 | #-------------------------------------------------------------------------------# 10 | # Create plot results from trained model 11 | CUDA_VISIBLE_DEVICES=$DEVICE python train_default.py --mode plot --fp-ckpt runs//mnist_cnn.pt --batch-size 128 -------------------------------------------------------------------------------- /src/color_mnist/scripts/color_mnist_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # to be run as: python color_mnist_train.sh 0 4 | # (for cuda device 0) 5 | 6 | # CUDA DEVICE ID 7 | DEVICE=$1 8 | 9 | #-------------------------------------------------------------------------------# 10 | # Train on ColorMNIST 11 | CUDA_VISIBLE_DEVICES=$DEVICE python train.py --mode train -------------------------------------------------------------------------------- /src/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:20.03-py3 2 | WORKDIR /workspace 3 | COPY requirements.txt requirements.txt 4 | RUN pip install -r requirements.txt 5 | RUN ["apt-get", "update"] 6 | RUN ["apt-get", "install", "-y", "zsh"] 7 | RUN wget https://github.com/robbyrussell/oh-my-zsh/raw/master/tools/install.sh -O - | zsh || true 8 | COPY captum_xil /workspace/captum_xil 9 | RUN cd captum_xil/ && pip install -e . && cd .. 10 | -------------------------------------------------------------------------------- /src/docker/captum_xil/.conda/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set data = load_setup_py_data(setup_file="../setup.py", from_recipe_dir=True) %} 2 | 3 | package: 4 | name: {{ data.get("name")|lower }} 5 | version: {{ data.get("version") }} 6 | 7 | source: 8 | path: .. 9 | 10 | build: 11 | noarch: python 12 | script: "$PYTHON setup.py install --single-version-externally-managed --record=record.txt" 13 | 14 | requirements: 15 | host: 16 | - python>=3.6 17 | run: 18 | - numpy 19 | - pytorch>=1.2 20 | - matplotlib 21 | 22 | test: 23 | imports: 24 | - captum 25 | 26 | about: 27 | home: https://captum.ai 28 | license: BSD 29 | license_file: LICENSE 30 | summary: Model interpretability for PyTorch 31 | doc_url: https://captum.ai 32 | dev_url: https://github.com/pytorch/captum 33 | -------------------------------------------------------------------------------- /src/docker/captum_xil/.gitattributes: -------------------------------------------------------------------------------- 1 | tutorials/* linguist-documentation 2 | *.pt binary 3 | -------------------------------------------------------------------------------- /src/docker/captum_xil/.github/ISSUE_TEMPLATE/---bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug report" 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 🐛 Bug 11 | 12 | 13 | 14 | ## To Reproduce 15 | 16 | Steps to reproduce the behavior: 17 | 18 | 1. 19 | 1. 20 | 1. 21 | 22 | 23 | 24 | ## Expected behavior 25 | 26 | 27 | 28 | ## Environment 29 | Describe the environment used for Captum 30 | 31 | ``` 32 | 33 | - Captum / PyTorch Version (e.g., 1.0 / 0.4.0): 34 | - OS (e.g., Linux): 35 | - How you installed Captum / PyTorch (`conda`, `pip`, source): 36 | - Build command you used (if compiling from source): 37 | - Python version: 38 | - CUDA/cuDNN version: 39 | - GPU models and configuration: 40 | - Any other relevant information: 41 | 42 | ## Additional context 43 | 44 | 45 | -------------------------------------------------------------------------------- /src/docker/captum_xil/.github/ISSUE_TEMPLATE/---documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4DA Documentation" 3 | about: Describe this issue template's purpose here. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 📚 Documentation 11 | 12 | 15 | -------------------------------------------------------------------------------- /src/docker/captum_xil/.github/ISSUE_TEMPLATE/---feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680 Feature request" 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 🚀 Feature 11 | 12 | 13 | ## Motivation 14 | 15 | 16 | 17 | ## Pitch 18 | 19 | 20 | 21 | ## Alternatives 22 | 23 | 24 | 25 | ## Additional context 26 | 27 | 28 | -------------------------------------------------------------------------------- /src/docker/captum_xil/.github/ISSUE_TEMPLATE/--questions-and-help.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "❓ Questions and Help" 3 | about: Describe this issue template's purpose here. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## ❓ Questions and Help 11 | 12 | We have a set of listed resources available on the website and FAQ: https://captum.ai/ and https://captum.ai/docs/faq . Feel free to open an issue here on the github or in our discussion forums: 13 | 14 | - [Discussion Forum](https://discuss.pytorch.org/c/captum) 15 | -------------------------------------------------------------------------------- /src/docker/captum_xil/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Downloaded data 7 | data/ 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | captum/insights/frontend/node_modules/ 19 | captum/insights/frontend/.pnp/ 20 | captum/insights/frontend/build/ 21 | captum/insights/widget/static/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # watchman 34 | # Watchman helper 35 | .watchmanconfig 36 | 37 | # OSX 38 | *.DS_Store 39 | 40 | # Atom plugin files and ctags 41 | .ftpconfig 42 | .ftpconfig.cson 43 | .ftpignore 44 | *.tags 45 | *.tags1 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | captum/insights/frontend/npm-debug.log* 57 | captum/insights/frontend/yarn-debug.log* 58 | captum/insights/frontend/yarn-error.log* 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | captum/insights/frontend/coverage 63 | .tox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # mypy 91 | .mypy_cache/ 92 | 93 | # vim 94 | *.swp 95 | 96 | # Sphinx documentation 97 | sphinx/build/ 98 | 99 | # Docusaurus 100 | website/build/ 101 | website/i18n/ 102 | website/node_modules/ 103 | 104 | ## Generated for tutorials 105 | website/_tutorials/ 106 | website/static/files/ 107 | website/pages/tutorials/* 108 | !website/pages/tutorials/index.js 109 | 110 | ## Generated for Sphinx 111 | website/pages/api/ 112 | website/static/js/* 113 | !website/static/js/mathjax.js 114 | !website/static/js/code_block_buttons.js 115 | website/static/_sphinx-sources/ 116 | node_modules 117 | -------------------------------------------------------------------------------- /src/docker/captum_xil/.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | multi_line_output=3 3 | include_trailing_comma=True 4 | force_grid_wrap=0 5 | use_parentheses=True 6 | line_length=88 7 | known_third_party=pytext,torchvision,bs4,torch 8 | -------------------------------------------------------------------------------- /src/docker/captum_xil/CITATION: -------------------------------------------------------------------------------- 1 | @misc{captum2019github, 2 | author = {Kokhlikyan, Narine and Miglani, Vivek and Martin, Miguel and Wang, Edward and Reynolds, Jonathan and Melnikov, Alexander and Lunova, Natalia and Reblitz-Richardson, Orion}, 3 | title = {PyTorch Captum}, 4 | year = {2019}, 5 | publisher = {GitHub}, 6 | journal = {GitHub repository}, 7 | howpublished = {\url{https://github.com/pytorch/captum}}, 8 | } 9 | -------------------------------------------------------------------------------- /src/docker/captum_xil/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /src/docker/captum_xil/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Captum 2 | 3 | We want to make contributing to Captum as easy and transparent as possible. 4 | 5 | 6 | ## Development installation 7 | 8 | To get the development installation with all the necessary dependencies for 9 | linting, testing, and building the documentation, run the following: 10 | ```bash 11 | git clone https://github.com/pytorch/captum.git 12 | cd captum 13 | pip install -e .[dev] 14 | ``` 15 | 16 | 17 | ## Our Development Process 18 | 19 | #### Code Style 20 | 21 | Captum uses [black](https://github.com/ambv/black) and [flake8](https://github.com/PyCQA/flake8) to 22 | enforce a common code style across the code base. black and flake8 are installed easily via 23 | pip using `pip install black flake8`, and run locally by calling 24 | ```bash 25 | black . 26 | flake8 . 27 | ``` 28 | from the repository root. No additional configuration should be needed (see the 29 | [black documentation](https://black.readthedocs.io/en/stable/installation_and_usage.html#usage) 30 | for advanced usage). 31 | 32 | Captum also uses [isort](https://github.com/timothycrosley/isort) to sort imports 33 | alphabetically and separate into sections. isort is installed easily via 34 | pip using `pip install isort`, and run locally by calling 35 | ```bash 36 | isort 37 | ``` 38 | from the repository root. Configuration for isort is located in .isort.cfg. 39 | 40 | We feel strongly that having a consistent code style is extremely important, so 41 | CircleCI will fail on your PR if it does not adhere to the black or flake8 formatting style or isort import ordering. 42 | 43 | 44 | #### Type Hints 45 | 46 | Captum is fully typed using python 3.6+ 47 | [type hints](https://www.python.org/dev/peps/pep-0484/). 48 | We expect any contributions to also use proper type annotations, and we enforce 49 | consistency of these in our continuous integration tests. 50 | 51 | To type check your code locally, install [mypy](https://github.com/python/mypy), 52 | which can be done with pip using `pip install "mypy>=0.760"` 53 | Then run this script from the repository root: 54 | ```bash 55 | ./scripts/run_mypy.sh 56 | ``` 57 | Note that we expect mypy to have version 0.760 or higher, and when type checking, use PyTorch 1.4 or 58 | higher due to fixes to PyTorch type hints available in 1.4. We also use the Literal feature which is 59 | available only in Python 3.8 or above. If type-checking using a previous version of Python, you will 60 | need to install the typing-extension package which can be done with pip using `pip install typing-extensions`. 61 | 62 | #### Unit Tests 63 | 64 | To run the unit tests, you can either use `pytest` (if installed): 65 | ```bash 66 | pytest -ra 67 | ``` 68 | or python's `unittest`: 69 | ```bash 70 | python -m unittest 71 | ``` 72 | 73 | To get coverage reports we recommend using the `pytest-cov` plugin: 74 | ```bash 75 | pytest -ra --cov=. --cov-report term-missing 76 | ``` 77 | 78 | 79 | #### Documentation 80 | 81 | Captum's website is also open source, and is part of this very repository (the 82 | code can be found in the [website](/website/) folder). 83 | It is built using [Docusaurus](https://docusaurus.io/), and consists of three 84 | main elements: 85 | 86 | 1. The documentation in Docusaurus itself (if you know Markdown, you can 87 | already contribute!). This lives in the [docs](/docs/). 88 | 2. The API reference, auto-generated from the docstrings using 89 | [Sphinx](http://www.sphinx-doc.org), and embedded into the Docusaurus website. 90 | The sphinx .rst source files for this live in [sphinx/source](/sphinx/source/). 91 | 3. The Jupyter notebook tutorials, parsed by `nbconvert`, and embedded into the 92 | Docusaurus website. These live in [tutorials](/tutorials/). 93 | 94 | To build the documentation you will need [Node](https://nodejs.org/en/) >= 8.x 95 | and [Yarn](https://yarnpkg.com/en/) >= 1.5. 96 | 97 | The following command will both build the docs and serve the site locally: 98 | ```bash 99 | ./scripts/build_docs.sh 100 | ``` 101 | 102 | ## Pull Requests 103 | We actively welcome your pull requests. 104 | 105 | 1. Fork the repo and create your branch from `master`. 106 | 2. If you have added code that should be tested, add unit tests. 107 | In other words, add unit tests. 108 | 3. If you have changed APIs, update the documentation. Make sure the 109 | documentation builds. 110 | 4. Ensure the test suite passes. 111 | 5. Make sure your code passes both `black` and `flake8` formatting checks. 112 | 113 | 114 | ## Issues 115 | 116 | We use GitHub issues to track public bugs. Please ensure your description is 117 | clear and has sufficient instructions to be able to reproduce the issue. 118 | 119 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 120 | disclosure of security bugs. In those cases, please go through the process 121 | outlined on that page and do not file a public issue. 122 | 123 | 124 | ## License 125 | 126 | By contributing to Captum, you agree that your contributions will be licensed 127 | under the LICENSE file in the root directory of this source tree. 128 | -------------------------------------------------------------------------------- /src/docker/captum_xil/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, PyTorch team 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | __version__ = "0.2.0" 4 | -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/attr/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from ._core.deep_lift import DeepLift, DeepLiftShap # noqa 4 | from ._core.feature_ablation import FeatureAblation # noqa 5 | from ._core.feature_permutation import FeaturePermutation # noqa 6 | from ._core.gradient_shap import GradientShap # noqa 7 | from ._core.guided_backprop_deconvnet import Deconvolution # noqa 8 | from ._core.guided_backprop_deconvnet import GuidedBackprop 9 | from ._core.guided_grad_cam import GuidedGradCam # noqa 10 | from ._core.input_x_gradient import InputXGradient # noqa 11 | from ._core.integrated_gradients import IntegratedGradients # noqa 12 | from ._core.layer.grad_cam import LayerGradCam # noqa 13 | from ._core.layer.internal_influence import InternalInfluence # noqa 14 | from ._core.layer.layer_activation import LayerActivation # noqa 15 | from ._core.layer.layer_conductance import LayerConductance # noqa 16 | from ._core.layer.layer_deep_lift import LayerDeepLift # noqa 17 | from ._core.layer.layer_deep_lift import LayerDeepLiftShap 18 | from ._core.layer.layer_gradient_shap import LayerGradientShap # noqa 19 | from ._core.layer.layer_gradient_x_activation import LayerGradientXActivation # noqa 20 | from ._core.layer.layer_integrated_gradients import LayerIntegratedGradients # noqa 21 | from ._core.neuron.neuron_conductance import NeuronConductance # noqa 22 | from ._core.neuron.neuron_deep_lift import NeuronDeepLift # noqa 23 | from ._core.neuron.neuron_deep_lift import NeuronDeepLiftShap 24 | from ._core.neuron.neuron_feature_ablation import NeuronFeatureAblation # noqa 25 | from ._core.neuron.neuron_gradient import NeuronGradient # noqa 26 | from ._core.neuron.neuron_gradient_shap import NeuronGradientShap # noqa 27 | from ._core.neuron.neuron_guided_backprop_deconvnet import ( # noqa 28 | NeuronDeconvolution, 29 | NeuronGuidedBackprop, 30 | ) 31 | from ._core.neuron.neuron_integrated_gradients import NeuronIntegratedGradients # noqa 32 | from ._core.noise_tunnel import NoiseTunnel # noqa 33 | from ._core.occlusion import Occlusion # noqa 34 | from ._core.saliency import Saliency # noqa 35 | from ._core.shapley_value import ShapleyValues, ShapleyValueSampling # noqa 36 | from ._models.base import InterpretableEmbeddingBase # noqa 37 | from ._models.base import ( 38 | TokenReferenceBase, 39 | configure_interpretable_embedding_layer, 40 | remove_interpretable_embedding_layer, 41 | ) 42 | from ._utils import visualization # noqa 43 | from ._utils.attribution import Attribution # noqa 44 | from ._utils.attribution import GradientAttribution # noqa 45 | from ._utils.attribution import LayerAttribution # noqa 46 | from ._utils.attribution import NeuronAttribution # noqa 47 | from ._utils.attribution import PerturbationAttribution # noqa 48 | from ._utils.stat import MSE, Count, Max, Mean, Min, StdDev, Sum, Var 49 | from ._utils.summarizer import CommonSummarizer, Summarizer 50 | 51 | __all__ = [ 52 | "Attribution", 53 | "GradientAttribution", 54 | "PerturbationAttribution", 55 | "NeuronAttribution", 56 | "LayerAttribution", 57 | "IntegratedGradients", 58 | "DeepLift", 59 | "DeepLiftShap", 60 | "InputXGradient", 61 | "Saliency", 62 | "GuidedBackprop", 63 | "Deconvolution", 64 | "GuidedGradCam", 65 | "FeatureAblation", 66 | "FeaturePermutation", 67 | "Occlusion", 68 | "ShapleyValueSampling", 69 | "ShapleyValues", 70 | "LayerConductance", 71 | "LayerGradientXActivation", 72 | "LayerActivation", 73 | "InternalInfluence", 74 | "LayerGradCam", 75 | "LayerDeepLift", 76 | "LayerDeepLiftShap", 77 | "LayerGradientShap", 78 | "LayerIntegratedGradients", 79 | "NeuronConductance", 80 | "NeuronFeatureAblation", 81 | "NeuronGradient", 82 | "NeuronIntegratedGradients", 83 | "NeuronDeepLift", 84 | "NeuronDeepLiftShap", 85 | "NeuronGradientShap", 86 | "NeuronDeconvolution", 87 | "NeuronGuidedBackprop", 88 | "NoiseTunnel", 89 | "GradientShap", 90 | "InterpretableEmbeddingBase", 91 | "TokenReferenceBase", 92 | "visualization", 93 | "configure_interpretable_embedding_layer", 94 | "remove_interpretable_embedding_layer", 95 | "Summarizer", 96 | "CommonSummarizer", 97 | "Mean", 98 | "StdDev", 99 | "MSE", 100 | "Var", 101 | "Min", 102 | "Max", 103 | "Sum", 104 | "Count", 105 | ] 106 | -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/attr/_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/captum/attr/_core/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/attr/_core/layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/captum/attr/_core/layer/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/attr/_core/neuron/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/captum/attr/_core/neuron/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/attr/_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/captum/attr/_models/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/attr/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/captum/attr/_utils/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/attr/_utils/typing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import TYPE_CHECKING, List, Tuple, TypeVar, Union 4 | 5 | from torch import Tensor 6 | 7 | if TYPE_CHECKING: 8 | import sys 9 | 10 | if sys.version_info >= (3, 8): 11 | from typing import Literal # noqa: F401 12 | else: 13 | from typing_extensions import Literal # noqa: F401 14 | else: 15 | Literal = {True: bool, False: bool, (True, False): bool} 16 | 17 | TensorOrTupleOfTensorsGeneric = TypeVar( 18 | "TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...] 19 | ) 20 | TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool) 21 | TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]] 22 | BaselineType = Union[None, Tensor, int, float, Tuple[Union[Tensor, int, float], ...]] 23 | -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/insights/__init__.py: -------------------------------------------------------------------------------- 1 | from .api import AttributionVisualizer, Batch, FilterConfig # noqa 2 | -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/insights/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/captum/insights/_utils/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/insights/_utils/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Callable, List, Optional, Union 4 | 5 | 6 | def format_transforms( 7 | transforms: Optional[Union[Callable, List[Callable]]] 8 | ) -> List[Callable]: 9 | if transforms is None: 10 | return [] 11 | if callable(transforms): 12 | return [transforms] 13 | return transforms 14 | -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/insights/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from typing import List, NamedTuple, Optional, Tuple 3 | 4 | from captum.attr import ( 5 | Deconvolution, 6 | DeepLift, 7 | GuidedBackprop, 8 | InputXGradient, 9 | IntegratedGradients, 10 | Saliency, 11 | ) 12 | from captum.attr._utils.approximation_methods import SUPPORTED_METHODS 13 | 14 | 15 | class NumberConfig(NamedTuple): 16 | value: int = 1 17 | limit: Tuple[Optional[int], Optional[int]] = (None, None) 18 | type: str = "number" 19 | 20 | 21 | class StrEnumConfig(NamedTuple): 22 | value: str 23 | limit: List[str] 24 | type: str = "enum" 25 | 26 | 27 | SUPPORTED_ATTRIBUTION_METHODS = [ 28 | Deconvolution, 29 | DeepLift, 30 | GuidedBackprop, 31 | InputXGradient, 32 | IntegratedGradients, 33 | Saliency, 34 | ] 35 | 36 | ATTRIBUTION_NAMES_TO_METHODS = { 37 | # mypy bug - treating it as a type instead of a class 38 | cls.get_name(): cls # type: ignore 39 | for cls in SUPPORTED_ATTRIBUTION_METHODS 40 | } 41 | 42 | ATTRIBUTION_METHOD_CONFIG = { 43 | IntegratedGradients.get_name(): { 44 | "n_steps": NumberConfig(value=25, limit=(2, None)), 45 | "method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"), 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/insights/example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | 9 | from captum.insights import AttributionVisualizer, Batch 10 | from captum.insights.features import ImageFeature 11 | 12 | 13 | def get_classes(): 14 | classes = [ 15 | "Plane", 16 | "Car", 17 | "Bird", 18 | "Cat", 19 | "Deer", 20 | "Dog", 21 | "Frog", 22 | "Horse", 23 | "Ship", 24 | "Truck", 25 | ] 26 | return classes 27 | 28 | 29 | def get_pretrained_model(): 30 | class Net(nn.Module): 31 | def __init__(self): 32 | super(Net, self).__init__() 33 | self.conv1 = nn.Conv2d(3, 6, 5) 34 | self.pool1 = nn.MaxPool2d(2, 2) 35 | self.pool2 = nn.MaxPool2d(2, 2) 36 | self.conv2 = nn.Conv2d(6, 16, 5) 37 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 38 | self.fc2 = nn.Linear(120, 84) 39 | self.fc3 = nn.Linear(84, 10) 40 | self.relu1 = nn.ReLU() 41 | self.relu2 = nn.ReLU() 42 | self.relu3 = nn.ReLU() 43 | self.relu4 = nn.ReLU() 44 | 45 | def forward(self, x): 46 | x = self.pool1(self.relu1(self.conv1(x))) 47 | x = self.pool2(self.relu2(self.conv2(x))) 48 | x = x.view(-1, 16 * 5 * 5) 49 | x = self.relu3(self.fc1(x)) 50 | x = self.relu4(self.fc2(x)) 51 | x = self.fc3(x) 52 | return x 53 | 54 | net = Net() 55 | pt_path = os.path.abspath( 56 | os.path.dirname(__file__) + "/models/cifar_torchvision.pt" 57 | ) 58 | net.load_state_dict(torch.load(pt_path)) 59 | return net 60 | 61 | 62 | def baseline_func(input): 63 | return input * 0 64 | 65 | 66 | def formatted_data_iter(): 67 | dataset = torchvision.datasets.CIFAR10( 68 | root="data/test", train=False, download=True, transform=transforms.ToTensor() 69 | ) 70 | dataloader = iter( 71 | torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False, num_workers=2) 72 | ) 73 | while True: 74 | images, labels = next(dataloader) 75 | yield Batch(inputs=images, labels=labels) 76 | 77 | 78 | if __name__ == "__main__": 79 | normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 80 | model = get_pretrained_model() 81 | visualizer = AttributionVisualizer( 82 | models=[model], 83 | score_func=lambda o: torch.nn.functional.softmax(o, 1), 84 | classes=get_classes(), 85 | features=[ 86 | ImageFeature( 87 | "Photo", 88 | baseline_transforms=[baseline_func], 89 | input_transforms=[normalize], 90 | ) 91 | ], 92 | dataset=formatted_data_iter(), 93 | ) 94 | 95 | visualizer.serve(debug=True) 96 | -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/insights/models/cifar_torchvision.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/captum/insights/models/cifar_torchvision.pt -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/insights/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import logging 3 | import os 4 | import socket 5 | import threading 6 | from time import sleep 7 | from typing import Optional 8 | 9 | from flask import Flask, jsonify, render_template, request 10 | from torch import Tensor 11 | 12 | app = Flask( 13 | __name__, static_folder="frontend/build/static", template_folder="frontend/build" 14 | ) 15 | visualizer = None 16 | port = None 17 | 18 | 19 | def namedtuple_to_dict(obj): 20 | if isinstance(obj, Tensor): 21 | return obj.item() 22 | if hasattr(obj, "_asdict"): # detect namedtuple 23 | return dict(zip(obj._fields, (namedtuple_to_dict(item) for item in obj))) 24 | elif isinstance(obj, str): # iterables - strings 25 | return obj 26 | elif hasattr(obj, "keys"): # iterables - mapping 27 | return dict( 28 | zip(obj.keys(), (namedtuple_to_dict(item) for item in obj.values())) 29 | ) 30 | elif hasattr(obj, "__iter__"): # iterables - sequence 31 | return type(obj)((namedtuple_to_dict(item) for item in obj)) 32 | else: # non-iterable cannot contain namedtuples 33 | return obj 34 | 35 | 36 | @app.route("/attribute", methods=["POST"]) 37 | def attribute(): 38 | # force=True needed for Colab notebooks, which doesn't use the correct 39 | # Content-Type header when forwarding requests through the Colab proxy 40 | r = request.get_json(force=True) 41 | return jsonify( 42 | namedtuple_to_dict( 43 | visualizer._calculate_attribution_from_cache(r["instance"], r["labelIndex"]) 44 | ) 45 | ) 46 | 47 | 48 | @app.route("/fetch", methods=["POST"]) 49 | def fetch(): 50 | # force=True needed, see comment for "/attribute" route above 51 | visualizer._update_config(request.get_json(force=True)) 52 | visualizer_output = visualizer.visualize() 53 | clean_output = namedtuple_to_dict(visualizer_output) 54 | return jsonify(clean_output) 55 | 56 | 57 | @app.route("/init") 58 | def init(): 59 | return jsonify(visualizer.get_insights_config()) 60 | 61 | 62 | @app.route("/") 63 | def index(id=0): 64 | return render_template("index.html") 65 | 66 | 67 | def get_free_tcp_port(): 68 | tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 69 | tcp.bind(("", 0)) 70 | addr, port = tcp.getsockname() 71 | tcp.close() 72 | return port 73 | 74 | 75 | def run_app(debug: bool = True): 76 | app.run(port=port, use_reloader=False, debug=debug) 77 | 78 | 79 | def start_server( 80 | _viz, blocking: bool = False, debug: bool = False, _port: Optional[int] = None 81 | ): 82 | global visualizer 83 | visualizer = _viz 84 | 85 | global port 86 | if port is None: 87 | os.environ["WERKZEUG_RUN_MAIN"] = "true" # hides starting message 88 | if not debug: 89 | log = logging.getLogger("werkzeug") 90 | log.disabled = True 91 | app.logger.disabled = True 92 | 93 | port = _port or get_free_tcp_port() 94 | # Start in a new thread to not block notebook execution 95 | t = threading.Thread(target=run_app, kwargs={"debug": debug}) 96 | t.start() 97 | sleep(0.01) # add a short delay to allow server to start up 98 | if blocking: 99 | t.join() 100 | 101 | print(f"\nFetch data and view Captum Insights at http://localhost:{port}/\n") 102 | return port 103 | -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/insights/widget/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import __version__, version_info # noqa 2 | from .widget import * # noqa 3 | 4 | 5 | def _jupyter_nbextension_paths(): 6 | return [ 7 | { 8 | "section": "notebook", 9 | "src": "static", 10 | "dest": "jupyter-captum-insights", 11 | "require": "jupyter-captum-insights/extension", 12 | } 13 | ] 14 | -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/insights/widget/_version.py: -------------------------------------------------------------------------------- 1 | version_info = (0, 1, 0, "alpha", 0) 2 | 3 | _specifier_ = {"alpha": "a", "beta": "b", "candidate": "rc", "final": ""} 4 | 5 | __version__ = "%s.%s.%s%s" % ( 6 | version_info[0], 7 | version_info[1], 8 | version_info[2], 9 | "" 10 | if version_info[3] == "final" 11 | else _specifier_[version_info[3]] + str(version_info[4]), 12 | ) 13 | -------------------------------------------------------------------------------- /src/docker/captum_xil/captum/insights/widget/widget.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import ipywidgets as widgets 3 | from traitlets import Dict, Instance, List, Unicode, observe 4 | 5 | from captum.insights import AttributionVisualizer 6 | from captum.insights.server import namedtuple_to_dict 7 | 8 | 9 | @widgets.register 10 | class CaptumInsights(widgets.DOMWidget): 11 | """A widget for interacting with Captum Insights.""" 12 | 13 | _view_name = Unicode("CaptumInsightsView").tag(sync=True) 14 | _model_name = Unicode("CaptumInsightsModel").tag(sync=True) 15 | _view_module = Unicode("jupyter-captum-insights").tag(sync=True) 16 | _model_module = Unicode("jupyter-captum-insights").tag(sync=True) 17 | _view_module_version = Unicode("^0.1.0").tag(sync=True) 18 | _model_module_version = Unicode("^0.1.0").tag(sync=True) 19 | 20 | visualizer = Instance(klass=AttributionVisualizer) 21 | 22 | insights_config = Dict().tag(sync=True) 23 | label_details = Dict().tag(sync=True) 24 | attribution = Dict().tag(sync=True) 25 | config = Dict().tag(sync=True) 26 | output = List().tag(sync=True) 27 | 28 | def __init__(self, **kwargs): 29 | super(CaptumInsights, self).__init__(**kwargs) 30 | self.insights_config = self.visualizer.get_insights_config() 31 | self.out = widgets.Output() 32 | with self.out: 33 | print("Captum Insights widget created.") 34 | 35 | @observe("config") 36 | def _fetch_data(self, change): 37 | if not self.config: 38 | return 39 | with self.out: 40 | self.visualizer._update_config(self.config) 41 | self.output = namedtuple_to_dict(self.visualizer.visualize()) 42 | self.config = dict() 43 | 44 | @observe("label_details") 45 | def _fetch_attribution(self, change): 46 | if not self.label_details: 47 | return 48 | with self.out: 49 | self.attribution = namedtuple_to_dict( 50 | self.visualizer._calculate_attribution_from_cache( 51 | self.label_details["instance"], self.label_details["labelIndex"] 52 | ) 53 | ) 54 | self.label_details = dict() 55 | -------------------------------------------------------------------------------- /src/docker/captum_xil/docs/README.md: -------------------------------------------------------------------------------- 1 | This directory contains the source files for Captum's Docusaurus documentation. 2 | See the repo's [README](../README.md) for additional information. 3 | -------------------------------------------------------------------------------- /src/docker/captum_xil/docs/captum_insights.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: captum_insights 3 | title: Captum Insights 4 | --- 5 | 6 | Interpreting model output in complex models can be difficult. Even with interpretability libraries like Captum, it can be difficult to understand models without proper visualizations. Image and text input features can be especially difficult to understand without these visualizations. 7 | 8 | Captum Insights is an interpretability visualization widget built on top of Captum to facilitate model understanding. Captum Insights works across images, text, and other features to help users understand feature attribution. Some examples of the widget are below. 9 | 10 | Getting started with Captum Insights is easy. You can learn how to use Captum Insights with the [Getting started with Captum Insights](/tutorials/CIFAR_TorchVision_Captum_Insights) tutorial. Alternatively, to analyze a sample model on CIFAR10 via Captum Insights execute the line below and navigate to the URL specified in the output. 11 | 12 | ``` 13 | python -m captum.insights.example 14 | ``` 15 | 16 | 17 | Below are some sample screenshots of Captum Insights: 18 | 19 | ![screenshot1](/img/captum_insights_screenshot.png) 20 | 21 | ![screenshot2](/img/captum_insights_screenshot_vqa.png) 22 | -------------------------------------------------------------------------------- /src/docker/captum_xil/docs/getting_started.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: getting_started 3 | title: Getting Started 4 | --- 5 | 6 | This section shows you how to get up and running with Captum. 7 | 8 | 9 | ## Installing Captum 10 | 11 | #### Installation Requirements: 12 | 13 | - Python >= 3.6 14 | - PyTorch >= 1.2 15 | - numpy 16 | 17 | Captum is easily installed via 18 | [Anaconda](https://www.anaconda.com/distribution/#download-section) (recommended) 19 | or `pip`: 20 | 21 | 22 | 23 | ```bash 24 | conda install captum -c pytorch 25 | ``` 26 | 27 | ```bash 28 | pip install captum 29 | ``` 30 | 31 | 32 | For more detailed installation instructions, please see the 33 | [Project Readme](https://github.com/pytorch/captum/blob/master/README.md) 34 | on GitHub. 35 | 36 | 37 | ## Tutorials 38 | 39 | We have several tutorials to help get you off the ground with Captum. The tutorials are Jupyter notebooks and cover the basics along with demonstrating usage of Captum with models of different modalities. 40 | 41 | View the tutorials page [here](../tutorials/). 42 | 43 | 44 | ## API Reference 45 | 46 | For an in-depth reference of the various Captum internals, see our 47 | [API Reference](../api). 48 | 49 | 50 | ## Contributing 51 | 52 | You'd like to contribute to Captum? Great! Please see 53 | [here](https://github.com/pytorch/captum/blob/master/CONTRIBUTING.md) 54 | for how to help out. 55 | -------------------------------------------------------------------------------- /src/docker/captum_xil/docs/introduction.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: introduction 3 | title: Introduction 4 | --- 5 | Captum (“comprehension” in Latin) is an open source, extensible library for model interpretability built on PyTorch. 6 | 7 | With the increase in model complexity and the resulting lack of transparency, model interpretability methods have become increasingly important. Model understanding is both an active area of research as well as an area of focus for practical applications across industries using machine learning. Captum provides state-of-the-art algorithms, including Integrated Gradients, to provide researchers and developers with an easy way to understand which features are contributing to a model’s output. 8 | 9 | Captum helps ML researchers more easily implement interpretability algorithms that can interact with PyTorch models. It also allows researchers to quickly benchmark their work against other existing algorithms available in the library. 10 | 11 | For model developers, Captum can be used to improve and troubleshoot models by facilitating the identification of different features that contribute to a model’s output in order to design better models and troubleshoot unexpected model outputs. 12 | 13 | ## Target Audience 14 | 15 | The primary audiences for Captum are model developers who are looking to improve their models and understand which features are important and interpretability researchers focused on identifying algorithms that can better interpret many types of models. 16 | 17 | Captum can also be used by application engineers who are using trained models in production. Captum provides easier troubleshooting through improved model interpretability, and the potential for delivering better explanations to end users on why they’re seeing a specific piece of content, such as a movie recommendation. 18 | -------------------------------------------------------------------------------- /src/docker/captum_xil/docs/overview.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: overview 3 | title: Overview 4 | --- 5 | 6 | This overview describes the basic components of Captum and how they work 7 | together. For a high-level view of what Captum tries to achieve in more 8 | abstract terms, please see the [Introduction](introduction). 9 | 10 | 11 | TODO: Add content here 12 | -------------------------------------------------------------------------------- /src/docker/captum_xil/docs/presentations/Captum_NeurIPS_2019_final.key: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/docs/presentations/Captum_NeurIPS_2019_final.key -------------------------------------------------------------------------------- /src/docker/captum_xil/environment.yml: -------------------------------------------------------------------------------- 1 | name: captum 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - numpy 6 | - pytorch>=1.2 7 | -------------------------------------------------------------------------------- /src/docker/captum_xil/scripts/build_docs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run this script from the project root using `./scripts/build_docs.sh` 4 | 5 | usage() { 6 | echo "Usage: $0 [-b]" 7 | echo "" 8 | echo "Build Captum documentation." 9 | echo "" 10 | echo " -b Build static version of documentation (otherwise start server)" 11 | echo "" 12 | exit 1 13 | } 14 | 15 | BUILD_STATIC=false 16 | 17 | while getopts 'hb' flag; do 18 | case "${flag}" in 19 | h) 20 | usage 21 | ;; 22 | b) 23 | BUILD_STATIC=true 24 | ;; 25 | *) 26 | usage 27 | ;; 28 | esac 29 | done 30 | 31 | echo "-----------------------------------" 32 | echo "Generating API reference via Sphinx" 33 | echo "-----------------------------------" 34 | cd sphinx || exit 35 | make html 36 | cd .. || exit 37 | 38 | echo "-----------------------------------" 39 | echo "Building Captum Docusaurus site" 40 | echo "-----------------------------------" 41 | cd website || exit 42 | yarn 43 | 44 | # run script to parse html generated by sphinx 45 | echo "--------------------------------------------" 46 | echo "Parsing Sphinx docs and moving to Docusaurus" 47 | echo "--------------------------------------------" 48 | cd .. 49 | mkdir -p "website/pages/api/" 50 | 51 | cwd=$(pwd) 52 | python scripts/parse_sphinx.py -i "${cwd}/sphinx/build/html/" -o "${cwd}/website/pages/api/" 53 | 54 | SPHINX_JS_DIR='sphinx/build/html/_static/' 55 | DOCUSAURUS_JS_DIR='website/static/js/' 56 | 57 | mkdir -p $DOCUSAURUS_JS_DIR 58 | 59 | # move JS files from /sphinx/build/html/_static/*: 60 | cp "${SPHINX_JS_DIR}documentation_options.js" "${DOCUSAURUS_JS_DIR}documentation_options.js" 61 | cp "${SPHINX_JS_DIR}jquery.js" "${DOCUSAURUS_JS_DIR}jquery.js" 62 | cp "${SPHINX_JS_DIR}underscore.js" "${DOCUSAURUS_JS_DIR}underscore.js" 63 | cp "${SPHINX_JS_DIR}doctools.js" "${DOCUSAURUS_JS_DIR}doctools.js" 64 | cp "${SPHINX_JS_DIR}language_data.js" "${DOCUSAURUS_JS_DIR}language_data.js" 65 | cp "${SPHINX_JS_DIR}searchtools.js" "${DOCUSAURUS_JS_DIR}searchtools.js" 66 | 67 | # searchindex.js is not static util 68 | cp "sphinx/build/html/searchindex.js" "${DOCUSAURUS_JS_DIR}searchindex.js" 69 | 70 | # copy module sources 71 | cp -r "sphinx/build/html/_sources/" "website/static/_sphinx-sources/" 72 | 73 | echo "-----------------------------------" 74 | echo "Generating tutorials" 75 | echo "-----------------------------------" 76 | mkdir -p "website/_tutorials" 77 | mkdir -p "website/static/files" 78 | python scripts/parse_tutorials.py -w "${cwd}" 79 | 80 | cd website || exit 81 | 82 | if [[ $BUILD_STATIC == true ]]; then 83 | echo "-----------------------------------" 84 | echo "Building static site" 85 | echo "-----------------------------------" 86 | yarn build 87 | else 88 | echo "-----------------------------------" 89 | echo "Starting local server" 90 | echo "-----------------------------------" 91 | yarn start 92 | fi 93 | -------------------------------------------------------------------------------- /src/docker/captum_xil/scripts/build_insights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This builds the Captum Insights React-based UI using yarn. 4 | # 5 | # Run this script from the project root using `./scripts/build_insights.sh` 6 | 7 | set -e 8 | 9 | usage() { 10 | echo "Usage: $0" 11 | echo "" 12 | echo "Build Captum Insights." 13 | echo "" 14 | exit 1 15 | } 16 | 17 | while getopts 'h' flag; do 18 | case "${flag}" in 19 | h) 20 | usage 21 | ;; 22 | *) 23 | usage 24 | ;; 25 | esac 26 | done 27 | 28 | # check if yarn is installed 29 | if [ ! -x "$(command -v yarn)" ]; then 30 | echo "" 31 | echo "Please install yarn with 'conda install -c conda-forge yarn' or equivalent." 32 | echo "" 33 | exit 1 34 | fi 35 | 36 | # go into subdir 37 | pushd captum/insights/frontend || exit 38 | 39 | echo "-----------------------------------" 40 | echo "Install Dependencies" 41 | echo "-----------------------------------" 42 | yarn install 43 | 44 | echo "-----------------------------------" 45 | echo "Building Captum Insights" 46 | echo "-----------------------------------" 47 | yarn build 48 | 49 | pushd widget || exit 50 | 51 | echo "-----------------------------------" 52 | echo "Building Captum Insights widget" 53 | echo "-----------------------------------" 54 | 55 | ../node_modules/.bin/webpack-cli --config webpack.config.js 56 | 57 | # exit subdir 58 | popd || exit 59 | 60 | # exit subdir 61 | popd || exit 62 | -------------------------------------------------------------------------------- /src/docker/captum_xil/scripts/install_via_conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | PYTORCH_NIGHTLY=false 6 | 7 | while getopts 'nf' flag; do 8 | case "${flag}" in 9 | n) PYTORCH_NIGHTLY=true ;; 10 | f) FRAMEWORKS=true ;; 11 | *) echo "usage: $0 [-n] [-f]" >&2 12 | exit 1 ;; 13 | esac 14 | done 15 | 16 | # update conda 17 | # removing due to setuptools error during update 18 | #conda update -y -n base -c defaults conda 19 | 20 | # required to use conda develop 21 | conda install -y conda-build 22 | 23 | # install yarn for insights build 24 | conda install -y -c conda-forge yarn 25 | 26 | # install other frameworks if asked for and make sure this is before pytorch 27 | if [[ $FRAMEWORKS == true ]]; then 28 | pip install pytext-nlp 29 | fi 30 | 31 | if [[ $PYTORCH_NIGHTLY == true ]]; then 32 | # install CPU version for much smaller download 33 | conda install -y pytorch cpuonly -c pytorch-nightly 34 | else 35 | # install CPU version for much smaller download 36 | conda install -y -c pytorch pytorch-cpu 37 | fi 38 | 39 | # install other deps 40 | conda install -y numpy sphinx pytest flake8 ipywidgets ipython 41 | conda install -y -c conda-forge black matplotlib pytest-cov sphinx-autodoc-typehints mypy flask isort 42 | 43 | # build insights and install captum 44 | # TODO: remove CI=false when we want React warnings treated as errors 45 | CI=false BUILD_INSIGHTS=1 python setup.py develop 46 | -------------------------------------------------------------------------------- /src/docker/captum_xil/scripts/install_via_pip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | PYTORCH_NIGHTLY=false 6 | DEPLOY=false 7 | CHOSEN_TORCH_VERSION=-1 8 | 9 | while getopts 'ndfv:' flag; do 10 | case "${flag}" in 11 | n) PYTORCH_NIGHTLY=true ;; 12 | d) DEPLOY=true ;; 13 | f) FRAMEWORKS=true ;; 14 | v) CHOSEN_TORCH_VERSION=${OPTARG};; 15 | *) echo "usage: $0 [-n] [-d] [-f] [-v version]" >&2 16 | exit 1 ;; 17 | esac 18 | done 19 | 20 | # NOTE: Only Debian variants are supported, since this script is only 21 | # used by our tests on CircleCI. In the future we might generalize, 22 | # but users should hopefully be using conda installs. 23 | 24 | # install nodejs and yarn for insights build 25 | sudo apt install apt-transport-https ca-certificates 26 | curl -sL https://deb.nodesource.com/setup_10.x | sudo -E bash - 27 | curl -sS https://dl.yarnpkg.com/debian/pubkey.gpg | sudo apt-key add - 28 | echo "deb https://dl.yarnpkg.com/debian/ stable main" | sudo tee /etc/apt/sources.list.d/yarn.list 29 | sudo apt update 30 | sudo apt install nodejs 31 | sudo apt install yarn 32 | 33 | # yarn needs terminal info 34 | export TERM=xterm 35 | 36 | # NOTE: All of the below installs use sudo, b/c otherwise pip will get 37 | # permission errors installing in the docker container. An alternative would be 38 | # to use a virtualenv, but that would lead to bifurcation of the CircleCI config 39 | # since we'd need to source the environemnt in each step. 40 | 41 | # upgrade pip 42 | sudo pip install --upgrade pip 43 | 44 | # install captum with dev deps 45 | sudo pip install -e .[dev] 46 | sudo BUILD_INSIGHTS=1 python setup.py develop 47 | 48 | # install other frameworks if asked for and make sure this is before pytorch 49 | if [[ $FRAMEWORKS == true ]]; then 50 | sudo pip install pytext-nlp 51 | fi 52 | 53 | # install pytorch nightly if asked for 54 | if [[ $PYTORCH_NIGHTLY == true ]]; then 55 | sudo pip install --upgrade --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html 56 | else 57 | # If no version specified, upgrade to latest release. 58 | if [[ $CHOSEN_TORCH_VERSION == -1 ]]; then 59 | sudo pip install --upgrade torch 60 | else 61 | sudo pip install torch==$CHOSEN_TORCH_VERSION 62 | fi 63 | fi 64 | 65 | # install deployment bits if asked for 66 | if [[ $DEPLOY == true ]]; then 67 | sudo pip install beautifulsoup4 ipython nbconvert 68 | fi 69 | -------------------------------------------------------------------------------- /src/docker/captum_xil/scripts/parse_sphinx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | 6 | from bs4 import BeautifulSoup 7 | 8 | js_scripts = """ 9 | 11 | 12 | 13 | 14 | 15 | 16 | """ # noqa: E501 17 | 18 | search_js_scripts = """ 19 | 22 | 23 | 24 | """ 25 | 26 | 27 | def parse_sphinx(input_dir, output_dir): 28 | for cur, _, files in os.walk(input_dir): 29 | for fname in files: 30 | if fname.endswith(".html"): 31 | with open(os.path.join(cur, fname), "r") as f: 32 | soup = BeautifulSoup(f.read(), "html.parser") 33 | doc = soup.find("div", {"class": "document"}) 34 | wrapped_doc = doc.wrap(soup.new_tag("div", **{"class": "sphinx"})) 35 | # add js 36 | if fname == "search.html": 37 | out = js_scripts + search_js_scripts + str(wrapped_doc) 38 | else: 39 | out = js_scripts + str(wrapped_doc) 40 | output_path = os.path.join(output_dir, os.path.relpath(cur, input_dir)) 41 | os.makedirs(output_path, exist_ok=True) 42 | with open(os.path.join(output_path, fname), "w") as fout: 43 | fout.write(out) 44 | 45 | # update reference in JS file 46 | with open(os.path.join(input_dir, "_static/searchtools.js"), "r") as js_file: 47 | js = js_file.read() 48 | js = js.replace( 49 | "DOCUMENTATION_OPTIONS.URL_ROOT + '_sources/'", "'_sphinx-sources/'" 50 | ) 51 | with open(os.path.join(input_dir, "_static/searchtools.js"), "w") as js_file: 52 | js_file.write(js) 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser(description="Strip HTML body from Sphinx docs.") 57 | parser.add_argument( 58 | "-i", 59 | "--input_dir", 60 | metavar="path", 61 | required=True, 62 | help="Input directory for Sphinx HTML.", 63 | ) 64 | parser.add_argument( 65 | "-o", 66 | "--output_dir", 67 | metavar="path", 68 | required=True, 69 | help="Output directory in Docusaurus.", 70 | ) 71 | args = parser.parse_args() 72 | parse_sphinx(args.input_dir, args.output_dir) 73 | -------------------------------------------------------------------------------- /src/docker/captum_xil/scripts/parse_tutorials.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import json 5 | import os 6 | 7 | import nbformat 8 | from bs4 import BeautifulSoup 9 | from nbconvert import HTMLExporter, ScriptExporter 10 | 11 | TEMPLATE = """const CWD = process.cwd(); 12 | 13 | const React = require('react'); 14 | const Tutorial = require(`${{CWD}}/core/Tutorial.js`); 15 | 16 | class TutorialPage extends React.Component {{ 17 | render() {{ 18 | const {{config: siteConfig}} = this.props; 19 | const {{baseUrl}} = siteConfig; 20 | return ; 21 | }} 22 | }} 23 | 24 | module.exports = TutorialPage; 25 | 26 | """ 27 | 28 | JS_SCRIPTS = """ 29 | 32 | 35 | """ # noqa: E501 36 | 37 | 38 | def gen_tutorials(repo_dir: str) -> None: 39 | """Generate HTML tutorials for captum Docusaurus site from Jupyter notebooks. 40 | 41 | Also create ipynb and py versions of tutorial in Docusaurus site for 42 | download. 43 | """ 44 | with open(os.path.join(repo_dir, "website", "tutorials.json"), "r") as infile: 45 | tutorial_config = json.loads(infile.read()) 46 | 47 | tutorial_ids = {x["id"] for v in tutorial_config.values() for x in v} 48 | 49 | for tid in tutorial_ids: 50 | print("Generating {} tutorial".format(tid)) 51 | 52 | # convert notebook to HTML 53 | ipynb_in_path = os.path.join(repo_dir, "tutorials", "{}.ipynb".format(tid)) 54 | with open(ipynb_in_path, "r") as infile: 55 | nb_str = infile.read() 56 | nb = nbformat.reads(nb_str, nbformat.NO_CONVERT) 57 | 58 | # displayname is absent from notebook metadata 59 | nb["metadata"]["kernelspec"]["display_name"] = "python3" 60 | 61 | exporter = HTMLExporter() 62 | html, meta = exporter.from_notebook_node(nb) 63 | 64 | # pull out html div for notebook 65 | soup = BeautifulSoup(html, "html.parser") 66 | nb_meat = soup.find("div", {"id": "notebook-container"}) 67 | del nb_meat.attrs["id"] 68 | nb_meat.attrs["class"] = ["notebook"] 69 | html_out = JS_SCRIPTS + str(nb_meat) 70 | 71 | # generate html file 72 | html_out_path = os.path.join( 73 | repo_dir, "website", "_tutorials", "{}.html".format(tid) 74 | ) 75 | with open(html_out_path, "w") as html_outfile: 76 | html_outfile.write(html_out) 77 | 78 | # generate JS file 79 | script = TEMPLATE.format(tid) 80 | js_out_path = os.path.join( 81 | repo_dir, "website", "pages", "tutorials", "{}.js".format(tid) 82 | ) 83 | with open(js_out_path, "w") as js_outfile: 84 | js_outfile.write(script) 85 | 86 | # output tutorial in both ipynb & py form 87 | ipynb_out_path = os.path.join( 88 | repo_dir, "website", "static", "files", "{}.ipynb".format(tid) 89 | ) 90 | with open(ipynb_out_path, "w") as ipynb_outfile: 91 | ipynb_outfile.write(nb_str) 92 | exporter = ScriptExporter() 93 | script, meta = exporter.from_notebook_node(nb) 94 | py_out_path = os.path.join( 95 | repo_dir, "website", "static", "files", "{}.py".format(tid) 96 | ) 97 | with open(py_out_path, "w") as py_outfile: 98 | py_outfile.write(script) 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser( 103 | description="Generate JS, HTML, ipynb, and py files for tutorials." 104 | ) 105 | parser.add_argument( 106 | "-w", "--repo_dir", metavar="path", required=True, help="captum repo directory." 107 | ) 108 | args = parser.parse_args() 109 | gen_tutorials(args.repo_dir) 110 | -------------------------------------------------------------------------------- /src/docker/captum_xil/scripts/run_mypy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # This runs mypy's static type checker on parts of Captum supporting type 5 | # hints. 6 | 7 | mypy -p captum.attr --ignore-missing-imports --allow-redefinition 8 | mypy -p captum.insights --ignore-missing-imports --allow-redefinition 9 | mypy -p tests --ignore-missing-imports --allow-redefinition 10 | -------------------------------------------------------------------------------- /src/docker/captum_xil/scripts/update_versions_html.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import json 5 | 6 | from bs4 import BeautifulSoup 7 | 8 | BASE_URL = "/" 9 | 10 | 11 | def updateVersionHTML(base_path, base_url=BASE_URL): 12 | with open(base_path + "/captum-master/website/_versions.json", "rb") as infile: 13 | versions = json.loads(infile.read()) 14 | 15 | with open(base_path + "/new-site/versions.html", "rb") as infile: 16 | html = infile.read() 17 | 18 | versions.append("latest") 19 | 20 | def prepend_url(a_tag, base_url, version): 21 | href = a_tag.attrs["href"] 22 | if href.startswith("https://") or href.startswith("http://"): 23 | return href 24 | else: 25 | return "{base_url}versions/{version}{original_url}".format( 26 | base_url=base_url, version=version, original_url=href 27 | ) 28 | 29 | for v in versions: 30 | soup = BeautifulSoup(html, "html.parser") 31 | 32 | # title 33 | title_link = soup.find("header").find("a") 34 | title_link.attrs["href"] = prepend_url(title_link, base_url, v) 35 | 36 | # nav 37 | nav_links = soup.find("nav").findAll("a") 38 | for l in nav_links: 39 | l.attrs["href"] = prepend_url(l, base_url, v) 40 | 41 | # version link 42 | t = soup.find("h2", {"class": "headerTitleWithLogo"}).find_next("a") 43 | t.string = v 44 | t.attrs["href"] = prepend_url(t, base_url, v) 45 | 46 | # output files 47 | with open( 48 | base_path + "/new-site/versions/{}/versions.html".format(v), "w" 49 | ) as outfile: 50 | outfile.write(str(soup)) 51 | with open( 52 | base_path + "/new-site/versions/{}/en/versions.html".format(v), "w" 53 | ) as outfile: 54 | outfile.write(str(soup)) 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser( 59 | description=( 60 | "Fix links in version.html files for Docusaurus site." 61 | "This is used to ensure that the versions.js for older " 62 | "versions in versions subdirectory are up-to-date and " 63 | "will have a way to navigate back to newer versions." 64 | ) 65 | ) 66 | parser.add_argument( 67 | "-p", 68 | "--base_path", 69 | metavar="path", 70 | required=True, 71 | help="Input directory for rolling out new version of site.", 72 | ) 73 | args = parser.parse_args() 74 | updateVersionHTML(args.base_path) 75 | -------------------------------------------------------------------------------- /src/docker/captum_xil/scripts/versions.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2019-present, Facebook, Inc. 3 | * 4 | * This source code is licensed under the BSD license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | * 7 | * @format 8 | */ 9 | 10 | const React = require('react'); 11 | 12 | const CompLibrary = require('../../core/CompLibrary'); 13 | 14 | const Container = CompLibrary.Container; 15 | 16 | const CWD = process.cwd(); 17 | 18 | const versions = require(`${CWD}/_versions.json`); 19 | versions.sort().reverse(); 20 | 21 | function Versions(props) { 22 | const {config: siteConfig} = props; 23 | const baseUrl = siteConfig.baseUrl; 24 | const latestVersion = versions[0]; 25 | return ( 26 |
27 | 28 |
29 |
30 |

{siteConfig.title} Versions

31 |
32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 45 | 48 | 49 | 50 | 54 | 59 | 62 | 63 | 64 |
VersionInstall withDocumentation
{`stable (${latestVersion})`} 43 | conda install -c pytorch captum 44 | 46 | stable 47 |
51 | {'latest'} 52 | {' (master)'} 53 | 55 | 56 | pip install git+ssh://git@github.com/pytorch/captum.git 57 | 58 | 60 | latest 61 |
65 | 66 |

Past Versions

67 | 68 | 69 | {versions.map( 70 | version => 71 | version !== latestVersion && ( 72 | 73 | 74 | 79 | 80 | ), 81 | )} 82 | 83 |
{version} 75 | 76 | Documentation 77 | 78 |
84 |
85 |
86 |
87 | ); 88 | } 89 | 90 | module.exports = Versions; 91 | -------------------------------------------------------------------------------- /src/docker/captum_xil/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # E203: black and flake8 disagree on whitespace before ':' 3 | # W503: black and flake8 disagree on how to place operators 4 | ignore = E203, W503 5 | max-line-length = 88 6 | exclude = 7 | build, dist, tutorials, website 8 | 9 | [coverage:report] 10 | omit = 11 | test/* 12 | setup.py 13 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/README.md: -------------------------------------------------------------------------------- 1 | # sphinx API reference 2 | 3 | This file describes the sphinx setup for auto-generating the API reference. 4 | 5 | 6 | ## Installation 7 | 8 | **Requirements**: 9 | - sphinx >= 2.0 10 | - sphinx_autodoc_typehints 11 | 12 | You can install these via `pip install sphinx sphinx_autodoc_typehints`. 13 | 14 | 15 | ## Building 16 | 17 | From the `captum/sphinx` directory, run `make html`. 18 | 19 | Generated HTML output can be found in the `captum/sphinx/build` directory. The main index page is: `captum/sphinx/build/html/index.html` 20 | 21 | 22 | ## Structure 23 | 24 | `source/index.rst` contains the main index. The API reference for each module lives in its own file, e.g. `models.rst` for the `captum.models` module. 25 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/approximation_methods.rst: -------------------------------------------------------------------------------- 1 | Captum Approximation 2 | ==================== 3 | 4 | .. automodule:: captum.attr._utils.approximation_methods 5 | 6 | .. autoclass:: Riemann 7 | :members: 8 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/attribution.rst: -------------------------------------------------------------------------------- 1 | Attribution 2 | ==================== 3 | .. toctree:: 4 | 5 | integrated_gradients 6 | saliency 7 | deep_lift 8 | deep_lift_shap 9 | gradient_shap 10 | input_x_gradient 11 | guided_backprop 12 | guided_grad_cam 13 | deconvolution 14 | feature_ablation 15 | occlusion 16 | feature_permutation 17 | shapley_value_sampling 18 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/base_classes.rst: -------------------------------------------------------------------------------- 1 | Base Classes 2 | ========== 3 | 4 | Attribution 5 | ^^^^^^^^^^^^ 6 | 7 | .. autoclass:: captum.attr.Attribution 8 | :members: 9 | 10 | Layer Attribution 11 | ^^^^^^^^^^^^^^^^^ 12 | 13 | .. autoclass:: captum.attr.LayerAttribution 14 | :members: 15 | 16 | Neuron Attribution 17 | ^^^^^^^^^^^^^^^^^^^ 18 | 19 | .. autoclass:: captum.attr.NeuronAttribution 20 | :members: 21 | 22 | Gradient Attribution 23 | ^^^^^^^^^^^^^^^^^^^^^ 24 | 25 | .. autoclass:: captum.attr.GradientAttribution 26 | :members: 27 | 28 | Perturbation Attribution 29 | ^^^^^^^^^^^^^^^^^^^^^ 30 | 31 | .. autoclass:: captum.attr.PerturbationAttribution 32 | :members: 33 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/common.rst: -------------------------------------------------------------------------------- 1 | Captum.Utils 2 | ============ 3 | 4 | .. automodule:: captum.attr._utils.common 5 | 6 | .. autofunction:: validate_input 7 | .. autofunction:: validate_noise_tunnel_type 8 | .. autofunction:: format_input 9 | .. autofunction:: _format_attributions 10 | .. autofunction:: zeros 11 | .. autofunction:: _reshape_and_sum 12 | .. autofunction:: _run_forward 13 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/deconvolution.rst: -------------------------------------------------------------------------------- 1 | Deconvolution 2 | ========= 3 | 4 | .. autoclass:: captum.attr.Deconvolution 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/deep_lift.rst: -------------------------------------------------------------------------------- 1 | DeepLift 2 | ========= 3 | 4 | .. autoclass:: captum.attr.DeepLift 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/deep_lift_shap.rst: -------------------------------------------------------------------------------- 1 | DeepLiftShap 2 | ============= 3 | 4 | .. autoclass:: captum.attr.DeepLiftShap 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/feature_ablation.rst: -------------------------------------------------------------------------------- 1 | Feature Ablation 2 | ========= 3 | 4 | .. autoclass:: captum.attr.FeatureAblation 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/feature_permutation.rst: -------------------------------------------------------------------------------- 1 | Feature Permutation 2 | ========= 3 | 4 | .. autoclass:: captum.attr.FeaturePermutation 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/gradient_shap.rst: -------------------------------------------------------------------------------- 1 | GradientShap 2 | ============ 3 | 4 | .. autoclass:: captum.attr.GradientShap 5 | :members: 6 | 7 | .. autoclass:: captum.attr.InputBaselineXGradient 8 | :members: 9 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/guided_backprop.rst: -------------------------------------------------------------------------------- 1 | Guided Backprop 2 | ========= 3 | 4 | .. autoclass:: captum.attr.GuidedBackprop 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/guided_grad_cam.rst: -------------------------------------------------------------------------------- 1 | Guided GradCAM 2 | ========= 3 | 4 | .. autoclass:: captum.attr.GuidedGradCam 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/pytorch/captum 2 | 3 | Captum API Reference 4 | ==================== 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Contents: 9 | 10 | .. toctree:: 11 | :maxdepth: 2 12 | :caption: API Reference 13 | 14 | attribution 15 | noise_tunnel 16 | layer 17 | neuron 18 | utilities 19 | base_classes 20 | 21 | .. toctree:: 22 | :maxdepth: 2 23 | :caption: Insights API Reference 24 | 25 | insights 26 | 27 | 28 | Indices and Tables 29 | ================== 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/input_x_gradient.rst: -------------------------------------------------------------------------------- 1 | Input X Gradient 2 | =============== 3 | 4 | .. autoclass:: captum.attr.InputXGradient 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/insights.rst: -------------------------------------------------------------------------------- 1 | Insights 2 | ============ 3 | 4 | Batch 5 | ^^^^^ 6 | 7 | .. autoclass:: captum.insights.api.Batch 8 | :members: 9 | 10 | AttributionVisualizer 11 | ^^^^^^^^^^^^^^^^^^^^^ 12 | .. autoclass:: captum.insights.api.AttributionVisualizer 13 | :members: 14 | 15 | 16 | Features 17 | ======== 18 | 19 | BaseFeature 20 | ^^^^^^^^^^^ 21 | .. autoclass:: captum.insights.features.BaseFeature 22 | :members: 23 | 24 | GeneralFeature 25 | ^^^^^^^^^^^^^^ 26 | .. autoclass:: captum.insights.features.GeneralFeature 27 | 28 | TextFeature 29 | ^^^^^^^^^^^ 30 | .. autoclass:: captum.insights.features.TextFeature 31 | 32 | ImageFeature 33 | ^^^^^^^^^^^^ 34 | .. autoclass:: captum.insights.features.ImageFeature 35 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/integrated_gradients.rst: -------------------------------------------------------------------------------- 1 | Integrated Gradients 2 | ==================== 3 | 4 | .. autoclass:: captum.attr.IntegratedGradients 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/layer.rst: -------------------------------------------------------------------------------- 1 | Layer Attribution 2 | ====== 3 | 4 | Layer Conductance 5 | ^^^^^^^^^^^^^^^^^ 6 | 7 | .. autoclass:: captum.attr.LayerConductance 8 | :members: 9 | 10 | 11 | Layer Activation 12 | ^^^^^^^^^^^^^^^^ 13 | 14 | .. autoclass:: captum.attr.LayerActivation 15 | :members: 16 | 17 | Internal Influence 18 | ^^^^^^^^^^^^^^^^^ 19 | 20 | .. autoclass:: captum.attr.InternalInfluence 21 | :members: 22 | 23 | Layer Gradient X Activation 24 | ^^^^^^^^^^^^^^^^^^^^^^^^^ 25 | 26 | .. autoclass:: captum.attr.LayerGradientXActivation 27 | :members: 28 | 29 | GradCAM 30 | ^^^^^^^^^^^^^^^^^^^^^^^^^ 31 | 32 | .. autoclass:: captum.attr.LayerGradCam 33 | :members: 34 | 35 | Layer DeepLift 36 | ^^^^^^^^^^^^^^^^^^^^^^^^^ 37 | 38 | .. autoclass:: captum.attr.LayerDeepLift 39 | :members: 40 | 41 | Layer DeepLiftShap 42 | ^^^^^^^^^^^^^^^^^^^^^^^^^ 43 | 44 | .. autoclass:: captum.attr.LayerDeepLiftShap 45 | :members: 46 | 47 | Layer GradientShap 48 | ^^^^^^^^^^^^^^^^^^^^^^^^^ 49 | 50 | .. autoclass:: captum.attr.LayerGradientShap 51 | :members: 52 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/neuron.rst: -------------------------------------------------------------------------------- 1 | Neuron Attribution 2 | ======= 3 | 4 | Neuron Gradient 5 | ^^^^^^^^^^^^^^^ 6 | 7 | .. autoclass:: captum.attr.NeuronGradient 8 | :members: 9 | 10 | Neuron Integrated Gradients 11 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 12 | 13 | .. autoclass:: captum.attr.NeuronIntegratedGradients 14 | :members: 15 | 16 | Neuron Conductance 17 | ^^^^^^^^^^^^^^^^^^^^^^^^ 18 | 19 | .. autoclass:: captum.attr.NeuronConductance 20 | :members: 21 | 22 | Neuron DeepLift 23 | ^^^^^^^^^^^^^^^ 24 | 25 | .. autoclass:: captum.attr.NeuronDeepLift 26 | :members: 27 | 28 | Neuron DeepLiftShap 29 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 30 | 31 | .. autoclass:: captum.attr.NeuronDeepLiftShap 32 | :members: 33 | 34 | Neuron GradientShap 35 | ^^^^^^^^^^^^^^^^^^^^^^^^ 36 | 37 | .. autoclass:: captum.attr.NeuronGradientShap 38 | :members: 39 | 40 | Neuron Guided Backprop 41 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 42 | 43 | .. autoclass:: captum.attr.NeuronGuidedBackprop 44 | :members: 45 | 46 | Neuron Deconvolution 47 | ^^^^^^^^^^^^^^^^^^^^^^^^ 48 | 49 | .. autoclass:: captum.attr.NeuronDeconvolution 50 | :members: 51 | 52 | Neuron Feature Ablation 53 | ^^^^^^^^^^^^^^^^^^^^^^^^ 54 | 55 | .. autoclass:: captum.attr.NeuronFeatureAblation 56 | :members: 57 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/noise_tunnel.rst: -------------------------------------------------------------------------------- 1 | NoiseTunnel 2 | ============ 3 | 4 | .. autoclass:: captum.attr.NoiseTunnel 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/occlusion.rst: -------------------------------------------------------------------------------- 1 | Occlusion 2 | ========= 3 | 4 | .. autoclass:: captum.attr.Occlusion 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/pytext.rst: -------------------------------------------------------------------------------- 1 | Captum.Models 2 | ========================== 3 | 4 | .. automodule:: captum.attr._models.pytext 5 | 6 | .. autoclass:: PyTextInterpretableEmbedding 7 | :members: 8 | 9 | 10 | .. autoclass:: BaselineGenerator 11 | :members: 12 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/saliency.rst: -------------------------------------------------------------------------------- 1 | Saliency 2 | =============== 3 | 4 | .. autoclass:: captum.attr.Saliency 5 | :members: 6 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/shapley_value_sampling.rst: -------------------------------------------------------------------------------- 1 | Shapley Value Sampling 2 | ========= 3 | 4 | .. autoclass:: captum.attr.ShapleyValueSampling 5 | :members: 6 | .. autoclass:: captum.attr.ShapleyValues 7 | :members: 8 | -------------------------------------------------------------------------------- /src/docker/captum_xil/sphinx/source/utilities.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========== 3 | 4 | Visualization 5 | ^^^^^^^^^^^^^^ 6 | 7 | .. autofunction:: captum.attr.visualization.visualize_image_attr 8 | 9 | .. autofunction:: captum.attr.visualization.visualize_image_attr_multiple 10 | 11 | 12 | Interpretable Embeddings 13 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 14 | 15 | .. autoclass:: captum.attr.InterpretableEmbeddingBase 16 | :members: 17 | 18 | .. autofunction:: captum.attr.configure_interpretable_embedding_layer 19 | .. autofunction:: captum.attr.remove_interpretable_embedding_layer 20 | 21 | 22 | Token Reference Base 23 | ^^^^^^^^^^^^^^^^^^^^^ 24 | 25 | .. autoclass:: captum.attr.TokenReferenceBase 26 | :members: 27 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tests/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tests/attr/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tests/attr/helpers/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/helpers/classification_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class SigmoidModel(nn.Module): 7 | """ 8 | Model architecture from: 9 | https://medium.com/coinmonks/create-a-neural-network-in 10 | -pytorch-and-make-your-life-simpler-ec5367895199 11 | """ 12 | 13 | def __init__(self, num_in, num_hidden, num_out): 14 | super().__init__() 15 | self.num_in = num_in 16 | self.num_hidden = num_hidden 17 | self.num_out = num_out 18 | self.lin1 = nn.Linear(num_in, num_hidden) 19 | self.lin2 = nn.Linear(num_hidden, num_out) 20 | self.relu1 = nn.ReLU() 21 | self.sigmoid = nn.Sigmoid() 22 | 23 | def forward(self, input): 24 | lin1 = self.lin1(input) 25 | lin2 = self.lin2(self.relu1(lin1)) 26 | return self.sigmoid(lin2) 27 | 28 | 29 | class SoftmaxModel(nn.Module): 30 | """ 31 | Model architecture from: 32 | https://adventuresinmachinelearning.com/pytorch-tutorial-deep-learning/ 33 | """ 34 | 35 | def __init__(self, num_in, num_hidden, num_out, inplace=False): 36 | super().__init__() 37 | self.num_in = num_in 38 | self.num_hidden = num_hidden 39 | self.num_out = num_out 40 | self.lin1 = nn.Linear(num_in, num_hidden) 41 | self.lin2 = nn.Linear(num_hidden, num_hidden) 42 | self.lin3 = nn.Linear(num_hidden, num_out) 43 | self.relu1 = nn.ReLU(inplace=inplace) 44 | self.relu2 = nn.ReLU(inplace=inplace) 45 | self.softmax = nn.Softmax(dim=1) 46 | 47 | def forward(self, input): 48 | lin1 = self.relu1(self.lin1(input)) 49 | lin2 = self.relu2(self.lin2(lin1)) 50 | lin3 = self.lin3(lin2) 51 | return self.softmax(lin3) 52 | 53 | 54 | class SigmoidDeepLiftModel(nn.Module): 55 | """ 56 | Model architecture from: 57 | https://medium.com/coinmonks/create-a-neural-network-in 58 | -pytorch-and-make-your-life-simpler-ec5367895199 59 | """ 60 | 61 | def __init__(self, num_in, num_hidden, num_out): 62 | super().__init__() 63 | self.num_in = num_in 64 | self.num_hidden = num_hidden 65 | self.num_out = num_out 66 | self.lin1 = nn.Linear(num_in, num_hidden, bias=False) 67 | self.lin2 = nn.Linear(num_hidden, num_out, bias=False) 68 | self.lin1.weight = nn.Parameter(torch.ones(num_hidden, num_in)) 69 | self.lin2.weight = nn.Parameter(torch.ones(num_out, num_hidden)) 70 | self.relu1 = nn.ReLU() 71 | self.sigmoid = nn.Sigmoid() 72 | 73 | def forward(self, input): 74 | lin1 = self.lin1(input) 75 | lin2 = self.lin2(self.relu1(lin1)) 76 | return self.sigmoid(lin2) 77 | 78 | 79 | class SoftmaxDeepLiftModel(nn.Module): 80 | """ 81 | Model architecture from: 82 | https://adventuresinmachinelearning.com/pytorch-tutorial-deep-learning/ 83 | """ 84 | 85 | def __init__(self, num_in, num_hidden, num_out): 86 | super().__init__() 87 | self.num_in = num_in 88 | self.num_hidden = num_hidden 89 | self.num_out = num_out 90 | self.lin1 = nn.Linear(num_in, num_hidden) 91 | self.lin2 = nn.Linear(num_hidden, num_hidden) 92 | self.lin3 = nn.Linear(num_hidden, num_out) 93 | self.lin1.weight = nn.Parameter(torch.ones(num_hidden, num_in)) 94 | self.lin2.weight = nn.Parameter(torch.ones(num_hidden, num_hidden)) 95 | self.lin3.weight = nn.Parameter(torch.ones(num_out, num_hidden)) 96 | self.relu1 = nn.ReLU() 97 | self.relu2 = nn.ReLU() 98 | self.softmax = nn.Softmax(dim=1) 99 | 100 | def forward(self, input): 101 | lin1 = self.relu1(self.lin1(input)) 102 | lin2 = self.relu2(self.lin2(lin1)) 103 | lin3 = self.lin3(lin2) 104 | return self.softmax(lin3) 105 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/helpers/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import random 3 | import unittest 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def assertArraysAlmostEqual(inputArr, refArr, delta=0.05): 10 | for index, (input, ref) in enumerate(zip(inputArr, refArr)): 11 | almost_equal = abs(input - ref) <= delta 12 | if hasattr(almost_equal, "__iter__"): 13 | almost_equal = almost_equal.all() 14 | assert ( 15 | almost_equal 16 | ), "Values at index {}, {} and {}, \ 17 | differ more than by {}".format( 18 | index, input, ref, delta 19 | ) 20 | 21 | 22 | def assertTensorAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"): 23 | assert isinstance(actual, torch.Tensor), ( 24 | "Actual parameter given for " "comparison must be a tensor." 25 | ) 26 | actual = actual.squeeze() 27 | if not isinstance(expected, torch.Tensor): 28 | expected = torch.tensor(expected, dtype=actual.dtype) 29 | if mode == "sum": 30 | test.assertAlmostEqual( 31 | torch.sum(torch.abs(actual - expected)).item(), 0.0, delta=delta 32 | ) 33 | elif mode == "max": 34 | test.assertAlmostEqual( 35 | torch.max(torch.abs(actual - expected)).item(), 0.0, delta=delta 36 | ) 37 | else: 38 | raise ValueError("Mode for assertion comparison must be one of `max` or `sum`.") 39 | 40 | 41 | def assertTensorTuplesAlmostEqual(test, actual, expected, delta=0.0001, mode="sum"): 42 | if isinstance(expected, tuple): 43 | for i in range(len(expected)): 44 | assertTensorAlmostEqual(test, actual[i], expected[i], delta, mode) 45 | else: 46 | assertTensorAlmostEqual(test, actual, expected, delta, mode) 47 | 48 | 49 | def assertAttributionComparision(test, attributions1, attributions2): 50 | for attribution1, attribution2 in zip(attributions1, attributions2): 51 | for attr_row1, attr_row2 in zip( 52 | attribution1.detach().numpy(), attribution2.detach().numpy() 53 | ): 54 | if isinstance(attr_row1, np.ndarray): 55 | assertArraysAlmostEqual(attr_row1, attr_row2, delta=0.05) 56 | else: 57 | test.assertAlmostEqual(attr_row1, attr_row2, delta=0.05) 58 | 59 | 60 | def assert_delta(test, delta): 61 | delta_condition = all(abs(delta.numpy().flatten()) < 0.00001) 62 | test.assertTrue( 63 | delta_condition, 64 | "The sum of attribution values {} for relu layer is not " 65 | "nearly equal to the difference between the endpoint for " 66 | "some samples".format(delta), 67 | ) 68 | 69 | 70 | class BaseTest(unittest.TestCase): 71 | """ 72 | This class provides a basic framework for all Captum tests by providing 73 | a set up fixture, which sets a fixed random seed. Since many torch 74 | initializations are random, this ensures that tests run deterministically. 75 | """ 76 | 77 | def setUp(self): 78 | random.seed(1234) 79 | np.random.seed(1234) 80 | torch.manual_seed(1234) 81 | torch.cuda.manual_seed_all(1234) 82 | torch.backends.cudnn.deterministic = True 83 | 84 | 85 | class BaseGPUTest(BaseTest): 86 | """ 87 | This class provides a basic framework for all Captum tests requiring 88 | CUDA and available GPUs to run appropriately, such as tests for 89 | DataParallel models. If CUDA is not available, these tests are skipped. 90 | """ 91 | 92 | def setUp(self): 93 | super().setUp() 94 | if not torch.cuda.is_available() or torch.cuda.device_count() == 0: 95 | raise unittest.SkipTest("Skipping GPU test since CUDA not available.") 96 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tests/attr/layer/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/layer/test_grad_cam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import unittest 4 | from typing import Any, List, Tuple, Union 5 | 6 | import torch 7 | from torch import Tensor 8 | from torch.nn import Module 9 | 10 | from captum.attr._core.layer.grad_cam import LayerGradCam 11 | 12 | from ..helpers.basic_models import BasicModel_ConvNet_One_Conv, BasicModel_MultiLayer 13 | from ..helpers.utils import BaseTest, assertTensorTuplesAlmostEqual 14 | 15 | 16 | class Test(BaseTest): 17 | def test_simple_input_non_conv(self) -> None: 18 | net = BasicModel_MultiLayer() 19 | inp = torch.tensor([[0.0, 100.0, 0.0]], requires_grad=True) 20 | self._grad_cam_test_assert(net, net.linear0, inp, [400.0]) 21 | 22 | def test_simple_multi_input_non_conv(self) -> None: 23 | net = BasicModel_MultiLayer(multi_input_module=True) 24 | inp = torch.tensor([[0.0, 6.0, 0.0]], requires_grad=True) 25 | self._grad_cam_test_assert(net, net.relu, inp, ([21.0], [21.0])) 26 | 27 | def test_simple_input_conv(self) -> None: 28 | net = BasicModel_ConvNet_One_Conv() 29 | inp = torch.arange(16).view(1, 1, 4, 4).float() 30 | self._grad_cam_test_assert(net, net.conv1, inp, [[11.25, 13.5], [20.25, 22.5]]) 31 | 32 | def test_simple_input_conv_no_grad(self) -> None: 33 | net = BasicModel_ConvNet_One_Conv() 34 | 35 | # this way we deactivate require_grad. Some models explicitly 36 | # do that before interpreting the model. 37 | for param in net.parameters(): 38 | param.requires_grad = False 39 | 40 | inp = torch.arange(16).view(1, 1, 4, 4).float() 41 | self._grad_cam_test_assert(net, net.conv1, inp, [[11.25, 13.5], [20.25, 22.5]]) 42 | 43 | def test_simple_input_conv_relu(self) -> None: 44 | net = BasicModel_ConvNet_One_Conv() 45 | inp = torch.arange(16).view(1, 1, 4, 4).float() 46 | self._grad_cam_test_assert(net, net.relu1, inp, [[0.0, 4.0], [28.0, 32.5]]) 47 | 48 | def test_simple_input_conv_without_final_relu(self) -> None: 49 | net = BasicModel_ConvNet_One_Conv() 50 | inp = torch.arange(16).view(1, 1, 4, 4).float() 51 | inp.requires_grad_() 52 | # Adding negative value to test final relu is not applied by default 53 | inp[0, 0, 1, 1] = -4.0 54 | self._grad_cam_test_assert( 55 | net, net.conv1, inp, (0.5625 * inp,), attribute_to_layer_input=True 56 | ) 57 | 58 | def test_simple_input_conv_fc_with_final_relu(self) -> None: 59 | net = BasicModel_ConvNet_One_Conv() 60 | inp = torch.arange(16).view(1, 1, 4, 4).float() 61 | inp.requires_grad_() 62 | # Adding negative value to test final relu is applied 63 | inp[0, 0, 1, 1] = -4.0 64 | exp = 0.5625 * inp 65 | exp[0, 0, 1, 1] = 0.0 66 | self._grad_cam_test_assert( 67 | net, 68 | net.conv1, 69 | inp, 70 | (exp,), 71 | attribute_to_layer_input=True, 72 | relu_attributions=True, 73 | ) 74 | 75 | def test_simple_multi_input_conv(self) -> None: 76 | net = BasicModel_ConvNet_One_Conv() 77 | inp = torch.arange(16).view(1, 1, 4, 4).float() 78 | inp2 = torch.ones((1, 1, 4, 4)) 79 | self._grad_cam_test_assert( 80 | net, net.conv1, (inp, inp2), [[14.5, 19.0], [32.5, 37.0]] 81 | ) 82 | 83 | def _grad_cam_test_assert( 84 | self, 85 | model: Module, 86 | target_layer: Module, 87 | test_input: Union[Tensor, Tuple[Tensor, ...]], 88 | expected_activation: Union[ 89 | List[float], Tuple[List[float], ...], List[List[float]], Tuple[Tensor, ...] 90 | ], 91 | additional_input: Any = None, 92 | attribute_to_layer_input: bool = False, 93 | relu_attributions: bool = False, 94 | ): 95 | layer_gc = LayerGradCam(model, target_layer) 96 | attributions = layer_gc.attribute( 97 | test_input, 98 | target=0, 99 | additional_forward_args=additional_input, 100 | attribute_to_layer_input=attribute_to_layer_input, 101 | relu_attributions=relu_attributions, 102 | ) 103 | assertTensorTuplesAlmostEqual( 104 | self, attributions, expected_activation, delta=0.01 105 | ) 106 | 107 | 108 | if __name__ == "__main__": 109 | unittest.main() 110 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/layer/test_layer_activation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import unittest 4 | from typing import Any, List, Tuple, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import Tensor 9 | from torch.nn import Module 10 | 11 | from captum.attr._core.layer.layer_activation import LayerActivation 12 | 13 | from ..helpers.basic_models import ( 14 | BasicModel_MultiLayer, 15 | BasicModel_MultiLayer_MultiInput, 16 | ) 17 | from ..helpers.utils import ( 18 | BaseTest, 19 | assertTensorAlmostEqual, 20 | assertTensorTuplesAlmostEqual, 21 | ) 22 | 23 | 24 | class Test(BaseTest): 25 | def test_simple_input_activation(self) -> None: 26 | net = BasicModel_MultiLayer() 27 | inp = torch.tensor([[0.0, 100.0, 0.0]], requires_grad=True) 28 | self._layer_activation_test_assert(net, net.linear0, inp, [0.0, 100.0, 0.0]) 29 | 30 | def test_simple_linear_activation(self) -> None: 31 | net = BasicModel_MultiLayer() 32 | inp = torch.tensor([[0.0, 100.0, 0.0]]) 33 | self._layer_activation_test_assert( 34 | net, net.linear1, inp, [90.0, 101.0, 101.0, 101.0] 35 | ) 36 | 37 | def test_simple_relu_activation_input_inplace(self) -> None: 38 | net = BasicModel_MultiLayer(inplace=True) 39 | inp = torch.tensor([[2.0, -5.0, 4.0]]) 40 | self._layer_activation_test_assert( 41 | net, net.relu, inp, ([-9.0, 2.0, 2.0, 2.0],), attribute_to_layer_input=True 42 | ) 43 | 44 | def test_simple_linear_activation_inplace(self) -> None: 45 | net = BasicModel_MultiLayer(inplace=True) 46 | inp = torch.tensor([[2.0, -5.0, 4.0]]) 47 | self._layer_activation_test_assert(net, net.linear1, inp, [-9.0, 2.0, 2.0, 2.0]) 48 | 49 | def test_simple_relu_activation(self) -> None: 50 | net = BasicModel_MultiLayer() 51 | inp = torch.tensor([[3.0, 4.0, 0.0]], requires_grad=True) 52 | self._layer_activation_test_assert(net, net.relu, inp, [0.0, 8.0, 8.0, 8.0]) 53 | 54 | def test_simple_output_activation(self) -> None: 55 | net = BasicModel_MultiLayer() 56 | inp = torch.tensor([[0.0, 100.0, 0.0]]) 57 | self._layer_activation_test_assert(net, net.linear2, inp, [392.0, 394.0]) 58 | 59 | def test_simple_multi_output_activation(self) -> None: 60 | net = BasicModel_MultiLayer(multi_input_module=True) 61 | inp = torch.tensor([[0.0, 6.0, 0.0]]) 62 | self._layer_activation_test_assert( 63 | net, net.relu, inp, ([0.0, 7.0, 7.0, 7.0], [0.0, 7.0, 7.0, 7.0]) 64 | ) 65 | 66 | def test_simple_multi_input_activation(self) -> None: 67 | net = BasicModel_MultiLayer(multi_input_module=True) 68 | inp = torch.tensor([[0.0, 6.0, 0.0]]) 69 | self._layer_activation_test_assert( 70 | net, 71 | net.relu, 72 | inp, 73 | ([-4.0, 7.0, 7.0, 7.0], [-4.0, 7.0, 7.0, 7.0]), 74 | attribute_to_layer_input=True, 75 | ) 76 | 77 | def test_simple_multi_input_linear2_activation(self) -> None: 78 | net = BasicModel_MultiLayer_MultiInput() 79 | inp1 = torch.tensor([[0.0, 10.0, 0.0]]) 80 | inp2 = torch.tensor([[0.0, 10.0, 0.0]]) 81 | inp3 = torch.tensor([[0.0, 5.0, 0.0]]) 82 | self._layer_activation_test_assert( 83 | net, net.model.linear2, (inp1, inp2, inp3), [392.0, 394.0], (4,) 84 | ) 85 | 86 | def test_simple_multi_input_relu_activation(self) -> None: 87 | net = BasicModel_MultiLayer_MultiInput() 88 | inp1 = torch.tensor([[0.0, 10.0, 1.0]]) 89 | inp2 = torch.tensor([[0.0, 4.0, 5.0]]) 90 | inp3 = torch.tensor([[0.0, 0.0, 0.0]]) 91 | self._layer_activation_test_assert( 92 | net, net.model.relu, (inp1, inp2), [90.0, 101.0, 101.0, 101.0], (inp3, 5) 93 | ) 94 | 95 | def test_sequential_in_place(self) -> None: 96 | model = nn.Sequential(nn.Conv2d(3, 4, 3), nn.ReLU(inplace=True)) 97 | layer_act = LayerActivation(model, model[0]) 98 | input = torch.randn(1, 3, 5, 5) 99 | assertTensorAlmostEqual(self, layer_act.attribute(input), model[0](input)) 100 | 101 | def _layer_activation_test_assert( 102 | self, 103 | model: Module, 104 | target_layer: Module, 105 | test_input: Union[Tensor, Tuple[Tensor, ...]], 106 | expected_activation: Union[List[float], Tuple[List[float], ...]], 107 | additional_input: Any = None, 108 | attribute_to_layer_input: bool = False, 109 | ): 110 | layer_act = LayerActivation(model, target_layer) 111 | attributions = layer_act.attribute( 112 | test_input, 113 | additional_forward_args=additional_input, 114 | attribute_to_layer_input=attribute_to_layer_input, 115 | ) 116 | assertTensorTuplesAlmostEqual( 117 | self, attributions, expected_activation, delta=0.01 118 | ) 119 | 120 | 121 | if __name__ == "__main__": 122 | unittest.main() 123 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/layer/test_layer_gradient_x_activation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import unittest 3 | from typing import Any, List, Tuple, Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.nn import Module 8 | 9 | from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation 10 | 11 | from ..helpers.basic_models import ( 12 | BasicModel_MultiLayer, 13 | BasicModel_MultiLayer_MultiInput, 14 | ) 15 | from ..helpers.utils import BaseTest, assertTensorTuplesAlmostEqual 16 | 17 | 18 | class Test(BaseTest): 19 | def test_simple_input_gradient_activation(self) -> None: 20 | net = BasicModel_MultiLayer() 21 | inp = torch.tensor([[0.0, 100.0, 0.0]], requires_grad=True) 22 | self._layer_activation_test_assert(net, net.linear0, inp, [0.0, 400.0, 0.0]) 23 | 24 | def test_simple_linear_gradient_activation(self) -> None: 25 | net = BasicModel_MultiLayer() 26 | inp = torch.tensor([[0.0, 100.0, 0.0]]) 27 | self._layer_activation_test_assert( 28 | net, net.linear1, inp, [90.0, 101.0, 101.0, 101.0] 29 | ) 30 | 31 | def test_simple_linear_gradient_activation_no_grad(self) -> None: 32 | net = BasicModel_MultiLayer() 33 | inp = torch.tensor([[0.0, 100.0, 0.0]]) 34 | 35 | # this way we deactivate require_grad. Some models explicitly 36 | # do that before interpreting the model. 37 | for param in net.parameters(): 38 | param.requires_grad = False 39 | 40 | self._layer_activation_test_assert( 41 | net, net.linear1, inp, [90.0, 101.0, 101.0, 101.0] 42 | ) 43 | 44 | def test_simple_multi_gradient_activation(self) -> None: 45 | net = BasicModel_MultiLayer(multi_input_module=True) 46 | inp = torch.tensor([[3.0, 4.0, 0.0]]) 47 | self._layer_activation_test_assert( 48 | net, net.relu, inp, ([0.0, 8.0, 8.0, 8.0], [0.0, 8.0, 8.0, 8.0]) 49 | ) 50 | 51 | def test_simple_relu_gradient_activation(self) -> None: 52 | net = BasicModel_MultiLayer() 53 | inp = torch.tensor([[3.0, 4.0, 0.0]], requires_grad=True) 54 | self._layer_activation_test_assert(net, net.relu, inp, [0.0, 8.0, 8.0, 8.0]) 55 | 56 | def test_simple_output_gradient_activation(self) -> None: 57 | net = BasicModel_MultiLayer() 58 | inp = torch.tensor([[0.0, 100.0, 0.0]]) 59 | self._layer_activation_test_assert(net, net.linear2, inp, [392.0, 0.0]) 60 | 61 | def test_simple_gradient_activation_multi_input_linear2(self) -> None: 62 | net = BasicModel_MultiLayer_MultiInput() 63 | inp1 = torch.tensor([[0.0, 10.0, 0.0]]) 64 | inp2 = torch.tensor([[0.0, 10.0, 0.0]]) 65 | inp3 = torch.tensor([[0.0, 5.0, 0.0]]) 66 | self._layer_activation_test_assert( 67 | net, net.model.linear2, (inp1, inp2, inp3), [392.0, 0.0], (4,) 68 | ) 69 | 70 | def test_simple_gradient_activation_multi_input_relu(self) -> None: 71 | net = BasicModel_MultiLayer_MultiInput() 72 | inp1 = torch.tensor([[0.0, 10.0, 1.0]]) 73 | inp2 = torch.tensor([[0.0, 4.0, 5.0]]) 74 | inp3 = torch.tensor([[0.0, 0.0, 0.0]]) 75 | self._layer_activation_test_assert( 76 | net, net.model.relu, (inp1, inp2), [90.0, 101.0, 101.0, 101.0], (inp3, 5) 77 | ) 78 | 79 | def _layer_activation_test_assert( 80 | self, 81 | model: Module, 82 | target_layer: Module, 83 | test_input: Union[Tensor, Tuple[Tensor, ...]], 84 | expected_activation: Union[List[float], Tuple[List[float], ...]], 85 | additional_input: Any = None, 86 | ) -> None: 87 | layer_act = LayerGradientXActivation(model, target_layer) 88 | attributions = layer_act.attribute( 89 | test_input, target=0, additional_forward_args=additional_input 90 | ) 91 | assertTensorTuplesAlmostEqual( 92 | self, attributions, expected_activation, delta=0.01 93 | ) 94 | 95 | 96 | if __name__ == "__main__": 97 | unittest.main() 98 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tests/attr/models/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/neuron/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tests/attr/neuron/__init__.py -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/neuron/test_neuron_gradient_shap.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from typing import Callable, Union 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.nn import Module 7 | 8 | from captum.attr._core.neuron.neuron_gradient_shap import NeuronGradientShap 9 | from captum.attr._core.neuron.neuron_integrated_gradients import ( 10 | NeuronIntegratedGradients, 11 | ) 12 | 13 | from ..helpers.basic_models import BasicModel_MultiLayer 14 | from ..helpers.classification_models import SoftmaxModel 15 | from ..helpers.utils import BaseTest, assertTensorAlmostEqual 16 | 17 | 18 | class Test(BaseTest): 19 | def test_basic_multilayer(self) -> None: 20 | model = BasicModel_MultiLayer(inplace=True) 21 | model.eval() 22 | 23 | inputs = torch.tensor([[1.0, 20.0, 10.0]]) 24 | baselines = torch.randn(2, 3) 25 | 26 | self._assert_attributions(model, model.linear1, inputs, baselines, 0) 27 | 28 | def test_classification(self) -> None: 29 | def custom_baseline_fn(inputs: Tensor) -> Tensor: 30 | num_in = inputs.shape[1] # type: ignore 31 | return torch.arange(0.0, num_in * 5.0).reshape(5, num_in) 32 | 33 | num_in = 40 34 | n_samples = 100 35 | 36 | # 10-class classification model 37 | model = SoftmaxModel(num_in, 20, 10) 38 | model.eval() 39 | 40 | inputs = torch.arange(0.0, num_in * 2.0).reshape(2, num_in) 41 | baselines = custom_baseline_fn 42 | 43 | self._assert_attributions(model, model.relu1, inputs, baselines, 1, n_samples) 44 | 45 | def _assert_attributions( 46 | self, 47 | model: Module, 48 | layer: Module, 49 | inputs: Tensor, 50 | baselines: Union[Tensor, Callable[..., Tensor]], 51 | neuron_ind: Union[int, tuple], 52 | n_samples: int = 5, 53 | ) -> None: 54 | ngs = NeuronGradientShap(model, layer) 55 | nig = NeuronIntegratedGradients(model, layer) 56 | attrs_gs = ngs.attribute( 57 | inputs, neuron_ind, baselines=baselines, n_samples=n_samples, stdevs=0.09 58 | ) 59 | 60 | if callable(baselines): 61 | baselines = baselines(inputs) 62 | 63 | attrs_ig = [] 64 | for baseline in torch.unbind(baselines): 65 | attrs_ig.append( 66 | nig.attribute(inputs, neuron_ind, baselines=baseline.unsqueeze(0)) 67 | ) 68 | combined_attrs_ig = torch.stack(attrs_ig, dim=0).mean(dim=0) 69 | assertTensorAlmostEqual(self, attrs_gs, combined_attrs_ig, 0.5) 70 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/test_approximation_methods.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import unittest 4 | 5 | from captum.attr._utils.approximation_methods import Riemann, riemann_builders 6 | 7 | from .helpers.utils import assertArraysAlmostEqual 8 | 9 | 10 | class Test(unittest.TestCase): 11 | def __init__(self, methodName="runTest"): 12 | super().__init__(methodName) 13 | 14 | def test_riemann_0(self): 15 | with self.assertRaises(AssertionError): 16 | step_sizes, alphas = riemann_builders() 17 | step_sizes(0) 18 | alphas(0) 19 | 20 | def test_riemann_2(self): 21 | expected_step_sizes_lrm = [0.5, 0.5] 22 | expected_step_sizes_trapezoid = [0.25, 0.25] 23 | expected_left = [0.0, 0.5] 24 | expected_right = [0.5, 1.0] 25 | expected_middle = [0.25, 0.75] 26 | expected_trapezoid = [0.0, 1.0] 27 | self._assert_steps_and_alphas( 28 | 2, 29 | expected_step_sizes_lrm, 30 | expected_step_sizes_trapezoid, 31 | expected_left, 32 | expected_right, 33 | expected_middle, 34 | expected_trapezoid, 35 | ) 36 | 37 | def test_riemann_3(self): 38 | expected_step_sizes = [1 / 3] * 3 39 | expected_step_sizes_trapezoid = [1 / 6, 1 / 3, 1 / 6] 40 | expected_left = [0.0, 1 / 3, 2 / 3] 41 | expected_right = [1 / 3, 2 / 3, 1.0] 42 | expected_middle = [1 / 6, 0.5, 1 - 1 / 6] 43 | expected_trapezoid = [0.0, 0.5, 1.0] 44 | self._assert_steps_and_alphas( 45 | 3, 46 | expected_step_sizes, 47 | expected_step_sizes_trapezoid, 48 | expected_left, 49 | expected_right, 50 | expected_middle, 51 | expected_trapezoid, 52 | ) 53 | 54 | def test_riemann_4(self): 55 | expected_step_sizes = [1 / 4] * 4 56 | expected_step_sizes_trapezoid = [1 / 8, 1 / 4, 1 / 4, 1 / 8] 57 | expected_left = [0.0, 0.25, 0.5, 0.75] 58 | expected_right = [0.25, 0.5, 0.75, 1.0] 59 | expected_middle = [0.125, 0.375, 0.625, 0.875] 60 | expected_trapezoid = [0.0, 1 / 3, 2 / 3, 1.0] 61 | self._assert_steps_and_alphas( 62 | 4, 63 | expected_step_sizes, 64 | expected_step_sizes_trapezoid, 65 | expected_left, 66 | expected_right, 67 | expected_middle, 68 | expected_trapezoid, 69 | ) 70 | 71 | def _assert_steps_and_alphas( 72 | self, 73 | n, 74 | expected_step_sizes, 75 | expected_step_sizes_trapezoid, 76 | expected_left, 77 | expected_right, 78 | expected_middle, 79 | expected_trapezoid, 80 | ): 81 | step_sizes_left, alphas_left = riemann_builders(Riemann.left) 82 | step_sizes_right, alphas_right = riemann_builders(Riemann.right) 83 | step_sizes_middle, alphas_middle = riemann_builders(Riemann.middle) 84 | step_sizes_trapezoid, alphas_trapezoid = riemann_builders(Riemann.trapezoid) 85 | assertArraysAlmostEqual(expected_step_sizes, step_sizes_left(n)) 86 | assertArraysAlmostEqual(expected_step_sizes, step_sizes_right(n)) 87 | assertArraysAlmostEqual(expected_step_sizes, step_sizes_middle(n)) 88 | assertArraysAlmostEqual(expected_step_sizes_trapezoid, step_sizes_trapezoid(n)) 89 | assertArraysAlmostEqual(expected_left, alphas_left(n)) 90 | assertArraysAlmostEqual(expected_right, alphas_right(n)) 91 | assertArraysAlmostEqual(expected_middle, alphas_middle(n)) 92 | assertArraysAlmostEqual(expected_trapezoid, alphas_trapezoid(n)) 93 | 94 | 95 | # TODO write a test case for gauss-legendre 96 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/test_common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import List, Tuple, cast 4 | 5 | import torch 6 | 7 | from captum.attr._core.noise_tunnel import SUPPORTED_NOISE_TUNNEL_TYPES 8 | from captum.attr._utils.common import ( 9 | _select_targets, 10 | _validate_input, 11 | _validate_noise_tunnel_type, 12 | ) 13 | 14 | from .helpers.utils import BaseTest, assertTensorAlmostEqual 15 | 16 | 17 | class Test(BaseTest): 18 | def test_validate_input(self) -> None: 19 | with self.assertRaises(AssertionError): 20 | _validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),)) 21 | _validate_input( 22 | (torch.tensor([-1.0, 1.0]),), (torch.tensor([-1.0, 1.0]),), n_steps=-1 23 | ) 24 | _validate_input( 25 | (torch.tensor([-1.0, 1.0]),), 26 | (torch.tensor([-1.0, 1.0]),), 27 | method="abcde", 28 | ) 29 | _validate_input((torch.tensor([-1.0]),), (torch.tensor([-2.0]),)) 30 | _validate_input( 31 | (torch.tensor([-1.0]),), (torch.tensor([-2.0]),), method="gausslegendre" 32 | ) 33 | 34 | def test_validate_nt_type(self) -> None: 35 | with self.assertRaises(AssertionError): 36 | _validate_noise_tunnel_type("abc", SUPPORTED_NOISE_TUNNEL_TYPES) 37 | _validate_noise_tunnel_type("smoothgrad", SUPPORTED_NOISE_TUNNEL_TYPES) 38 | _validate_noise_tunnel_type("smoothgrad_sq", SUPPORTED_NOISE_TUNNEL_TYPES) 39 | _validate_noise_tunnel_type("vargrad", SUPPORTED_NOISE_TUNNEL_TYPES) 40 | 41 | def test_select_target_2d(self) -> None: 42 | output_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 43 | assertTensorAlmostEqual(self, _select_targets(output_tensor, 1), [2, 5, 8]) 44 | assertTensorAlmostEqual( 45 | self, _select_targets(output_tensor, torch.tensor(0)), [1, 4, 7] 46 | ) 47 | assertTensorAlmostEqual( 48 | self, _select_targets(output_tensor, torch.tensor([1, 2, 0])), [2, 6, 7] 49 | ) 50 | assertTensorAlmostEqual( 51 | self, _select_targets(output_tensor, [1, 2, 0]), [2, 6, 7] 52 | ) 53 | 54 | # Verify error is raised if too many dimensions are provided. 55 | with self.assertRaises(AssertionError): 56 | _select_targets(output_tensor, (1, 2)) 57 | 58 | def test_select_target_3d(self) -> None: 59 | output_tensor = torch.tensor( 60 | [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[9, 8, 7], [6, 5, 4], [3, 2, 1]]] 61 | ) 62 | assertTensorAlmostEqual(self, _select_targets(output_tensor, (0, 1)), [2, 8]) 63 | assertTensorAlmostEqual( 64 | self, 65 | _select_targets( 66 | output_tensor, cast(List[Tuple[int, ...]], [(0, 1), (2, 0)]) 67 | ), 68 | [2, 3], 69 | ) 70 | 71 | # Verify error is raised if list is longer than number of examples. 72 | with self.assertRaises(AssertionError): 73 | _select_targets( 74 | output_tensor, cast(List[Tuple[int, ...]], [(0, 1), (2, 0), (3, 2)]) 75 | ) 76 | 77 | # Verify error is raised if too many dimensions are provided. 78 | with self.assertRaises(AssertionError): 79 | _select_targets(output_tensor, (1, 2, 3)) 80 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/test_deconvolution.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from __future__ import print_function 4 | 5 | import unittest 6 | from typing import Any, List, Tuple, Union 7 | 8 | import torch 9 | from torch.nn import Module 10 | 11 | from captum.attr._core.guided_backprop_deconvnet import Deconvolution 12 | from captum.attr._core.neuron.neuron_guided_backprop_deconvnet import ( 13 | NeuronDeconvolution, 14 | ) 15 | from captum.attr._utils.typing import TensorOrTupleOfTensorsGeneric 16 | 17 | from .helpers.basic_models import BasicModel_ConvNet_One_Conv 18 | from .helpers.utils import BaseTest, assertTensorAlmostEqual 19 | 20 | 21 | class Test(BaseTest): 22 | def test_simple_input_conv_deconv(self) -> None: 23 | net = BasicModel_ConvNet_One_Conv() 24 | inp = 1.0 * torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 25 | exp = [ 26 | [2.0, 3.0, 3.0, 1.0], 27 | [3.0, 5.0, 5.0, 2.0], 28 | [3.0, 5.0, 5.0, 2.0], 29 | [1.0, 2.0, 2.0, 1.0], 30 | ] 31 | self._deconv_test_assert(net, (inp,), (exp,)) 32 | 33 | def test_simple_input_conv_neuron_deconv(self) -> None: 34 | net = BasicModel_ConvNet_One_Conv() 35 | inp = 1.0 * torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 36 | exp = [ 37 | [2.0, 3.0, 3.0, 1.0], 38 | [3.0, 5.0, 5.0, 2.0], 39 | [3.0, 5.0, 5.0, 2.0], 40 | [1.0, 2.0, 2.0, 1.0], 41 | ] 42 | self._neuron_deconv_test_assert(net, net.fc1, (0,), (inp,), (exp,)) 43 | 44 | def test_simple_multi_input_conv_deconv(self) -> None: 45 | net = BasicModel_ConvNet_One_Conv() 46 | inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 47 | inp2 = torch.ones((1, 1, 4, 4)) 48 | ex_attr = [ 49 | [2.0, 3.0, 3.0, 1.0], 50 | [3.0, 5.0, 5.0, 2.0], 51 | [3.0, 5.0, 5.0, 2.0], 52 | [1.0, 2.0, 2.0, 1.0], 53 | ] 54 | self._deconv_test_assert(net, (inp, inp2), (ex_attr, ex_attr)) 55 | 56 | def test_simple_multi_input_conv_neuron_deconv(self) -> None: 57 | net = BasicModel_ConvNet_One_Conv() 58 | inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 59 | inp2 = torch.ones((1, 1, 4, 4)) 60 | ex_attr = [ 61 | [2.0, 3.0, 3.0, 1.0], 62 | [3.0, 5.0, 5.0, 2.0], 63 | [3.0, 5.0, 5.0, 2.0], 64 | [1.0, 2.0, 2.0, 1.0], 65 | ] 66 | self._neuron_deconv_test_assert( 67 | net, net.fc1, (3,), (inp, inp2), (ex_attr, ex_attr) 68 | ) 69 | 70 | def test_deconv_matching(self) -> None: 71 | net = BasicModel_ConvNet_One_Conv() 72 | inp = 100.0 * torch.randn(1, 1, 4, 4) 73 | self._deconv_matching_assert(net, net.relu2, inp) 74 | 75 | def _deconv_test_assert( 76 | self, 77 | model: Module, 78 | test_input: TensorOrTupleOfTensorsGeneric, 79 | expected: Tuple[List[List[float]], ...], 80 | additional_input: Any = None, 81 | ) -> None: 82 | deconv = Deconvolution(model) 83 | attributions = deconv.attribute( 84 | test_input, target=0, additional_forward_args=additional_input 85 | ) 86 | for i in range(len(test_input)): 87 | assertTensorAlmostEqual(self, attributions[i], expected[i], delta=0.01) 88 | 89 | def _neuron_deconv_test_assert( 90 | self, 91 | model: Module, 92 | layer: Module, 93 | neuron_index: Union[int, Tuple[int, ...]], 94 | test_input: TensorOrTupleOfTensorsGeneric, 95 | expected: Tuple[List[List[float]], ...], 96 | additional_input: Any = None, 97 | ) -> None: 98 | deconv = NeuronDeconvolution(model, layer) 99 | attributions = deconv.attribute( 100 | test_input, 101 | neuron_index=neuron_index, 102 | additional_forward_args=additional_input, 103 | ) 104 | for i in range(len(test_input)): 105 | assertTensorAlmostEqual(self, attributions[i], expected[i], delta=0.01) 106 | 107 | def _deconv_matching_assert( 108 | self, 109 | model: Module, 110 | output_layer: Module, 111 | test_input: TensorOrTupleOfTensorsGeneric, 112 | ) -> None: 113 | out = model(test_input) 114 | attrib = Deconvolution(model) 115 | neuron_attrib = NeuronDeconvolution(model, output_layer) 116 | for i in range(out.shape[1]): 117 | deconv_vals = attrib.attribute(test_input, target=i) 118 | neuron_deconv_vals = neuron_attrib.attribute(test_input, (i,)) 119 | assertTensorAlmostEqual(self, deconv_vals, neuron_deconv_vals, delta=0.01) 120 | 121 | 122 | if __name__ == "__main__": 123 | unittest.main() 124 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/test_guided_backprop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import unittest 4 | from typing import Any, List, Tuple, Union 5 | 6 | import torch 7 | from torch.nn import Module 8 | 9 | from captum.attr._core.guided_backprop_deconvnet import GuidedBackprop 10 | from captum.attr._core.neuron.neuron_guided_backprop_deconvnet import ( 11 | NeuronGuidedBackprop, 12 | ) 13 | from captum.attr._utils.typing import TensorOrTupleOfTensorsGeneric 14 | 15 | from .helpers.basic_models import BasicModel_ConvNet_One_Conv 16 | from .helpers.utils import BaseTest, assertTensorAlmostEqual 17 | 18 | 19 | class Test(BaseTest): 20 | def test_simple_input_conv_gb(self) -> None: 21 | net = BasicModel_ConvNet_One_Conv() 22 | inp = 1.0 * torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 23 | exp = [ 24 | [0.0, 1.0, 1.0, 1.0], 25 | [1.0, 3.0, 3.0, 2.0], 26 | [1.0, 3.0, 3.0, 2.0], 27 | [1.0, 2.0, 2.0, 1.0], 28 | ] 29 | self._guided_backprop_test_assert(net, (inp,), (exp,)) 30 | 31 | def test_simple_input_conv_neuron_gb(self) -> None: 32 | net = BasicModel_ConvNet_One_Conv() 33 | inp = 1.0 * torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 34 | exp = [ 35 | [0.0, 1.0, 1.0, 1.0], 36 | [1.0, 3.0, 3.0, 2.0], 37 | [1.0, 3.0, 3.0, 2.0], 38 | [1.0, 2.0, 2.0, 1.0], 39 | ] 40 | self._neuron_guided_backprop_test_assert(net, net.fc1, (0,), (inp,), (exp,)) 41 | 42 | def test_simple_multi_input_conv_gb(self) -> None: 43 | net = BasicModel_ConvNet_One_Conv() 44 | inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 45 | inp2 = torch.ones((1, 1, 4, 4)) 46 | ex_attr = [ 47 | [1.0, 2.0, 2.0, 1.0], 48 | [2.0, 4.0, 4.0, 2.0], 49 | [2.0, 4.0, 4.0, 2.0], 50 | [1.0, 2.0, 2.0, 1.0], 51 | ] 52 | self._guided_backprop_test_assert(net, (inp, inp2), (ex_attr, ex_attr)) 53 | 54 | def test_simple_multi_input_conv_neuron_gb(self) -> None: 55 | net = BasicModel_ConvNet_One_Conv() 56 | inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 57 | inp2 = torch.ones((1, 1, 4, 4)) 58 | ex_attr = [ 59 | [1.0, 2.0, 2.0, 1.0], 60 | [2.0, 4.0, 4.0, 2.0], 61 | [2.0, 4.0, 4.0, 2.0], 62 | [1.0, 2.0, 2.0, 1.0], 63 | ] 64 | self._neuron_guided_backprop_test_assert( 65 | net, net.fc1, (3,), (inp, inp2), (ex_attr, ex_attr) 66 | ) 67 | 68 | def test_gb_matching(self) -> None: 69 | net = BasicModel_ConvNet_One_Conv() 70 | inp = 100.0 * torch.randn(1, 1, 4, 4) 71 | self._guided_backprop_matching_assert(net, net.relu2, inp) 72 | 73 | def _guided_backprop_test_assert( 74 | self, 75 | model: Module, 76 | test_input: TensorOrTupleOfTensorsGeneric, 77 | expected: Tuple[List[List[float]], ...], 78 | additional_input: Any = None, 79 | ) -> None: 80 | guided_backprop = GuidedBackprop(model) 81 | attributions = guided_backprop.attribute( 82 | test_input, target=0, additional_forward_args=additional_input 83 | ) 84 | for i in range(len(test_input)): 85 | assertTensorAlmostEqual(self, attributions[i], expected[i], delta=0.01) 86 | 87 | def _neuron_guided_backprop_test_assert( 88 | self, 89 | model: Module, 90 | layer: Module, 91 | neuron_index: Union[int, Tuple[int, ...]], 92 | test_input: TensorOrTupleOfTensorsGeneric, 93 | expected: Tuple[List[List[float]], ...], 94 | additional_input: Any = None, 95 | ) -> None: 96 | guided_backprop = NeuronGuidedBackprop(model, layer) 97 | attributions = guided_backprop.attribute( 98 | test_input, 99 | neuron_index=neuron_index, 100 | additional_forward_args=additional_input, 101 | ) 102 | for i in range(len(test_input)): 103 | assertTensorAlmostEqual(self, attributions[i], expected[i], delta=0.01) 104 | 105 | def _guided_backprop_matching_assert( 106 | self, 107 | model: Module, 108 | output_layer: Module, 109 | test_input: TensorOrTupleOfTensorsGeneric, 110 | ): 111 | out = model(test_input) 112 | attrib = GuidedBackprop(model) 113 | neuron_attrib = NeuronGuidedBackprop(model, output_layer) 114 | for i in range(out.shape[1]): 115 | gbp_vals = attrib.attribute(test_input, target=i) 116 | neuron_gbp_vals = neuron_attrib.attribute(test_input, (i,)) 117 | assertTensorAlmostEqual(self, gbp_vals, neuron_gbp_vals, delta=0.01) 118 | 119 | 120 | if __name__ == "__main__": 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/test_guided_grad_cam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import unittest 4 | from typing import Any 5 | 6 | import torch 7 | from torch.nn import Module 8 | 9 | from captum.attr._core.guided_grad_cam import GuidedGradCam 10 | from captum.attr._utils.typing import TensorOrTupleOfTensorsGeneric 11 | 12 | from .helpers.basic_models import BasicModel_ConvNet_One_Conv 13 | from .helpers.utils import BaseTest, assertTensorAlmostEqual 14 | 15 | 16 | class Test(BaseTest): 17 | def test_simple_input_conv(self) -> None: 18 | net = BasicModel_ConvNet_One_Conv() 19 | inp = 1.0 * torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 20 | ex = [ 21 | [0.0, 0.0, 4.0, 4.0], 22 | [0.0, 0.0, 12.0, 8.0], 23 | [28.0, 84.0, 97.5, 65.0], 24 | [28.0, 56.0, 65.0, 32.5], 25 | ] 26 | self._guided_grad_cam_test_assert(net, net.relu1, inp, ex) 27 | 28 | def test_simple_multi_input_conv(self) -> None: 29 | net = BasicModel_ConvNet_One_Conv() 30 | inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 31 | inp2 = torch.ones((1, 1, 4, 4)) 32 | ex = [ 33 | [14.5, 29.0, 38.0, 19.0], 34 | [29.0, 58.0, 76.0, 38.0], 35 | [65.0, 130.0, 148.0, 74.0], 36 | [32.5, 65.0, 74.0, 37.0], 37 | ] 38 | self._guided_grad_cam_test_assert(net, net.conv1, (inp, inp2), (ex, ex)) 39 | 40 | def test_simple_multi_input_relu_input_inplace(self) -> None: 41 | net = BasicModel_ConvNet_One_Conv(inplace=True) 42 | inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 43 | inp2 = torch.ones((1, 1, 4, 4)) 44 | ex = [ 45 | [14.5, 29.0, 38.0, 19.0], 46 | [29.0, 58.0, 76.0, 38.0], 47 | [65.0, 130.0, 148.0, 74.0], 48 | [32.5, 65.0, 74.0, 37.0], 49 | ] 50 | self._guided_grad_cam_test_assert( 51 | net, net.relu1, (inp, inp2), (ex, ex), attribute_to_layer_input=True 52 | ) 53 | 54 | def test_simple_multi_input_conv_inplace(self) -> None: 55 | net = BasicModel_ConvNet_One_Conv(inplace=True) 56 | inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 57 | inp2 = torch.ones((1, 1, 4, 4)) 58 | ex = [ 59 | [14.5, 29.0, 38.0, 19.0], 60 | [29.0, 58.0, 76.0, 38.0], 61 | [65.0, 130.0, 148.0, 74.0], 62 | [32.5, 65.0, 74.0, 37.0], 63 | ] 64 | self._guided_grad_cam_test_assert(net, net.conv1, (inp, inp2), (ex, ex)) 65 | 66 | def test_improper_dims_multi_input_conv(self) -> None: 67 | net = BasicModel_ConvNet_One_Conv() 68 | inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 69 | inp2 = torch.ones(1) 70 | ex = [ 71 | [14.5, 29.0, 38.0, 19.0], 72 | [29.0, 58.0, 76.0, 38.0], 73 | [65.0, 130.0, 148.0, 74.0], 74 | [32.5, 65.0, 74.0, 37.0], 75 | ] 76 | self._guided_grad_cam_test_assert(net, net.conv1, (inp, inp2), (ex, [])) 77 | 78 | def test_improper_method_multi_input_conv(self) -> None: 79 | net = BasicModel_ConvNet_One_Conv() 80 | inp = torch.arange(16, dtype=torch.float).view(1, 1, 4, 4) 81 | inp2 = torch.ones(1) 82 | self._guided_grad_cam_test_assert( 83 | net, net.conv1, (inp, inp2), ([], []), interpolate_mode="triilinear" 84 | ) 85 | 86 | def _guided_grad_cam_test_assert( 87 | self, 88 | model: Module, 89 | target_layer: Module, 90 | test_input: TensorOrTupleOfTensorsGeneric, 91 | expected, 92 | additional_input: Any = None, 93 | interpolate_mode: str = "nearest", 94 | attribute_to_layer_input: bool = False, 95 | ) -> None: 96 | guided_gc = GuidedGradCam(model, target_layer) 97 | attributions = guided_gc.attribute( 98 | test_input, 99 | target=0, 100 | additional_forward_args=additional_input, 101 | interpolate_mode=interpolate_mode, 102 | attribute_to_layer_input=attribute_to_layer_input, 103 | ) 104 | if isinstance(test_input, tuple): 105 | for i in range(len(test_input)): 106 | assertTensorAlmostEqual(self, attributions[i], expected[i], delta=0.01) 107 | else: 108 | assertTensorAlmostEqual(self, attributions, expected, delta=0.01) 109 | 110 | 111 | if __name__ == "__main__": 112 | unittest.main() 113 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/test_input_x_gradient.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from typing import Any 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.nn import Module 7 | 8 | from captum.attr._core.input_x_gradient import InputXGradient 9 | from captum.attr._core.noise_tunnel import NoiseTunnel 10 | from captum.attr._utils.typing import TensorOrTupleOfTensorsGeneric 11 | 12 | from .helpers.classification_models import SoftmaxModel 13 | from .helpers.utils import BaseTest, assertArraysAlmostEqual 14 | from .test_saliency import _get_basic_config, _get_multiargs_basic_config 15 | 16 | 17 | class Test(BaseTest): 18 | def test_input_x_gradient_test_basic_vanilla(self) -> None: 19 | self._input_x_gradient_base_assert(*_get_basic_config()) 20 | 21 | def test_input_x_gradient_test_basic_smoothgrad(self) -> None: 22 | self._input_x_gradient_base_assert(*_get_basic_config(), nt_type="smoothgrad") 23 | 24 | def test_input_x_gradient_test_basic_vargrad(self) -> None: 25 | self._input_x_gradient_base_assert(*_get_basic_config(), nt_type="vargrad") 26 | 27 | def test_saliency_test_basic_multi_variable_vanilla(self) -> None: 28 | self._input_x_gradient_base_assert(*_get_multiargs_basic_config()) 29 | 30 | def test_saliency_test_basic_multi_variable_smoothgrad(self) -> None: 31 | self._input_x_gradient_base_assert( 32 | *_get_multiargs_basic_config(), nt_type="smoothgrad" 33 | ) 34 | 35 | def test_saliency_test_basic_multi_vargrad(self) -> None: 36 | self._input_x_gradient_base_assert( 37 | *_get_multiargs_basic_config(), nt_type="vargrad" 38 | ) 39 | 40 | def test_input_x_gradient_classification_vanilla(self) -> None: 41 | self._input_x_gradient_classification_assert() 42 | 43 | def test_input_x_gradient_classification_smoothgrad(self) -> None: 44 | self._input_x_gradient_classification_assert(nt_type="smoothgrad") 45 | 46 | def test_input_x_gradient_classification_vargrad(self) -> None: 47 | self._input_x_gradient_classification_assert(nt_type="vargrad") 48 | 49 | def _input_x_gradient_base_assert( 50 | self, 51 | model: Module, 52 | inputs: TensorOrTupleOfTensorsGeneric, 53 | expected_grads: TensorOrTupleOfTensorsGeneric, 54 | additional_forward_args: Any = None, 55 | nt_type: str = "vanilla", 56 | ) -> None: 57 | input_x_grad = InputXGradient(model) 58 | if nt_type == "vanilla": 59 | attributions = input_x_grad.attribute( 60 | inputs, additional_forward_args=additional_forward_args 61 | ) 62 | else: 63 | nt = NoiseTunnel(input_x_grad) 64 | attributions = nt.attribute( 65 | inputs, 66 | nt_type=nt_type, 67 | n_samples=10, 68 | stdevs=0.0002, 69 | additional_forward_args=additional_forward_args, 70 | ) 71 | 72 | if isinstance(attributions, tuple): 73 | for input, attribution, expected_grad in zip( 74 | inputs, attributions, expected_grads 75 | ): 76 | if nt_type == "vanilla": 77 | assertArraysAlmostEqual( 78 | attribution.reshape(-1), (expected_grad * input).reshape(-1) 79 | ) 80 | self.assertEqual(input.shape, attribution.shape) 81 | elif isinstance(attributions, Tensor): 82 | if nt_type == "vanilla": 83 | assertArraysAlmostEqual( 84 | attributions.reshape(-1), 85 | (expected_grads * inputs).reshape(-1), 86 | delta=0.5, 87 | ) 88 | self.assertEqual(inputs.shape, attributions.shape) 89 | 90 | def _input_x_gradient_classification_assert( 91 | self, nt_type: str = "vanilla", 92 | ) -> None: 93 | num_in = 5 94 | input = torch.tensor([[0.0, 1.0, 2.0, 3.0, 4.0]], requires_grad=True) 95 | target = torch.tensor(5) 96 | 97 | # 10-class classification model 98 | model = SoftmaxModel(num_in, 20, 10) 99 | input_x_grad = InputXGradient(model.forward) 100 | if nt_type == "vanilla": 101 | attributions = input_x_grad.attribute(input, target) 102 | output = model(input)[:, target] 103 | output.backward() 104 | expercted = input.grad * input 105 | self.assertEqual( 106 | expercted.detach().numpy().tolist(), 107 | attributions.detach().numpy().tolist(), 108 | ) 109 | else: 110 | nt = NoiseTunnel(input_x_grad) 111 | attributions = nt.attribute( 112 | input, nt_type=nt_type, n_samples=10, stdevs=1.0, target=target 113 | ) 114 | 115 | self.assertEqual(attributions.shape, input.shape) 116 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tests/attr/test_summarizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | 4 | from captum.attr import CommonSummarizer 5 | 6 | from .helpers.utils import BaseTest 7 | 8 | 9 | class Test(BaseTest): 10 | def test_single_input(self): 11 | size = (2, 3) 12 | summarizer = CommonSummarizer() 13 | for _ in range(10): 14 | attrs = torch.randn(size) 15 | summarizer.update(attrs) 16 | 17 | summ = summarizer.summary 18 | self.assertIsNotNone(summ) 19 | self.assertTrue(isinstance(summ, dict)) 20 | 21 | for k in summ: 22 | self.assertTrue(summ[k].size() == size) 23 | 24 | def test_multi_input(self): 25 | size1 = (10, 5, 5) 26 | size2 = (3, 5) 27 | 28 | summarizer = CommonSummarizer() 29 | for _ in range(10): 30 | a1 = torch.randn(size1) 31 | a2 = torch.randn(size2) 32 | summarizer.update((a1, a2)) 33 | 34 | summ = summarizer.summary 35 | self.assertIsNotNone(summ) 36 | self.assertTrue(len(summ) == 2) 37 | self.assertTrue(isinstance(summ[0], dict)) 38 | self.assertTrue(isinstance(summ[1], dict)) 39 | 40 | for k in summ[0]: 41 | self.assertTrue(summ[0][k].size() == size1) 42 | self.assertTrue(summ[1][k].size() == size2) 43 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/README.md: -------------------------------------------------------------------------------- 1 | This folder contains Jupyter tutorial notebooks. These notebooks are also 2 | parsed and embedded into the captum website. 3 | -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/bert/visuals_of_start_end_predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/bert/visuals_of_start_end_predictions.png -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/captum_insights.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/captum_insights.png -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/captum_insights_vqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/captum_insights_vqa.png -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/resnet/swan-3299528_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/resnet/swan-3299528_1280.jpg -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/resnet/swan-3299528_1280_attribution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/resnet/swan-3299528_1280_attribution.jpg -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/sentiment_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/sentiment_analysis.png -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/vqa/elephant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/vqa/elephant.jpg -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/vqa/elephant_attribution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/vqa/elephant_attribution.jpg -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/vqa/siamese.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/vqa/siamese.jpg -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/vqa/siamese_attribution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/vqa/siamese_attribution.jpg -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/vqa/zebra.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/vqa/zebra.jpg -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/img/vqa/zebra_attribution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/img/vqa/zebra_attribution.jpg -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/models/boston_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/models/boston_model.pt -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/models/cifar_torchvision.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/models/cifar_torchvision.pt -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/models/imdb-model-cnn.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/models/imdb-model-cnn.pt -------------------------------------------------------------------------------- /src/docker/captum_xil/tutorials/models/titanic_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-research/NeSyXIL/8a8b2a5e2c8ce57ec73d68ca76d4d95a65530cdd/src/docker/captum_xil/tutorials/models/titanic_model.pt -------------------------------------------------------------------------------- /src/docker/captum_xil/website/core/Footer.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017-present, Facebook, Inc. 3 | * 4 | * This source code is licensed under the BSD license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | * 7 | * @format 8 | */ 9 | 10 | const PropTypes = require("prop-types"); 11 | const React = require("react"); 12 | 13 | function SocialFooter(props) { 14 | const repoUrl = `https://github.com/${props.config.organizationName}/${props.config.projectName}`; 15 | return ( 16 |
17 |
Social
18 |
19 | 27 | {props.config.projectName} 28 | 29 |
30 |
31 | ); 32 | } 33 | 34 | SocialFooter.propTypes = { 35 | config: PropTypes.object 36 | }; 37 | 38 | class Footer extends React.Component { 39 | docUrl(doc, language) { 40 | const baseUrl = this.props.config.baseUrl; 41 | const docsUrl = this.props.config.docsUrl; 42 | const docsPart = `${docsUrl ? `${docsUrl}/` : ""}`; 43 | const langPart = `${language ? `${language}/` : ""}`; 44 | return `${baseUrl}${docsPart}${langPart}${doc}`; 45 | } 46 | 47 | pageUrl(doc, language) { 48 | const baseUrl = this.props.config.baseUrl; 49 | return baseUrl + (language ? `${language}/` : "") + doc; 50 | } 51 | 52 | render() { 53 | const currentYear = new Date().getFullYear(); 54 | return ( 55 |