├── .gitattributes ├── .gitignore ├── .travis.yml ├── LICENSE.md ├── README.md ├── docs ├── about.md ├── global_explainers.md ├── gradient_backprop.md ├── gradient_cam.md ├── gradient_explainer.md ├── index.md ├── tutorial_global.md ├── tutorial_local.md ├── xdeep_image.md ├── xdeep_tabular.md └── xdeep_text.md ├── result_img └── ensemble_fig.png ├── setup.py ├── tests ├── __init__.py ├── xglobal │ ├── __init__.py │ ├── images │ │ └── jay.jpg │ ├── results │ │ ├── deepdream.jpg │ │ ├── filter.jpg │ │ ├── inverted │ │ │ ├── Inverted_Iteration_0.jpg │ │ │ ├── Inverted_Iteration_1.jpg │ │ │ ├── Inverted_Iteration_2.jpg │ │ │ ├── Inverted_Iteration_3.jpg │ │ │ ├── Inverted_Iteration_4.jpg │ │ │ ├── Inverted_Iteration_5.jpg │ │ │ ├── Inverted_Iteration_6.jpg │ │ │ ├── Inverted_Iteration_7.jpg │ │ │ ├── Inverted_Iteration_8.jpg │ │ │ └── Inverted_Iteration_9.jpg │ │ ├── layer.jpg │ │ └── logit.jpg │ └── test_global.py ├── xlocal_gradient │ ├── __init__.py │ ├── images │ │ └── ILSVRC2012_val_00000073.JPEG │ ├── results │ │ ├── GuidedBP_ILSVRC2012_val_00000073.JPEG │ │ └── gradcam_ILSVRC2012_val_00000073.JPEG │ ├── test_explainers.py │ ├── test_gradcam.py │ └── test_guided_backprop.py └── xlocal_perturbation │ ├── __init__.py │ ├── data │ ├── Eskimo dog, husky.JPEG │ ├── Great Pyrenees.JPEG │ ├── Ibizan hound.JPEG │ ├── adult │ │ └── adult.data │ ├── bee eater.JPEG │ ├── car mirror.JPEG │ ├── eggnog.JPEG │ ├── kitten.jpg │ ├── oystercatcher.jpg │ ├── rt-polaritydata │ │ ├── rt-polaritydata.README.1.0.txt │ │ └── rt-polaritydata │ │ │ ├── rt-polarity.neg │ │ │ └── rt-polarity.pos │ ├── sea_snake.JPEG │ ├── sky.jpg │ ├── tusker.JPEG │ └── violin.JPEG │ ├── image_util │ ├── imagenet_lsvrc_2015_synsets.txt │ └── imagenet_metadata.txt │ ├── test_image.py │ ├── test_tabular.py │ └── test_text.py ├── tutorial ├── notebook │ ├── image - inceptionV3-checkpoint.ipynb │ ├── image - tutorial-checkpoint.ipynb │ ├── tabular - RF-checkpoint.ipynb │ ├── tabular - tutorial-checkpoint.ipynb │ ├── text - Bert-checkpoint.ipynb │ ├── text - LR - CountVetorizer-checkpoint.ipynb │ ├── text - LSTM - Word2Vec-checkpoint.ipynb │ ├── text - RF - CountVectorizer-checkpoint.ipynb │ ├── text - SVC - CountVetorizer-checkpoint.ipynb │ └── text - tutorial-checkpoint.ipynb ├── xglobal │ ├── images │ │ └── jay.jpg │ ├── results │ │ ├── deepdream.jpg │ │ ├── filter.jpg │ │ ├── inverted │ │ │ ├── Inverted_Iteration_0.jpg │ │ │ ├── Inverted_Iteration_1.jpg │ │ │ ├── Inverted_Iteration_2.jpg │ │ │ ├── Inverted_Iteration_3.jpg │ │ │ ├── Inverted_Iteration_4.jpg │ │ │ ├── Inverted_Iteration_5.jpg │ │ │ ├── Inverted_Iteration_6.jpg │ │ │ ├── Inverted_Iteration_7.jpg │ │ │ ├── Inverted_Iteration_8.jpg │ │ │ └── Inverted_Iteration_9.jpg │ │ ├── layer.jpg │ │ └── logit.jpg │ └── test_global.py ├── xlocal_gradient │ ├── images │ │ └── ILSVRC2012_val_00000073.JPEG │ ├── results │ │ ├── GuidedBP_ILSVRC2012_val_00000073.JPEG │ │ └── gradcam_ILSVRC2012_val_00000073.JPEG │ ├── test_explainers.py │ ├── test_gradcam.py │ └── test_guided_backprop.py └── xlocal_perturbation │ ├── data │ ├── Eskimo dog, husky.JPEG │ ├── Great Pyrenees.JPEG │ ├── Ibizan hound.JPEG │ ├── adult │ │ └── adult.data │ ├── bee eater.JPEG │ ├── car mirror.JPEG │ ├── eggnog.JPEG │ ├── kitten.jpg │ ├── oystercatcher.jpg │ ├── rt-polaritydata │ │ ├── rt-polaritydata.README.1.0.txt │ │ └── rt-polaritydata │ │ │ ├── rt-polarity.neg │ │ │ └── rt-polarity.pos │ ├── sea_snake.JPEG │ ├── sky.jpg │ ├── tusker.JPEG │ └── violin.JPEG │ ├── image - tutorial.html │ ├── image - tutorial.ipynb │ ├── image - tutorial.py │ ├── image_util │ ├── .DS_Store │ ├── imagenet_lsvrc_2015_synsets.txt │ ├── imagenet_metadata.txt │ ├── inception_preprocessing.py │ └── inception_v3.py │ ├── tabular - tutorial.html │ ├── tabular - tutorial.ipynb │ ├── tabular - tutorial.py │ ├── text - tutorial.html │ ├── text - tutorial.ipynb │ └── text - tutorial.py └── xdeep ├── __init__.py ├── utils ├── __init__.py ├── imagenet.py └── resources │ ├── __init__.py │ └── imagenet_class_index.json ├── xglobal ├── __init__.py ├── global_explainers.py └── methods │ ├── __init__.py │ ├── activation.py │ └── inverted_representation.py └── xlocal ├── __init__.py ├── gradient ├── __init__.py ├── backprop │ ├── __init__.py │ ├── base.py │ ├── guided_backprop.py │ ├── integrate_grad.py │ └── smooth_grad.py ├── cam │ ├── __init__.py │ ├── basecam.py │ ├── gradcam.py │ ├── gradcampp.py │ └── scorecam.py └── explainers.py └── perturbation ├── .ipynb_checkpoints ├── Image_Tutorial-checkpoint.ipynb ├── Tutorial - image-checkpoint.ipynb ├── Tutorial - tabular-checkpoint.ipynb └── Tutorial - text-checkpoint.ipynb ├── __init__.py ├── cle ├── __init__.py ├── cle_image.py ├── cle_tabular.py └── cle_text.py ├── exceptions.py ├── explainer.py ├── image_explainers ├── __init__.py ├── anchor_image.py ├── cle_image.py ├── lime_image.py └── shap_image.py ├── tabular_explainers ├── __init__.py ├── anchor_tabular.py ├── cle_tabular.py ├── lime_tabular.py └── shap_tabular.py ├── text_explainers ├── __init__.py ├── anchor_text.py ├── cle_text.py ├── lime_text.py └── shap_text.py ├── xdeep_base.py ├── xdeep_image.py ├── xdeep_tabular.py └── xdeep_text.py /.gitattributes: -------------------------------------------------------------------------------- 1 | tests/xlocal_perturbation/model/inception_v3.ckpt filter=lfs diff=lfs merge=lfs -text 2 | tutorial/xlocal_perturbation/model/inception_v3.ckpt filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | before_script: 4 | - pip install -e .[tests] 5 | - pip install python-coveralls 6 | - pip install pytest-cover 7 | script: 8 | - py.test tests/ 9 | after_success: 10 | - coveralls -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Data Analytics Lab at Texas A&M University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # XDeep 2 | ## -- *for interpretable deep learning developers* 3 | 4 | [![Codacy Badge](https://api.codacy.com/project/badge/Grade/2c0aff755250450c90ba167987aaebe5)](https://www.codacy.com/manual/nacoyang/xdeep?utm_source=github.com&utm_medium=referral&utm_content=datamllab/xdeep&utm_campaign=Badge_Grade) 5 | [![Build Status](https://travis-ci.com/datamllab/xdeep.svg?branch=master)](https://travis-ci.com/datamllab/xdeep) 6 | 7 | XDeep is an open source Python library for **Interpretable Machine Learning**. It is developed by [DATA Lab](http://faculty.cs.tamu.edu/xiahu/index.html) at Texas A&M University. The goal of XDeep is to provide easily accessible interpretation tools for people who want to figure out how deep models work. XDeep provides a variety of methods to interpret a model both locally and globally. 8 | 9 | ## Installation 10 | 11 | To install the package, please use the pip installation as follows (https://pypi.org/project/x-deep/): 12 | 13 | pip install x-deep 14 | 15 | **Note**: currently, XDeep is only compatible with: **Python 3.6**. 16 | 17 | ## Example 18 | 19 | Here is a short example of using the package. 20 | 21 | ```python 22 | import torchvision.models as models 23 | from xdeep.xlocal.gradient.explainers import * 24 | 25 | # load the input image & target deep model 26 | image = load_image('input.jpg') 27 | model = models.vgg16(pretrained=True) 28 | 29 | # build the xdeep explainer 30 | model_explainer = ImageInterpreter(model) 31 | 32 | # generate the local interpretation 33 | model_explainer.explain(image, method_name='gradcam', target_layer_name='features_29', viz=True) 34 | ``` 35 | 36 | For detailed tutorial, please check the docs directory of this repository [here](https://github.com/datamllab/xdeep/tree/master/docs). 37 | 38 | ## Sample Results 39 | 40 | 41 | 42 | ## Cite this work 43 | 44 | Fan Yang, Zijian Zhang, Haofan Wang, Yuening Li, Xia Hu "XDeep: An Interpretation Tool for Deep Neural Networks." arXiv:1911.01005, 2019. ([Download](https://arxiv.org/abs/1911.01005)) 45 | 46 | Biblatex entry: 47 | 48 | @article{yang2019xdeep, 49 | title={XDeep: An Interpretation Tool for Deep Neural Networks}, 50 | author={Yang, Fan and Zhang, Zijian and Wang, Haofan and Li, Yuening and Hu, Xia}, 51 | journal={arXiv preprint arXiv:1911.01005}, 52 | year={2019} 53 | } 54 | 55 | ## DISCLAIMER 56 | 57 | Please note that this is a **pre-release** version of the XDeep which is still undergoing final testing before its official release. The website, its software and all contents found on it are provided on an 58 | “as is” and “as available” basis. XDeep does **not** give any warranties, whether express or implied, as to the suitability or usability of the website, its software or any of its content. XDeep will **not** be liable for any loss, whether such loss is direct, indirect, special or consequential, suffered by any party as a result of their use of the libraries or content. Any usage of the libraries is done at the user’s own risk and the user will be solely responsible for any damage to any computer system or loss of data that results from such activities. Should you encounter any bugs, glitches, lack of functionality or 59 | other problems on the website, please let us know immediately so we 60 | can rectify these accordingly. Your help in this regard is greatly appreciated. 61 | 62 | ## Contact 63 | 64 | Please send emails to *nacoyang@tamu.edu* if you have any issues for this project. Thanks. 65 | 66 | ## Acknowledgements 67 | 68 | The authors gratefully acknowledge the XAI program of the Defense Advanced Research Projects Agency (DARPA) administered through grant N66001-17-2-4031; the Texas A&M College of Engineering, and Texas A&M. 69 | -------------------------------------------------------------------------------- /docs/about.md: -------------------------------------------------------------------------------- 1 | #About 2 | 3 | This package is developed by [DATA LAB](http://faculty.cs.tamu.edu/xiahu/) at Texas A&M University. 4 | 5 | ## Main Contributors 6 | 7 | [**Fan Yang**](./): 8 | Overall package design. 9 | 10 | [**Zijian Zhang**](./): 11 | Developer for perturbation-based local explanation part. 12 | 13 | [**Haofan Wang**](./): 14 | Developer for gradient-based local explanation and global explanation part. 15 | 16 | [**Yuening Li**](./): 17 | Code review and testing. 18 | 19 | [**Xia "Ben" Hu**](./): 20 | Advisor. 21 | -------------------------------------------------------------------------------- /docs/global_explainers.md: -------------------------------------------------------------------------------- 1 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xglobal/global_explainers.py#L11) 2 | ## GlobalImageInterpreter class 3 | 4 | ```python 5 | xdeep.xglobal.global_explainers.GlobalImageInterpreter(model) 6 | ``` 7 | 8 | Class for global image explanations. 9 | 10 | GlobalImageInterpreter provides unified interface to call different visualization methods. 11 | The user can use it by several lines. It includes '__init__()' and 'explain()'. The 12 | users can specify the name of visualization method and target layer. If params are 13 | not specified, default params will be used. 14 | 15 | 16 | --- 17 | ## GlobalImageInterpreter methods 18 | 19 | ### explain 20 | 21 | 22 | ```python 23 | explain(method_name=None, target_layer=None, target_filter=None, input_=None, num_iter=10, save_path=None) 24 | ``` 25 | 26 | 27 | Function to call different visualization methods. 28 | 29 | __Arguments__ 30 | 31 | - __method_name__: str. The name of interpreter method. Currently, global explanation methods support 'filter', 32 | 'layer', 'logit', 'deepdream', 'inverted'. 33 | - __target_layer__: torch.nn.Linear or torch.nn.conv.Conv2d. The objective layer. 34 | - __target_filter__: int or list. Index of filter or filters. 35 | - __input___: str or Tensor. Path of input image or normalized tensor. Default to be None. 36 | - __num_iter__: int. Iter times. Default to 10. 37 | - __save_path__: str. Path to save generated explanations. Default to None. 38 | 39 | --- 40 | ### inverted_feature 41 | 42 | 43 | ```python 44 | inverted_feature(model) 45 | ``` 46 | 47 | --- 48 | ### maximize_activation 49 | 50 | 51 | ```python 52 | maximize_activation(model) 53 | ``` 54 | -------------------------------------------------------------------------------- /docs/gradient_backprop.md: -------------------------------------------------------------------------------- 1 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/backprop/base.py#L5) 2 | ## BaseProp class 3 | 4 | ```python 5 | xdeep.xlocal.gradient.backprop.base.BaseProp(model) 6 | ``` 7 | 8 | 9 | Base class for backpropagation. 10 | 11 | ---- 12 | 13 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/backprop/guided_backprop.py#L4) 14 | ## Backprop class 15 | 16 | ```python 17 | xdeep.xlocal.gradient.backprop.guided_backprop.Backprop(model, guided=False) 18 | ``` 19 | 20 | Generates vanilla or guided backprop gradients of a target class output w.r.t. an input image. 21 | 22 | __Arguments:__ 23 | 24 | - __model__: torchvision.models. A pretrained model. 25 | - __guided__: bool. If True, perform guided backpropagation. Defaults to False. 26 | 27 | __Return:__ 28 | 29 | Backprop Class. 30 | 31 | 32 | --- 33 | ## Backprop methods 34 | 35 | ### calculate_gradients 36 | 37 | 38 | ```python 39 | calculate_gradients(input_, target_class=None, take_max=False, use_gpu=False) 40 | ``` 41 | 42 | 43 | Calculate gradient. 44 | 45 | __Arguments__ 46 | 47 | - __input___: torch.Tensor. Preprocessed image with shape (1, C, H, W). 48 | - __target_class__: int. Index of target class. Default to None and use the prediction result as target class. 49 | - __take_max__: bool. Take the maximum across colour channels. Defaults to False. 50 | 51 | use_gpu. bool. Use GPU or not. Defaults to False. 52 | 53 | __Return:__ 54 | 55 | Gradient (torch.Tensor) with shape (C, H, W). If take max is True, with shape (1, H, W). 56 | 57 | ---- 58 | 59 | ### generate_integrated_grad 60 | 61 | 62 | ```python 63 | xdeep.xlocal.gradient.backprop.integrate_grad.generate_integrated_grad(explainer, input_, target_class=None, n=25) 64 | ``` 65 | 66 | 67 | Generates integrate gradients of given explainer. You can use this with both vanilla 68 | and guided backprop 69 | 70 | __Arguments:__ 71 | 72 | - __explainer__: class. Backprop method. 73 | - __input___: torch.Tensor. Preprocessed image with shape (N, C, H, W). 74 | - __target_class__: int. Index of target class. Default to None. 75 | - __n__: int. Integrate steps. Default to 10. 76 | 77 | __Return:__ 78 | 79 | Integrated gradient (torch.Tensor) with shape (C, H, W). 80 | 81 | ---- 82 | 83 | ### generate_smooth_grad 84 | 85 | 86 | ```python 87 | xdeep.xlocal.gradient.backprop.smooth_grad.generate_smooth_grad(explainer, input_, target_class=None, n=50, mean=0, sigma_multiplier=4) 88 | ``` 89 | 90 | 91 | Generates smooth gradients of given explainer. 92 | 93 | __Arguments__ 94 | 95 | - __explainer__: class. Backprop method. 96 | - __input___: torch.Tensor. Preprocessed image with shape (1, C, H, W). 97 | - __target_class__: int. Index of target class. Defaults to None. 98 | - __n__: int. Amount of noisy images used to smooth gradient. Default to 10. 99 | - __mean__: int. Mean value of normal distribution when generating noise. Default to 0. 100 | 101 | sigma_multiplier. int. Sigma multiplier when calculating std of noise. Default to 4. 102 | 103 | __Return:__ 104 | 105 | Smooth gradient (torch.Tensor) with shape (C, H, W). 106 | -------------------------------------------------------------------------------- /docs/gradient_cam.md: -------------------------------------------------------------------------------- 1 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/cam/basecam.py#L3) 2 | ## BaseCAM class 3 | 4 | ```python 5 | xdeep.xlocal.gradient.cam.basecam.BaseCAM(model_dict) 6 | ``` 7 | 8 | 9 | Base class for Class Activation Mapping. 10 | 11 | 12 | --- 13 | ## BaseCAM methods 14 | 15 | ### find_layer 16 | 17 | 18 | ```python 19 | find_layer(arch, target_layer_name) 20 | ``` 21 | 22 | --- 23 | ### forward 24 | 25 | 26 | ```python 27 | forward(input_, class_idx=None, retain_graph=False) 28 | ``` 29 | 30 | ---- 31 | 32 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/cam/gradcam.py#L5) 33 | ## GradCAM class 34 | 35 | ```python 36 | xdeep.xlocal.gradient.cam.gradcam.GradCAM(model_dict) 37 | ``` 38 | 39 | 40 | GradCAM, inherit from BaseCAM 41 | 42 | 43 | --- 44 | ## GradCAM methods 45 | 46 | ### find_layer 47 | 48 | 49 | ```python 50 | find_layer(arch, target_layer_name) 51 | ``` 52 | 53 | --- 54 | ### forward 55 | 56 | 57 | ```python 58 | forward(input_, class_idx=None, retain_graph=False) 59 | ``` 60 | 61 | 62 | Generates GradCAM result. 63 | 64 | __Arguments__ 65 | 66 | - __input___: torch.Tensor. Preprocessed image with shape (1, C, H, W). 67 | - __class_idx__: int. Index of target class. Defaults to be index of predicted class. 68 | 69 | __Return__ 70 | 71 | Result of GradCAM (torch.Tensor) with shape (1, H, W). 72 | 73 | ---- 74 | 75 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/cam/gradcampp.py#L6) 76 | ## GradCAMpp class 77 | 78 | ```python 79 | xdeep.xlocal.gradient.cam.gradcampp.GradCAMpp(model_dict) 80 | ``` 81 | 82 | 83 | GradCAM++, inherit from BaseCAM 84 | 85 | 86 | --- 87 | ## GradCAMpp methods 88 | 89 | ### find_layer 90 | 91 | 92 | ```python 93 | find_layer(arch, target_layer_name) 94 | ``` 95 | 96 | --- 97 | ### forward 98 | 99 | 100 | ```python 101 | forward(input, class_idx=None, retain_graph=False) 102 | ``` 103 | 104 | 105 | Generates GradCAM++ result. 106 | 107 | __Arguments__ 108 | 109 | - __input___: torch.Tensor. Preprocessed image with shape (1, C, H, W). 110 | - __class_idx__: int. Index of target class. Defaults to be index of predicted class. 111 | 112 | __Return__ 113 | 114 | Result of GradCAM++ (torch.Tensor) with shape (1, H, W). 115 | 116 | ---- 117 | 118 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/cam/scorecam.py#L6) 119 | ## ScoreCAM class 120 | 121 | ```python 122 | xdeep.xlocal.gradient.cam.scorecam.ScoreCAM(model_dict) 123 | ``` 124 | 125 | 126 | ScoreCAM, inherit from BaseCAM 127 | 128 | 129 | --- 130 | ## ScoreCAM methods 131 | 132 | ### find_layer 133 | 134 | 135 | ```python 136 | find_layer(arch, target_layer_name) 137 | ``` 138 | 139 | --- 140 | ### forward 141 | 142 | 143 | ```python 144 | forward(input, class_idx=None, retain_graph=False) 145 | ``` 146 | 147 | 148 | Generates ScoreCAM result. 149 | 150 | __Arguments__ 151 | 152 | - __input___: torch.Tensor. Preprocessed image with shape (1, C, H, W). 153 | - __class_idx__: int. Index of target class. Defaults to be index of predicted class. 154 | 155 | __Return__ 156 | 157 | Result of GradCAM (torch.Tensor) with shape (1, H, W). 158 | -------------------------------------------------------------------------------- /docs/gradient_explainer.md: -------------------------------------------------------------------------------- 1 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/explainers.py#L18) 2 | ## ImageInterpreter class 3 | 4 | ```python 5 | xdeep.xlocal.gradient.explainers.ImageInterpreter(model) 6 | ``` 7 | 8 | Class for image explanation 9 | 10 | ImageInterpreter provides unified interface to call different visualization methods. 11 | The user can use it by several lines. It includes '__init__()' and 'explain()'. The 12 | users can specify the name of visualization method and target layer. If params are 13 | not specified, default params will be used. 14 | 15 | 16 | --- 17 | ## ImageInterpreter methods 18 | 19 | ### explain 20 | 21 | 22 | ```python 23 | explain(image, method_name, viz=True, target_layer_name=None, target_class=None, save_path=None) 24 | ``` 25 | 26 | 27 | Function to call different local visualization methods. 28 | 29 | __Arguments__ 30 | 31 | - __image__: str or PIL.Image. The input image to ImageInterpreter. User can directly provide the path of input 32 | image, or provide PIL.Image foramt image. For example, image can be './test.jpg' or 33 | Image.open('./test.jpg').convert('RGB'). 34 | - __method_name__: str. The name of interpreter method. Currently support for 'vallina_backprop', 'guided_backprop', 35 | 'smooth_grad', 'smooth_guided_grad', 'integrate_grad', 'integrate_guided_grad', 'gradcam', 36 | 'gradcampp', 'scorecam'. 37 | - __viz__: bool. Visualize or not. Defaults to True. 38 | - __target_layer_name__: str. The layer to hook gradients and activation map. Defaults to the name of the latest 39 | activation map. User can also provide their target layer like 'features_29' or 'layer4' 40 | with respect to different network architectures. 41 | - __target_class__: int. The index of target class. Default to be the index of predicted class. 42 | - __save_path__: str. Path to save the saliency map. Default to be None. 43 | 44 | --- 45 | ### gradcam 46 | 47 | 48 | ```python 49 | gradcam(model_dict) 50 | ``` 51 | 52 | --- 53 | ### gradcampp 54 | 55 | 56 | ```python 57 | gradcampp(model_dict) 58 | ``` 59 | 60 | --- 61 | ### guided_backprop 62 | 63 | 64 | ```python 65 | guided_backprop(model) 66 | ``` 67 | 68 | --- 69 | ### integrate_grad 70 | 71 | 72 | ```python 73 | integrate_grad(model, input_, guided=False) 74 | ``` 75 | 76 | --- 77 | ### integrate_guided_grad 78 | 79 | 80 | ```python 81 | integrate_guided_grad(model, input_, guided=True) 82 | ``` 83 | 84 | --- 85 | ### scorecam 86 | 87 | 88 | ```python 89 | scorecam(model_dict) 90 | ``` 91 | 92 | --- 93 | ### smooth_grad 94 | 95 | 96 | ```python 97 | smooth_grad(model, input_, guided=False) 98 | ``` 99 | 100 | --- 101 | ### smooth_guided_grad 102 | 103 | 104 | ```python 105 | smooth_guided_grad(model, input_, guided=True) 106 | ``` 107 | 108 | --- 109 | ### vallina_backprop 110 | 111 | 112 | ```python 113 | vallina_backprop(model) 114 | ``` 115 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to XDeep 2 | 3 | XDeep is an open source software library for automated Interpretable Machine Learning. It is developed by [DATA Lab](http://faculty.cs.tamu.edu/xiahu/index.html) at Texas A&M University. The ultimate goal of XDeep is to provide easily accessible interpretation tools to people who want to figure out why the model predicts so. XDeep provides a variety of methods to interpret a model locally and globally. 4 | 5 | ## Installation 6 | 7 | To install the package, please use the pip installation as follows: 8 | 9 | pip install x-deep 10 | 11 | **Note**: currently, XDeep is only compatible with: **Python 3.6**. 12 | 13 | ## Interpretation Methods 14 | 15 | * Local 16 | - Gradient 17 | 1. CAM\-Based 18 | - Grad-CAM 19 | - Grad-CAM++ 20 | - Score-CAM 21 | 2. Gradient\-Based 22 | - Vanilla Backpropagation 23 | - Guided Backpropagation 24 | - Smooth Gradient 25 | - Integrate Gradient 26 | 27 | - Perturbation 28 | 1. LIME 29 | 2. Anchors 30 | 3. SHAP 31 | 4. CLE 32 | 33 | * Global 34 | - Maximum Activation(filter, layer, logit, deepdream) 35 | - Inverted Features 36 | 37 | ## Example 38 | 39 | Here is a short example of using the package. 40 | 41 | from xdeep.xlocal.perturbation.xdeep_tabular import TabularExplainer 42 | 43 | iris = sklearn.datasets.load_iris() 44 | train, test, labels_train, labels_test = 45 | sklearn.model_selection.train_test_split(iris.data, iris.target, train_size=0.80) 46 | rf = sklearn.ensemble.RandomForestClassifier(n_estimators=500) 47 | rf.fit(train, labels_train) 48 | 49 | explainer = TabularExplainer(rf.predict_proba, 50 | train, feature_names=iris.feature_names, class_names=iris.target_names) 51 | explainer.set_lime_parameters(discretize_continuous=True) 52 | explainer.explain('lime', test[0]) 53 | explainer.show_explanation('lime') 54 | 55 | For detailed tutorial, please check [here](./tutorial_local) for local ones and [here](./tutorial_global) for global ones. 56 | -------------------------------------------------------------------------------- /docs/tutorial_global.md: -------------------------------------------------------------------------------- 1 | # Tutorial 2 | 3 | ## Global \- Maximum Activation 4 | 5 | ### GlobalExplainer 6 | 7 | from xdeep.xglobal.global_explaners import * 8 | import torchvision.models as models 9 | 10 | # load model 11 | model = models.vgg16(pretrained=True) 12 | 13 | # load global image interpreter 14 | model_explainer = GlobalImageInterpreter(model) 15 | 16 | # generate global explanation via activation maximum 17 | model_explainer.explain(method_name='filter', target_layer='features_24', target_filter=20, num_iter=10, save_path='results/filter.jpg') 18 | 19 | model_explainer.explain(method_name='layer', target_layer='features_24', num_iter=10, save_path='results/layer.jpg') 20 | 21 | model_explainer.explain(method_name='logit', target_layer='features_24', target_filter=20, num_iter=10, save_path='results/logit.jpg') 22 | 23 | model_explainer.explain(method_name='deepdream', target_layer='features_24', target_filter=20, input_='images/jay.jpg', num_iter=10, save_path='results/deepdream.jpg') 24 | 25 | # generate global explanation via inverted representation 26 | model_explainer.explain(method_name='inverted', target_layer='features_24', input_='images/jay.jpg', num_iter=10) 27 | -------------------------------------------------------------------------------- /docs/tutorial_local.md: -------------------------------------------------------------------------------- 1 | # Tutorial 2 | 3 | ## Local \- Gradient 4 | 5 | import torchvision.models as models 6 | from xdeep.xlocal.gradient.explainers import * 7 | 8 | # load image 9 | image = load_image('tutorial/xlocal.gradient/images/ILSVRC2012_val_00000073.JPEG') 10 | 11 | # load model 12 | model = models.vgg16(pretrained=True) 13 | 14 | # load embedded image interpreter 15 | model_explainer = ImageInterpreter(model) 16 | 17 | # generate saliency map of method specified by 'method_name' 18 | model_explainer.explain(image, method_name='vallina_backprop', viz=True, save_path='results/bp.jpg') 19 | 20 | model_explainer.explain(image, method_name='guided_backprop', viz=True, save_path='results/guided.jpg') 21 | 22 | model_explainer.explain(image, method_name='smooth_grad', viz=True, save_path='results/smooth_grad.jpg') 23 | 24 | model_explainer.explain(image, method_name='integrate_grad', viz=True, save_path='results/integrate.jpg') 25 | 26 | model_explainer.explain(image, method_name='gradcam', target_layer_name='features_29', viz=True, save_path='results/gradcam.jpg') 27 | 28 | model_explainer.explain(image, method_name='gradcampp', target_layer_name='features_29', viz=True, save_path='results/gradcampp.jpg') 29 | 30 | model_explainer.explain(image, method_name='scorecam', target_layer_name='features_29', viz=True, save_path='results/scorecam.jpg') 31 | 32 | ## Local \- Perturbation 33 | 34 | ### Text 35 | 36 | from xdeep.xlocal.perturbation import xdeep_text 37 | 38 | # Load data and model 39 | 40 | model = load_model() 41 | dataset = load_data() 42 | 43 | # Start explaining 44 | 45 | text = dataset.test[0] 46 | 47 | explainer = xdeep_text.TextExplainer(model.predict_proba, dataset.class_names) 48 | 49 | explainer.explain('lime', text) 50 | explainer.show_explanation('lime') 51 | 52 | explainer.explain('cle', text) 53 | explainer.show_explanation('cle') 54 | 55 | explainer.explain('anchor', text) 56 | explainer.show_explanation('anchor') 57 | 58 | # for SHAP please check the Jupyter Notebook tutorial on GitHub. 59 | 60 | ### Tabular 61 | 62 | from xdeep.xlocal.perturbation import xdeep_tabular 63 | 64 | # Load data and model 65 | 66 | model = load_model() 67 | dataset = load_data() 68 | 69 | explainer = xdeep_tabular.TabularExplainer( 70 | model.predict_proba, dataset.train, dataset.class_names, dataset.feature_names, 71 | categorical_features=dataset.categorical_features, 72 | categorical_names=dataset.categorical_names) 73 | 74 | explainer.set_parameters('lime', discretize_continuous=True) 75 | explainer.explain('lime', dataset.test[0]) 76 | explainer.show_explanation('lime') 77 | 78 | explainer.set_parameters('cle', discretize_continuous=True) 79 | explainer.explain('cle', dataset.test[0]) 80 | explainer.show_explanation('cle') 81 | 82 | explainer.explain('shap', dataset.test[0]) 83 | explainer.show_explanation('shap') 84 | 85 | # for Anchor please check the Jupyter Notebook tutorial on GitHub. 86 | 87 | ### Image 88 | 89 | from xdeep.local.perturbation import xdeep_image 90 | 91 | # Load data and model 92 | 93 | model = load_model() 94 | dataset = load_data() 95 | 96 | # Start explaining 97 | 98 | explainer = xdeep_image.ImageExplainer(model.predict_proba, dataset.class_names) 99 | 100 | # deprocess function (If the range of image is [-1, 1], then transfroms it to [0, 1]) 101 | def f(x): 102 | return x / 2 + 0.5 103 | 104 | image = dataset.test[0] 105 | 106 | explainer.explain('lime', image, top_labels=3) 107 | 108 | explainer.show_explanation('lime', deprocess=f, positive_only=False) 109 | 110 | explainer.explain('cle', image, top_labels=3) 111 | 112 | explainer.show_explanation('cle', deprocess=f, positive_only=False) 113 | 114 | explainer.explain('anchor', image) 115 | 116 | explainer.show_explanation('anchor') 117 | 118 | from skimage.segmentation import slic 119 | segments_slic = slic(image, n_segments=50, compactness=30, sigma=3) 120 | explainer.initialize_shap(n_segment=50, segment=segments_slic) 121 | explainer.explain('shap', image, nsamples=200) 122 | explainer.show_explanation('shap', deprocess=f) -------------------------------------------------------------------------------- /docs/xdeep_text.md: -------------------------------------------------------------------------------- 1 | [[source]](.) 2 | ## TextExplainer class 3 | 4 | ```python 5 | xdeep.xlocal.perturbation.xdeep_text.TextExplainer(predict_proba, class_names) 6 | ``` 7 | 8 | Integrated explainer which explains text classifiers. 9 | 10 | --- 11 | ## TextExplainer methods 12 | 13 | ### explain 14 | 15 | 16 | ```python 17 | explain(method, instance, top_labels=None, labels=(1,)) 18 | ``` 19 | 20 | 21 | Generate explanation for a prediction using certain method. 22 | 23 | __Arguments__ 24 | 25 | - __method__: Str. The method you want to use. 26 | - __instance__: Instance to be explained. 27 | - __top_labels__: Integer. Number of labels you care about. 28 | - __labels__: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 29 | - __**kwargs__: Parameters setter. For more detail, please check 'explain_instance' in corresponding method. 30 | 31 | --- 32 | ### get_explanation 33 | 34 | 35 | ```python 36 | get_explanation(method) 37 | ``` 38 | 39 | 40 | Explanation getter. 41 | 42 | __Arguments__ 43 | 44 | - __method__: Str. The method you want to use. 45 | 46 | __Return__ 47 | 48 | An Explanation object with the corresponding method. 49 | 50 | --- 51 | ### initialize_shap 52 | 53 | 54 | ```python 55 | initialize_shap(predict_vectorized, vectorizer, train) 56 | ``` 57 | 58 | 59 | Explicit shap initializer. 60 | 61 | __Arguments__ 62 | 63 | - __predict_vectorized__: Function. Classifier prediction probability function which need vector passed in. 64 | - __vectorizer__: Vectorizer. A vectorizer which has 'transform' function that transforms list of str to vector. 65 | - __train__: Array. Train data, in this case a list of str. 66 | 67 | --- 68 | ### set_parameters 69 | 70 | 71 | ```python 72 | set_parameters(method) 73 | ``` 74 | 75 | 76 | Parameter setter. 77 | 78 | __Arguments__ 79 | 80 | - __method__: Str. The method you want to use. 81 | - __**kwargs__: Other parameters that depends on which method you use. For more detail, please check xdeep documentation. 82 | 83 | --- 84 | ### show_explanation 85 | 86 | 87 | ```python 88 | show_explanation(method) 89 | ``` 90 | 91 | 92 | Visualization of explanation of the corresponding method. 93 | 94 | __Arguments__ 95 | 96 | - __method__: Str. The method you want to use. 97 | - __**kwargs__: parameters setter. For more detail, please check xdeep documentation. 98 | 99 | ---- 100 | 101 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/perturbation/text_explainers/lime_text.py#L8) 102 | ## XDeepLimeTextExplainer class 103 | 104 | ```python 105 | xdeep.xlocal.perturbation.text_explainers.lime_text.XDeepLimeTextExplainer(predict_proba, class_names) 106 | ``` 107 | 108 | 109 | --- 110 | ## XDeepLimeTextExplainer methods 111 | 112 | ### explain 113 | 114 | 115 | ```python 116 | explain(instance, top_labels=None, labels=(1,)) 117 | ``` 118 | 119 | 120 | Generate explanation for a prediction using certain method. 121 | 122 | __Arguments__ 123 | 124 | - __instance__: Str. A raw text string to be explained. 125 | - __top_labels__: Integer. Number of labels you care about. 126 | - __labels__: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 127 | - __**kwargs__: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 128 | 129 | --- 130 | ### set_parameters 131 | 132 | 133 | ```python 134 | set_parameters() 135 | ``` 136 | 137 | 138 | Parameter setter for lime_text. 139 | 140 | __Arguments__ 141 | 142 | - __**kwargs__: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 143 | 144 | --- 145 | ### show_explanation 146 | 147 | 148 | ```python 149 | show_explanation(span=3) 150 | ``` 151 | 152 | 153 | Visualization of explanation of lime_text. 154 | 155 | __Arguments__ 156 | 157 | - __span__: Integer. Each row shows how many features. 158 | 159 | ---- 160 | 161 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/perturbation/text_explainers/cle_text.py#L9) 162 | ## XDeepCLETextExplainer class 163 | 164 | ```python 165 | xdeep.xlocal.perturbation.text_explainers.cle_text.XDeepCLETextExplainer(predict_proba, class_names) 166 | ``` 167 | 168 | 169 | --- 170 | ## XDeepCLETextExplainer methods 171 | 172 | ### explain 173 | 174 | 175 | ```python 176 | explain(instance, top_labels=None, labels=(1,), care_words=None, spans=(2,), include_original_feature=True) 177 | ``` 178 | 179 | 180 | Generate explanation for a prediction using certain method. 181 | 182 | __Arguments__ 183 | 184 | - __instance__: Str. A raw text string to be explained. 185 | - __top_labels__: Integer. Number of labels you care about. 186 | - __labels__: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 187 | - __**kwarg__: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 188 | 189 | --- 190 | ### set_parameters 191 | 192 | 193 | ```python 194 | set_parameters() 195 | ``` 196 | 197 | 198 | Parameter setter for cle_text. 199 | 200 | __Arguments__ 201 | 202 | - __**kwargs__: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 203 | 204 | --- 205 | ### show_explanation 206 | 207 | 208 | ```python 209 | show_explanation(span=3, plot=True) 210 | ``` 211 | 212 | 213 | Visualization of explanation of cle_text. 214 | 215 | __Arguments__ 216 | 217 | - __span__: Integer. Each row shows how many features. 218 | - __plot__: Boolean. Whether plots a figure. 219 | 220 | ---- 221 | 222 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/perturbation/text_explainers/anchor_text.py#L9) 223 | ## XDeepAnchorTextExplainer class 224 | 225 | ```python 226 | xdeep.xlocal.perturbation.text_explainers.anchor_text.XDeepAnchorTextExplainer(predict_proba, class_names) 227 | ``` 228 | 229 | 230 | --- 231 | ## XDeepAnchorTextExplainer methods 232 | 233 | ### explain 234 | 235 | 236 | ```python 237 | explain(instance, top_labels=None, labels=(1,)) 238 | ``` 239 | 240 | 241 | Generate explanation for a prediction using certain method. 242 | Anchor does not use top_labels and labels. 243 | 244 | __Arguments__ 245 | 246 | - __instance__: Str. A raw text string to be explained. 247 | - __**kwargs__: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor. 248 | 249 | --- 250 | ### set_parameters 251 | 252 | 253 | ```python 254 | set_parameters(nlp=None) 255 | ``` 256 | 257 | 258 | Parameter setter for anchor_text. 259 | 260 | __Arguments__ 261 | 262 | - __nlp__: Object. A spacy model object. 263 | - __**kwargs__: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor 264 | 265 | --- 266 | ### show_explanation 267 | 268 | 269 | ```python 270 | show_explanation(show_in_note_book=True, verbose=True) 271 | ``` 272 | 273 | 274 | Visualization of explanation of anchor_text. 275 | 276 | __Arguments__ 277 | 278 | - __show_in_note_book__: Boolean. Whether show in jupyter notebook. 279 | - __verbose__: Boolean. Whether print out examples and counter examples. 280 | 281 | ---- 282 | 283 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/perturbation/text_explainers/shap_text.py#L10) 284 | ## XDeepShapTextExplainer class 285 | 286 | ```python 287 | xdeep.xlocal.perturbation.text_explainers.shap_text.XDeepShapTextExplainer(predict_vectorized, class_names, vectorizer, train) 288 | ``` 289 | 290 | 291 | --- 292 | ## XDeepShapTextExplainer methods 293 | 294 | ### explain 295 | 296 | 297 | ```python 298 | explain(instance, top_labels=None, labels=(1,)) 299 | ``` 300 | 301 | 302 | Generate explanation for a prediction using certain method. 303 | 304 | __Arguments__ 305 | 306 | - __instance__: Str. A raw text string to be explained. 307 | - __top_labels__: Integer. Number of labels you care about. 308 | - __labels__: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 309 | - __**kwarg__: Parameters setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py. 310 | 311 | --- 312 | ### set_parameters 313 | 314 | 315 | ```python 316 | set_parameters() 317 | ``` 318 | 319 | 320 | Parameter setter for shap. The data pass to 'shap' cannot be str but vector.As a result, you need to pass in the vectorizer. 321 | 322 | __Arguments__ 323 | 324 | - __**kwargs__: Shap kernel explainer parameter setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py. 325 | 326 | --- 327 | ### show_explanation 328 | 329 | 330 | ```python 331 | show_explanation(show_in_note_book=True) 332 | ``` 333 | 334 | 335 | Visualization of explanation of shap. 336 | 337 | __Arguments__ 338 | 339 | - __show_in_note_book__: Booleam. Whether show in jupyter notebook. 340 | -------------------------------------------------------------------------------- /result_img/ensemble_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/result_img/ensemble_fig.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="x-deep", 8 | version="0.0.8", 9 | author="DATA LAB at Texas A&M University", 10 | author_email="hu@cse.tamu.edu", 11 | description="XDeep is an open-source package for Interpretable Machine Learning.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/datamllab/xdeep", 15 | packages=setuptools.find_packages(), 16 | install_requires=[ 17 | 'anchor-exp', 18 | 'shap', 19 | 'torch', 20 | 'torchvision', 21 | 'matplotlib', 22 | 'importlib_resources', 23 | ], 24 | extras_require={ 25 | 'tests': ['IPython', 26 | 'tensorflow < 2.0.0', 27 | 'pytest', 28 | ], 29 | }, 30 | classifiers=[ 31 | "Programming Language :: Python :: 3.6", 32 | "License :: OSI Approved :: MIT License", 33 | "Operating System :: OS Independent", 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/__init__.py -------------------------------------------------------------------------------- /tests/xglobal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/__init__.py -------------------------------------------------------------------------------- /tests/xglobal/images/jay.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/images/jay.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/deepdream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/deepdream.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/filter.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/filter.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/inverted/Inverted_Iteration_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_0.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/inverted/Inverted_Iteration_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_1.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/inverted/Inverted_Iteration_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_2.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/inverted/Inverted_Iteration_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_3.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/inverted/Inverted_Iteration_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_4.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/inverted/Inverted_Iteration_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_5.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/inverted/Inverted_Iteration_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_6.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/inverted/Inverted_Iteration_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_7.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/inverted/Inverted_Iteration_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_8.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/inverted/Inverted_Iteration_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_9.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/layer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/layer.jpg -------------------------------------------------------------------------------- /tests/xglobal/results/logit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/logit.jpg -------------------------------------------------------------------------------- /tests/xglobal/test_global.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | from xdeep.xglobal.global_explainers import * 3 | from xdeep.xglobal.methods.inverted_representation import * 4 | 5 | 6 | def test_explainer(): 7 | 8 | model = models.vgg16(pretrained=True) 9 | model_explainer = GlobalImageInterpreter(model) 10 | 11 | model_explainer.explain(method_name='filter', target_layer='features_24', target_filter=20, num_iter=10, 12 | save_path='tests/xglobal/results/filter.jpg') 13 | model_explainer.explain(method_name='layer', target_layer='features_24', num_iter=10, save_path='tests/xglobal/results/layer.jpg') 14 | model_explainer.explain(method_name='logit', target_layer='features_24', target_filter=20, num_iter=10, 15 | save_path='tests/xglobal/results/logit.jpg') 16 | model_explainer.explain(method_name='deepdream', target_layer='features_24', target_filter=20, 17 | input_='tests/xglobal/images/jay.jpg', num_iter=10, save_path='tests/xglobal/results/deepdream.jpg') 18 | model_explainer.explain(method_name='inverted', target_layer='features_24', input_='tests/xglobal/images/jay.jpg', num_iter=10) 19 | 20 | 21 | def test_filter(): 22 | model = models.vgg16(pretrained=True) 23 | layer = model.features[24] 24 | filters = [45, 271, 363, 409] 25 | g_ascent = GradientAscent(model.features) 26 | g_ascent.visualize(layer, filters, num_iter=30, title='filter visualization', save_path='tests/xglobal/results/filter.jpg') 27 | 28 | 29 | def test_layer(): 30 | model = models.vgg16(pretrained=True) 31 | layer = model.features[24] 32 | g_ascent = GradientAscent(model.features) 33 | g_ascent.visualize(layer, num_iter=100, title='layer visualization', save_path='tests/xglobal/results/layer.jpg') 34 | 35 | 36 | def test_deepdream(): 37 | model = models.vgg16(pretrained=True) 38 | layer = model.features[24] 39 | g_ascent = GradientAscent(model.features) 40 | g_ascent.visualize(layer=layer, filter_idxs=33, input_='tests/xglobal/images/jay.jpg', title='deepdream',num_iter=50, save_path='tests/xglobal/results/deepdream.jpg') 41 | 42 | 43 | def test_logit(): 44 | model = models.vgg16(pretrained=True) 45 | layer = model.classifier[-1] 46 | logit = 17 47 | g_ascent = GradientAscent(model) 48 | g_ascent.visualize(layer, logit, num_iter=30, title='logit visualization', save_path='tests/xglobal/results/logit.jpg') 49 | 50 | 51 | def test_inverted(): 52 | image_path = 'tests/xglobal/images/jay.jpg' 53 | image = load_image(image_path) 54 | norm_image = apply_transforms(image) 55 | 56 | model = models.vgg16(pretrained=True) 57 | model_dict = dict(arch=model, layer_name='features_29', input_size=(224, 224)) 58 | IR = InvertedRepresentation(model_dict) 59 | 60 | IR.visualize(norm_image) 61 | -------------------------------------------------------------------------------- /tests/xlocal_gradient/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_gradient/__init__.py -------------------------------------------------------------------------------- /tests/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG -------------------------------------------------------------------------------- /tests/xlocal_gradient/results/GuidedBP_ILSVRC2012_val_00000073.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_gradient/results/GuidedBP_ILSVRC2012_val_00000073.JPEG -------------------------------------------------------------------------------- /tests/xlocal_gradient/results/gradcam_ILSVRC2012_val_00000073.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_gradient/results/gradcam_ILSVRC2012_val_00000073.JPEG -------------------------------------------------------------------------------- /tests/xlocal_gradient/test_explainers.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | from xdeep.xlocal.gradient.explainers import * 3 | 4 | image = load_image('tests/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG') 5 | model = models.vgg16(pretrained=True) 6 | model_explainer = ImageInterpreter(model) 7 | 8 | def test_backprop(): 9 | model_explainer.explain(image, method_name='vanilla_backprop', viz=True, save_path='tests/xlocal_gradient/results/bp.jpg') 10 | model_explainer.explain(image, method_name='guided_backprop', viz=True, save_path='tests/xlocal_gradient/results/guided.jpg') 11 | 12 | def test_grad(): 13 | model_explainer.explain(image, method_name='smooth_grad', viz=True, save_path='tests/xlocal_gradient/results/smooth_grad.jpg') 14 | model_explainer.explain(image, method_name='integrate_grad', viz=True, save_path='tests/xlocal_gradient/results/integrate.jpg') 15 | model_explainer.explain(image, method_name='smooth_guided_grad', viz=True, save_path='tests/xlocal_gradient/results/smooth_guided_grad.jpg') 16 | model_explainer.explain(image, method_name='integrate_guided_grad', viz=True, save_path='tests/xlocal_gradient/results/integrate_guided.jpg') 17 | 18 | def test_cam(): 19 | model_explainer.explain(image, method_name='gradcam', target_layer_name='features_29', viz=True, save_path='tests/xlocal_gradient/results/gradcam.jpg') 20 | model_explainer.explain(image, method_name='gradcampp', target_layer_name='features_29', viz=True, save_path='tests/xlocal_gradient/results/gradcampp.jpg') 21 | model_explainer.explain(image, method_name='scorecam', target_layer_name='features_29', viz=True, save_path='tests/xlocal_gradient/results/scorecam.jpg') 22 | -------------------------------------------------------------------------------- /tests/xlocal_gradient/test_gradcam.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | 3 | from xdeep.utils import * 4 | from xdeep.xlocal.gradient.cam.gradcam import * 5 | 6 | 7 | def test_gradcam(): 8 | 9 | image_path = 'tests/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG' 10 | image_name = image_path.split('/')[1] 11 | image = load_image(image_path) 12 | norm_image = apply_transforms(image) 13 | 14 | model = models.vgg16(pretrained=True) 15 | model_dict = dict(arch=model, layer_name='features_29', input_size=(224, 224)) 16 | 17 | gradcam = GradCAM(model_dict) 18 | output = gradcam(norm_image) 19 | visualize(norm_image, output, save_path='tests/xlocal_gradient/results/gradcam_' + image_name) 20 | -------------------------------------------------------------------------------- /tests/xlocal_gradient/test_guided_backprop.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | 3 | from xdeep.utils import * 4 | from xdeep.xlocal.gradient.backprop.guided_backprop import * 5 | 6 | 7 | def test_bp(): 8 | 9 | image_path = 'tests/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG' 10 | image_name = image_path.split('/')[1] 11 | image = load_image(image_path) 12 | norm_image = apply_transforms(image) 13 | 14 | model = models.vgg16(pretrained=True) 15 | Guided_BP = Backprop(model, guided=True) 16 | output = Guided_BP.calculate_gradients(norm_image) 17 | visualize(norm_image, output, save_path='tests/xlocal_gradient/results/GuidedBP_' + image_name) 18 | -------------------------------------------------------------------------------- /tests/xlocal_perturbation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/__init__.py -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/Eskimo dog, husky.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/Eskimo dog, husky.JPEG -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/Great Pyrenees.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/Great Pyrenees.JPEG -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/Ibizan hound.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/Ibizan hound.JPEG -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/bee eater.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/bee eater.JPEG -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/car mirror.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/car mirror.JPEG -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/eggnog.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/eggnog.JPEG -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/kitten.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/kitten.jpg -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/oystercatcher.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/oystercatcher.jpg -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata.README.1.0.txt: -------------------------------------------------------------------------------- 1 | 2 | ======= 3 | 4 | Introduction 5 | 6 | This README v1.0 (June, 2005) for the v1.0 sentence polarity dataset comes 7 | from the URL 8 | http://www.cs.cornell.edu/people/pabo/movie-review-data . 9 | 10 | ======= 11 | 12 | Citation Info 13 | 14 | This data was first used in Bo Pang and Lillian Lee, 15 | ``Seeing stars: Exploiting class relationships for sentiment categorization 16 | with respect to rating scales.'', Proceedings of the ACL, 2005. 17 | 18 | @InProceedings{Pang+Lee:05a, 19 | author = {Bo Pang and Lillian Lee}, 20 | title = {Seeing stars: Exploiting class relationships for sentiment 21 | categorization with respect to rating scales}, 22 | booktitle = {Proceedings of the ACL}, 23 | year = 2005 24 | } 25 | 26 | ======= 27 | 28 | Data Format Summary 29 | 30 | - rt-polaritydata.tar.gz: contains this readme and two data files that 31 | were used in the experiments described in Pang/Lee ACL 2005. 32 | 33 | Specifically: 34 | * rt-polarity.pos contains 5331 positive snippets 35 | * rt-polarity.neg contains 5331 negative snippets 36 | 37 | Each line in these two files corresponds to a single snippet (usually 38 | containing roughly one single sentence); all snippets are down-cased. 39 | The snippets were labeled automatically, as described below (see 40 | section "Label Decision"). 41 | 42 | Note: The original source files from which the data in 43 | rt-polaritydata.tar.gz was derived can be found in the subjective 44 | part (Rotten Tomatoes pages) of subjectivity_html.tar.gz (released 45 | with subjectivity dataset v1.0). 46 | 47 | 48 | ======= 49 | 50 | Label Decision 51 | 52 | We assumed snippets (from Rotten Tomatoes webpages) for reviews marked with 53 | ``fresh'' are positive, and those for reviews marked with ``rotten'' are 54 | negative. 55 | -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.neg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.neg -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.pos: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.pos -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/sea_snake.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/sea_snake.JPEG -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/sky.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/sky.jpg -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/tusker.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/tusker.JPEG -------------------------------------------------------------------------------- /tests/xlocal_perturbation/data/violin.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/violin.JPEG -------------------------------------------------------------------------------- /tests/xlocal_perturbation/test_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | import numpy as np 4 | # from skimage.segmentation import slic 5 | import xdeep.xlocal.perturbation.xdeep_image as xdeep_image 6 | from xdeep.utils import * 7 | 8 | image = load_image('tests/xlocal_perturbation/data/violin.JPEG') 9 | image = np.asarray(image) 10 | model = models.vgg16(pretrained=True) 11 | 12 | 13 | def test_image_data(): 14 | with torch.no_grad(): 15 | names = create_readable_names_for_imagenet_labels() 16 | 17 | class_names = [] 18 | for item in names: 19 | class_names.append(names[item]) 20 | 21 | def predict_fn(x): 22 | return model(torch.from_numpy(x.reshape(-1, 3, 224, 224)).float()).numpy() 23 | 24 | explainer = xdeep_image.ImageExplainer(predict_fn, class_names) 25 | 26 | explainer.explain('lime', image, top_labels=1, num_samples=100) 27 | explainer.show_explanation('lime', positive_only=False) 28 | 29 | explainer.explain('cle', image, top_labels=1, num_samples=100) 30 | explainer.show_explanation('cle', positive_only=False) 31 | 32 | # explainer.explain('anchor', image, threshold=0.7, coverage_samples=5000) 33 | # explainer.show_explanation('anchor') 34 | 35 | # segments_slic = slic(image, n_segments=10, compactness=30, sigma=3) 36 | # explainer.initialize_shap(n_segment=10, segment=segments_slic) 37 | # explainer.explain('shap',image, nsamples=10) 38 | # explainer.show_explanation('shap') 39 | 40 | 41 | def create_readable_names_for_imagenet_labels(): 42 | """Create a dict mapping label id to human readable string. 43 | 44 | Returns: 45 | labels_to_names: dictionary where keys are integers from to 1000 46 | and values are human-readable names. 47 | 48 | We retrieve a synset file, which contains a list of valid synset labels used 49 | by ILSVRC competition. There is one synset one per line, eg. 50 | # n01440764 51 | # n01443537 52 | We also retrieve a synset_to_human_file, which contains a mapping from synsets 53 | to human-readable names for every synset in Imagenet. These are stored in a 54 | tsv format, as follows: 55 | # n02119247 black fox 56 | # n02119359 silver fox 57 | We assign each synset (in alphabetical order) an integer, starting from 1 58 | (since 0 is reserved for the background class). 59 | 60 | Code is based on 61 | https://github.com/tensorflow/models/blob/master/research/inception/inception/data/build_imagenet_data.py#L463 62 | """ 63 | 64 | filename = "tests/xlocal_perturbation/image_util/imagenet_lsvrc_2015_synsets.txt" 65 | synset_list = [s.strip() for s in open(filename).readlines()] 66 | num_synsets_in_ilsvrc = len(synset_list) 67 | if not num_synsets_in_ilsvrc == 1000: 68 | raise AssertionError() 69 | 70 | filename = "tests/xlocal_perturbation/image_util/imagenet_metadata.txt" 71 | synset_to_human_list = open(filename).readlines() 72 | num_synsets_in_all_imagenet = len(synset_to_human_list) 73 | if not num_synsets_in_all_imagenet == 21842: 74 | raise AssertionError() 75 | 76 | synset_to_human = {} 77 | for s in synset_to_human_list: 78 | parts = s.strip().split('\t') 79 | if not len(parts) == 2: 80 | raise AssertionError() 81 | synset = parts[0] 82 | human = parts[1] 83 | synset_to_human[synset] = human 84 | 85 | label_index = 1 86 | labels_to_names = {0: 'background'} 87 | for synset in synset_list: 88 | name = synset_to_human[synset] 89 | labels_to_names[label_index] = name 90 | label_index += 1 91 | 92 | return labels_to_names 93 | -------------------------------------------------------------------------------- /tests/xlocal_perturbation/test_tabular.py: -------------------------------------------------------------------------------- 1 | from sklearn.ensemble import RandomForestClassifier 2 | from anchor import utils 3 | import xdeep.xlocal.perturbation.xdeep_tabular as xdeep_tabular 4 | 5 | # Please download the dataset at 6 | # https://archive.ics.uci.edu/ml/datasets/adult 7 | # Then reset the path. 8 | 9 | def test_tabular_data(): 10 | dataset_folder = 'tests/xlocal_perturbation/data/' 11 | dataset = utils.load_dataset('adult', balance=True, dataset_folder=dataset_folder) 12 | 13 | c = RandomForestClassifier(n_estimators=50, n_jobs=5) 14 | c.fit(dataset.train, dataset.labels_train) 15 | 16 | explainer = xdeep_tabular.TabularExplainer(c.predict_proba, ['<=50K', '>50K'], dataset.feature_names, dataset.train[0:50], 17 | categorical_features=dataset.categorical_features, categorical_names=dataset.categorical_names) 18 | 19 | explainer.set_parameters('lime', discretize_continuous=True) 20 | explainer.explain('lime', dataset.test[0]) 21 | explainer.show_explanation('lime') 22 | 23 | explainer.set_parameters('cle', discretize_continuous=True) 24 | explainer.explain('cle', dataset.test[0]) 25 | explainer.show_explanation('cle') 26 | 27 | explainer.explain('shap', dataset.test[0]) 28 | explainer.show_explanation('shap') 29 | 30 | explainer.initialize_anchor(dataset.data) 31 | encoder = explainer.get_anchor_encoder(dataset.train[0:50], dataset.labels_train, dataset.validation, dataset.labels_validation) 32 | c_new = RandomForestClassifier(n_estimators=50, n_jobs=5) 33 | c_new.fit(encoder.transform(dataset.train), dataset.labels_train) 34 | explainer.set_anchor_predict_proba(c_new.predict_proba) 35 | explainer.explain('anchor', dataset.test[0]) 36 | explainer.show_explanation('anchor') 37 | 38 | -------------------------------------------------------------------------------- /tests/xlocal_perturbation/test_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | import spacy 3 | import numpy as np 4 | from sklearn import model_selection 5 | from sklearn.linear_model import LogisticRegression 6 | from sklearn.pipeline import make_pipeline 7 | from sklearn.feature_extraction.text import CountVectorizer 8 | import xdeep.xlocal.perturbation.xdeep_text as xdeep_text 9 | 10 | # Please download dataset at 11 | # http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz 12 | # Then reset the path. 13 | 14 | def load_polarity(path='tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata'): 15 | data = [] 16 | labels = [] 17 | f_names = ['rt-polarity.neg', 'rt-polarity.pos'] 18 | for (l, f) in enumerate(f_names): 19 | for line in open(os.path.join(path, f), 'rb'): 20 | try: 21 | line.decode('utf8') 22 | except: 23 | continue 24 | data.append(line.strip()) 25 | labels.append(l) 26 | return data, labels 27 | 28 | def test_text_data(): 29 | data, labels = load_polarity() 30 | train, test, train_labels, test_labels = model_selection.train_test_split(data, labels, test_size=.2, random_state=42) 31 | train, val, train_labels, val_labels = model_selection.train_test_split(train, train_labels, test_size=.1, random_state=42) 32 | train_labels = np.array(train_labels) 33 | test_labels = np.array(test_labels) 34 | val_labels = np.array(val_labels) 35 | 36 | vectorizer = CountVectorizer(min_df=1) 37 | vectorizer.fit(train) 38 | train_vectors = vectorizer.transform(train) 39 | test_vectors = vectorizer.transform(test) 40 | val_vectors = vectorizer.transform(val) 41 | 42 | x = LogisticRegression() 43 | x.fit(train_vectors, train_labels) 44 | c = make_pipeline(vectorizer, x) 45 | 46 | explainer = xdeep_text.TextExplainer(c.predict_proba, ['negative', 'positive']) 47 | 48 | text = 'This is a good movie .' 49 | 50 | explainer.explain('lime', text) 51 | explainer.show_explanation('lime') 52 | 53 | explainer.explain('cle', text) 54 | explainer.show_explanation('cle') 55 | 56 | try: 57 | nlp = spacy.load('en_core_web_sm') 58 | explainer.explain('anchor', text) 59 | explainer.show_explanation('anchor') 60 | except OSError: 61 | pass 62 | 63 | explainer.initialize_shap(x.predict_proba, vectorizer, train[0:10]) 64 | explainer.explain('shap', text) 65 | explainer.show_explanation('shap') 66 | -------------------------------------------------------------------------------- /tutorial/notebook/tabular - tutorial-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [], 3 | "metadata": {}, 4 | "nbformat": 4, 5 | "nbformat_minor": 2 6 | } 7 | -------------------------------------------------------------------------------- /tutorial/xglobal/images/jay.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/images/jay.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/deepdream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/deepdream.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/filter.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/filter.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/inverted/Inverted_Iteration_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_0.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/inverted/Inverted_Iteration_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_1.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/inverted/Inverted_Iteration_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_2.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/inverted/Inverted_Iteration_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_3.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/inverted/Inverted_Iteration_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_4.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/inverted/Inverted_Iteration_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_5.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/inverted/Inverted_Iteration_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_6.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/inverted/Inverted_Iteration_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_7.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/inverted/Inverted_Iteration_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_8.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/inverted/Inverted_Iteration_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_9.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/layer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/layer.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/results/logit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/logit.jpg -------------------------------------------------------------------------------- /tutorial/xglobal/test_global.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | from xdeep.xglobal.global_explainers import * 3 | from xdeep.xglobal.methods.inverted_representation import * 4 | 5 | 6 | def test_explainer(): 7 | 8 | model = models.vgg16(pretrained=True) 9 | model_explainer = GlobalImageInterpreter(model) 10 | 11 | model_explainer.explain(method_name='filter', target_layer='features_24', target_filter=20, num_iter=10, 12 | save_path='results/filter.jpg') 13 | model_explainer.explain(method_name='layer', target_layer='features_24', num_iter=10, save_path='results/layer.jpg') 14 | model_explainer.explain(method_name='logit', target_layer='features_24', target_filter=20, num_iter=10, 15 | save_path='results/logit.jpg') 16 | model_explainer.explain(method_name='deepdream', target_layer='features_24', target_filter=20, 17 | input_='images/jay.jpg', num_iter=10, save_path='results/deepdream.jpg') 18 | model_explainer.explain(method_name='inverted', target_layer='features_24', input_='images/jay.jpg', num_iter=10) 19 | 20 | 21 | def test_filter(): 22 | model = models.vgg16(pretrained=True) 23 | layer = model.features[24] 24 | filters = [45, 271, 363, 409] 25 | g_ascent = GradientAscent(model.features) 26 | g_ascent.visualize(layer, filters, num_iter=30, title='filter visualization', save_path='results/filter.jpg') 27 | 28 | 29 | def test_layer(): 30 | model = models.vgg16(pretrained=True) 31 | layer = model.features[24] 32 | g_ascent = GradientAscent(model.features) 33 | g_ascent.visualize(layer, num_iter=100, title='layer visualization', save_path='results/layer.jpg') 34 | 35 | 36 | def test_deepdream(): 37 | model = models.vgg16(pretrained=True) 38 | layer = model.features[24] 39 | g_ascent = GradientAscent(model.features) 40 | g_ascent.visualize(layer=layer, filter_idxs=33, input_='images/jay.jpg', title='deepdream',num_iter=50, save_path='results/deepdream.jpg') 41 | 42 | 43 | def test_logit(): 44 | model = models.vgg16(pretrained=True) 45 | layer = model.classifier[-1] 46 | logit = 17 47 | g_ascent = GradientAscent(model) 48 | g_ascent.visualize(layer, logit, num_iter=30, title='logit visualization', save_path='results/logit.jpg') 49 | 50 | 51 | def test_inverted(): 52 | image_path = 'images/jay.jpg' 53 | image = load_image(image_path) 54 | norm_image = apply_transforms(image) 55 | 56 | model = models.vgg16(pretrained=True) 57 | model_dict = dict(arch=model, layer_name='features_29', input_size=(224, 224)) 58 | IR = InvertedRepresentation(model_dict) 59 | 60 | IR.visualize(norm_image) 61 | -------------------------------------------------------------------------------- /tutorial/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_gradient/results/GuidedBP_ILSVRC2012_val_00000073.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_gradient/results/GuidedBP_ILSVRC2012_val_00000073.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_gradient/results/gradcam_ILSVRC2012_val_00000073.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_gradient/results/gradcam_ILSVRC2012_val_00000073.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_gradient/test_explainers.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | from xdeep.xlocal.gradient.explainers import * 3 | 4 | 5 | def test_explainer(): 6 | 7 | image = load_image('images/ILSVRC2012_val_00000073.JPEG') 8 | model = models.vgg16(pretrained=True) 9 | model_explainer = ImageInterpreter(model) 10 | 11 | # generate explanations 12 | model_explainer.explain(image, method_name='vanilla_backprop', viz=True, save_path='results/bp.jpg') 13 | model_explainer.explain(image, method_name='guided_backprop', viz=True, save_path='results/guided.jpg') 14 | model_explainer.explain(image, method_name='smooth_grad', viz=True, save_path='results/smooth_grad.jpg') 15 | model_explainer.explain(image, method_name='integrate_grad', viz=True, save_path='results/integrate.jpg') 16 | model_explainer.explain(image, method_name='smooth_guided_grad', viz=True, save_path='results/smooth_guided_grad.jpg') 17 | model_explainer.explain(image, method_name='integrate_guided_grad', viz=True, save_path='results/integrate_guided.jpg') 18 | model_explainer.explain(image, method_name='gradcam', target_layer_name='features_29', viz=True, save_path='results/gradcam.jpg') 19 | model_explainer.explain(image, method_name='gradcampp', target_layer_name='features_29', viz=True, save_path='results/gradcampp.jpg') 20 | model_explainer.explain(image, method_name='scorecam', target_layer_name='features_29', viz=True, save_path='results/scorecam.jpg') 21 | -------------------------------------------------------------------------------- /tutorial/xlocal_gradient/test_gradcam.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | 3 | from xdeep.utils import * 4 | from xdeep.xlocal.gradient.cam.gradcam import * 5 | 6 | 7 | def test_gradcam(): 8 | 9 | image_path = 'images/ILSVRC2012_val_00000073.JPEG' 10 | image_name = image_path.split('/')[1] 11 | image = load_image(image_path) 12 | norm_image = apply_transforms(image) 13 | 14 | model = models.vgg16(pretrained=True) 15 | model_dict = dict(arch=model, layer_name='features_29', input_size=(224, 224)) 16 | 17 | gradcam = GradCAM(model_dict) 18 | output = gradcam(norm_image) 19 | visualize(norm_image, output, save_path='results/gradcam_' + image_name) 20 | -------------------------------------------------------------------------------- /tutorial/xlocal_gradient/test_guided_backprop.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | 3 | from xdeep.utils import * 4 | from xdeep.xlocal.gradient.backprop.guided_backprop import * 5 | 6 | 7 | def test_bp(): 8 | 9 | image_path = 'images/ILSVRC2012_val_00000073.JPEG' 10 | image_name = image_path.split('/')[1] 11 | image = load_image(image_path) 12 | norm_image = apply_transforms(image) 13 | 14 | model = models.vgg16(pretrained=True) 15 | Guided_BP = Backprop(model, guided=True) 16 | output = Guided_BP.calculate_gradients(norm_image) 17 | visualize(norm_image, output, save_path='results/GuidedBP_' + image_name) 18 | -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/Eskimo dog, husky.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/Eskimo dog, husky.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/Great Pyrenees.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/Great Pyrenees.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/Ibizan hound.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/Ibizan hound.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/bee eater.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/bee eater.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/car mirror.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/car mirror.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/eggnog.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/eggnog.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/kitten.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/kitten.jpg -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/oystercatcher.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/oystercatcher.jpg -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata.README.1.0.txt: -------------------------------------------------------------------------------- 1 | 2 | ======= 3 | 4 | Introduction 5 | 6 | This README v1.0 (June, 2005) for the v1.0 sentence polarity dataset comes 7 | from the URL 8 | http://www.cs.cornell.edu/people/pabo/movie-review-data . 9 | 10 | ======= 11 | 12 | Citation Info 13 | 14 | This data was first used in Bo Pang and Lillian Lee, 15 | ``Seeing stars: Exploiting class relationships for sentiment categorization 16 | with respect to rating scales.'', Proceedings of the ACL, 2005. 17 | 18 | @InProceedings{Pang+Lee:05a, 19 | author = {Bo Pang and Lillian Lee}, 20 | title = {Seeing stars: Exploiting class relationships for sentiment 21 | categorization with respect to rating scales}, 22 | booktitle = {Proceedings of the ACL}, 23 | year = 2005 24 | } 25 | 26 | ======= 27 | 28 | Data Format Summary 29 | 30 | - rt-polaritydata.tar.gz: contains this readme and two data files that 31 | were used in the experiments described in Pang/Lee ACL 2005. 32 | 33 | Specifically: 34 | * rt-polarity.pos contains 5331 positive snippets 35 | * rt-polarity.neg contains 5331 negative snippets 36 | 37 | Each line in these two files corresponds to a single snippet (usually 38 | containing roughly one single sentence); all snippets are down-cased. 39 | The snippets were labeled automatically, as described below (see 40 | section "Label Decision"). 41 | 42 | Note: The original source files from which the data in 43 | rt-polaritydata.tar.gz was derived can be found in the subjective 44 | part (Rotten Tomatoes pages) of subjectivity_html.tar.gz (released 45 | with subjectivity dataset v1.0). 46 | 47 | 48 | ======= 49 | 50 | Label Decision 51 | 52 | We assumed snippets (from Rotten Tomatoes webpages) for reviews marked with 53 | ``fresh'' are positive, and those for reviews marked with ``rotten'' are 54 | negative. 55 | -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.neg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.neg -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.pos: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.pos -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/sea_snake.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/sea_snake.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/sky.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/sky.jpg -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/tusker.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/tusker.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/data/violin.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/violin.JPEG -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/image - tutorial.py: -------------------------------------------------------------------------------- 1 | # Please set the folder slim according to your own path. 2 | # Download slim at https://github.com/tensorflow/models/tree/master/research/slim 3 | 4 | import os 5 | import torch 6 | import torchvision.models as models 7 | import tensorflow as tf 8 | import numpy as np 9 | from image_util import inception_preprocessing, inception_v3 10 | from skimage.segmentation import slic 11 | import xdeep.xlocal.perturbation.xdeep_image as xdeep_image 12 | from xdeep.utils import * 13 | 14 | def test_image_data_torch(): 15 | with torch.no_grad(): 16 | image = load_image('./xlocal_perturbation/data/violin.JPEG') 17 | image = np.asarray(image) 18 | model = models.vgg16(pretrained=True) 19 | 20 | names = create_readable_names_for_imagenet_labels() 21 | 22 | class_names = [] 23 | for item in names: 24 | class_names.append(names[item]) 25 | 26 | def predict_fn(x): 27 | return model(torch.from_numpy(x.reshape(-1, 3, 224, 224)).float()).numpy() 28 | 29 | explainer = xdeep_image.ImageExplainer(predict_fn, class_names) 30 | 31 | explainer.explain('lime', image, top_labels=1, num_samples=30) 32 | explainer.show_explanation('lime', positive_only=False) 33 | 34 | explainer.explain('cle', image, top_labels=1, num_samples=30) 35 | explainer.show_explanation('cle', positive_only=False) 36 | 37 | explainer.explain('anchor', image, threshold=0.7, coverage_samples=5000) 38 | explainer.show_explanation('anchor') 39 | 40 | segments_slic = slic(image, n_segments=30, compactness=30, sigma=3) 41 | explainer.initialize_shap(n_segment=30, segment=segments_slic) 42 | explainer.explain('shap',image,nsamples=30) 43 | explainer.show_explanation('shap') 44 | 45 | def test_image_data_tf(): 46 | slim = tf.contrib.slim 47 | tf.reset_default_graph() 48 | session = tf.Session() 49 | 50 | names = create_readable_names_for_imagenet_labels() 51 | 52 | processed_images = tf.placeholder(tf.float32, shape=(None, 299, 299, 3)) 53 | 54 | with slim.arg_scope(inception_v3.inception_v3_arg_scope()): 55 | logits, _ = inception_v3.inception_v3(processed_images, num_classes=1001, is_training=False) 56 | probabilities = tf.nn.softmax(logits) 57 | 58 | # Please correctly set the model path. 59 | # Download the model at https://github.com/tensorflow/models/tree/master/research/slim 60 | checkpoints_dir = 'model' 61 | init_fn = slim.assign_from_checkpoint_fn( 62 | os.path.join(checkpoints_dir, 'inception_v3.ckpt'), 63 | slim.get_model_variables('InceptionV3')) 64 | init_fn(session) 65 | 66 | def predict_fn(images): 67 | return session.run(probabilities, feed_dict={processed_images: images}) 68 | 69 | def f(x): 70 | return x / 2 + 0.5 71 | 72 | class_names = [] 73 | for item in names: 74 | class_names.append(names[item]) 75 | 76 | def transform_img_fn(path_list): 77 | out = [] 78 | image_size = 299 79 | for f in path_list: 80 | with open(f,'rb') as img: 81 | image_raw = tf.image.decode_jpeg(img.read(), channels=3) 82 | image = inception_preprocessing.preprocess_image(image_raw, image_size, image_size, is_training=False) 83 | out.append(image) 84 | return session.run([out])[0] 85 | 86 | images = transform_img_fn(['./data/violin.JPEG']) 87 | image = images[0] 88 | 89 | explainer = xdeep_image.ImageExplainer(predict_fn, class_names) 90 | 91 | explainer.explain('lime', image, top_labels=3) 92 | explainer.show_explanation('lime', deprocess=f, positive_only=False) 93 | 94 | explainer.explain('cle', image, top_labels=3) 95 | explainer.show_explanation('cle', deprocess=f, positive_only=False) 96 | 97 | explainer.explain('anchor', image) 98 | explainer.show_explanation('anchor') 99 | 100 | segments_slic = slic(image, n_segments=50, compactness=30, sigma=3) 101 | explainer.initialize_shap(n_segment=50, segment=segments_slic) 102 | explainer.explain('shap',image,nsamples=400) 103 | explainer.show_explanation('shap',deprocess=f) 104 | 105 | def create_readable_names_for_imagenet_labels(): 106 | """Create a dict mapping label id to human readable string. 107 | 108 | Returns: 109 | labels_to_names: dictionary where keys are integers from to 1000 110 | and values are human-readable names. 111 | 112 | We retrieve a synset file, which contains a list of valid synset labels used 113 | by ILSVRC competition. There is one synset one per line, eg. 114 | # n01440764 115 | # n01443537 116 | We also retrieve a synset_to_human_file, which contains a mapping from synsets 117 | to human-readable names for every synset in Imagenet. These are stored in a 118 | tsv format, as follows: 119 | # n02119247 black fox 120 | # n02119359 silver fox 121 | We assign each synset (in alphabetical order) an integer, starting from 1 122 | (since 0 is reserved for the background class). 123 | 124 | Code is based on 125 | https://github.com/tensorflow/models/blob/master/research/inception/inception/data/build_imagenet_data.py#L463 126 | """ 127 | 128 | filename = "./image_util/imagenet_lsvrc_2015_synsets.txt" 129 | synset_list = [s.strip() for s in open(filename).readlines()] 130 | num_synsets_in_ilsvrc = len(synset_list) 131 | if not num_synsets_in_ilsvrc == 1000: 132 | raise AssertionError() 133 | 134 | filename = "./image_util/imagenet_metadata.txt" 135 | synset_to_human_list = open(filename).readlines() 136 | num_synsets_in_all_imagenet = len(synset_to_human_list) 137 | if not num_synsets_in_all_imagenet == 21842: 138 | raise AssertionError() 139 | 140 | synset_to_human = {} 141 | for s in synset_to_human_list: 142 | parts = s.strip().split('\t') 143 | if not len(parts) == 2: 144 | raise AssertionError() 145 | synset = parts[0] 146 | human = parts[1] 147 | synset_to_human[synset] = human 148 | 149 | label_index = 1 150 | labels_to_names = {0: 'background'} 151 | for synset in synset_list: 152 | name = synset_to_human[synset] 153 | labels_to_names[label_index] = name 154 | label_index += 1 155 | 156 | return labels_to_names 157 | -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/image_util/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/image_util/.DS_Store -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/tabular - tutorial.py: -------------------------------------------------------------------------------- 1 | from sklearn.ensemble import RandomForestClassifier 2 | from anchor import utils 3 | import xdeep.xlocal.perturbation.xdeep_tabular as xdeep_tabular 4 | 5 | # Please download the dataset at 6 | # https://archive.ics.uci.edu/ml/datasets/adult 7 | # Then reset the path. eg. 'data/' 8 | 9 | def test_tabular_data(): 10 | dataset_folder = 'data/' 11 | dataset = utils.load_dataset('adult', balance=True, dataset_folder=dataset_folder) 12 | 13 | c = RandomForestClassifier(n_estimators=50, n_jobs=5) 14 | c.fit(dataset.train, dataset.labels_train) 15 | 16 | explainer = xdeep_tabular.TabularExplainer(c.predict_proba, ['<=50K', '>50K'], dataset.feature_names, dataset.train[0:50], 17 | categorical_features=dataset.categorical_features, categorical_names=dataset.categorical_names) 18 | 19 | explainer.set_parameters('lime', discretize_continuous=True) 20 | explainer.explain('lime', dataset.test[0]) 21 | explainer.show_explanation('lime') 22 | 23 | explainer.set_parameters('cle', discretize_continuous=True) 24 | explainer.explain('cle', dataset.test[0]) 25 | explainer.show_explanation('cle') 26 | 27 | explainer.explain('shap', dataset.test[0]) 28 | explainer.show_explanation('shap') 29 | 30 | explainer.initialize_anchor(dataset.data) 31 | encoder = explainer.get_anchor_encoder(dataset.train[0:50], dataset.labels_train, dataset.validation, dataset.labels_validation) 32 | c_new = RandomForestClassifier(n_estimators=50, n_jobs=5) 33 | c_new.fit(encoder.transform(dataset.train), dataset.labels_train) 34 | explainer.set_anchor_predict_proba(c_new.predict_proba) 35 | explainer.explain('anchor', dataset.test[0]) 36 | explainer.show_explanation('anchor') 37 | -------------------------------------------------------------------------------- /tutorial/xlocal_perturbation/text - tutorial.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import spacy 4 | from sklearn import model_selection 5 | from sklearn.linear_model import LogisticRegression 6 | from sklearn.pipeline import make_pipeline 7 | from sklearn.feature_extraction.text import CountVectorizer 8 | import xdeep.xlocal.perturbation.xdeep_text as xdeep_text 9 | 10 | # Please download dataset at 11 | # http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz 12 | # Then reset the path. 13 | 14 | def load_polarity(path='./data/rt-polaritydata/rt-polaritydata'): 15 | data = [] 16 | labels = [] 17 | f_names = ['rt-polarity.neg', 'rt-polarity.pos'] 18 | for (l, f) in enumerate(f_names): 19 | for line in open(os.path.join(path, f), 'rb'): 20 | try: 21 | line.decode('utf8') 22 | except: 23 | continue 24 | data.append(line.strip()) 25 | labels.append(l) 26 | return data, labels 27 | 28 | def test_text_data(): 29 | data, labels = load_polarity() 30 | train, test, train_labels, test_labels = model_selection.train_test_split(data, labels, test_size=.2, random_state=42) 31 | train, val, train_labels, val_labels = model_selection.train_test_split(train, train_labels, test_size=.1, random_state=42) 32 | train_labels = np.array(train_labels) 33 | test_labels = np.array(test_labels) 34 | val_labels = np.array(val_labels) 35 | 36 | vectorizer = CountVectorizer(min_df=1) 37 | vectorizer.fit(train) 38 | train_vectors = vectorizer.transform(train) 39 | test_vectors = vectorizer.transform(test) 40 | val_vectors = vectorizer.transform(val) 41 | 42 | x = LogisticRegression() 43 | x.fit(train_vectors, train_labels) 44 | c = make_pipeline(vectorizer, x) 45 | 46 | explainer = xdeep_text.TextExplainer(c.predict_proba, ['negative', 'positive']) 47 | 48 | text = 'This is a good movie .' 49 | 50 | explainer.explain('lime', text) 51 | explainer.show_explanation('lime') 52 | 53 | explainer.explain('cle', text) 54 | explainer.show_explanation('cle') 55 | 56 | try: 57 | nlp = spacy.load('en_core_web_sm') 58 | explainer.explain('anchor', text) 59 | explainer.show_explanation('anchor') 60 | except OSError: 61 | pass 62 | 63 | explainer.initialize_shap(x.predict_proba, vectorizer, train[0:10]) 64 | explainer.explain('shap', text) 65 | explainer.show_explanation('shap') 66 | -------------------------------------------------------------------------------- /xdeep/__init__.py: -------------------------------------------------------------------------------- 1 | name = "x-deep" -------------------------------------------------------------------------------- /xdeep/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Misa Ogura. (2019, September 9). 4 | MisaOgura/flashtorch: 0.1.0 (Version v0.1.0). 5 | Zenodo. http://doi.org/10.5281/zenodo.3403214 6 | 7 | This module provides utility functions for image handling and tensor 8 | transformation. 9 | 10 | """ 11 | from PIL import Image 12 | import matplotlib.pyplot as plt 13 | 14 | import torchvision.transforms as transforms 15 | import torchvision.transforms.functional as F 16 | 17 | from .imagenet import * 18 | 19 | 20 | def load_image(image_path): 21 | """Loads image as a PIL RGB image. 22 | 23 | Args: 24 | - **image_path (str) - **: A path to the image 25 | 26 | Returns: 27 | An instance of PIL.Image.Image in RGB 28 | 29 | """ 30 | 31 | return Image.open(image_path).convert('RGB') 32 | 33 | 34 | def apply_transforms(image, size=224): 35 | """Transforms a PIL image to torch.Tensor. 36 | 37 | Applies a series of tranformations on PIL image including a conversion 38 | to a tensor. The returned tensor has a shape of :math:`(N, C, H, W)` and 39 | is ready to be used as an input to neural networks. 40 | 41 | First the image is resized to 256, then cropped to 224. The `means` and 42 | `stds` for normalisation are taken from numbers used in ImageNet, as 43 | currently developing the package for visualizing pre-trained models. 44 | 45 | The plan is to to expand this to handle custom size/mean/std. 46 | 47 | Args: 48 | image (PIL.Image.Image or numpy array) 49 | size (int, optional, default=224): Desired size (width/height) of the 50 | output tensor 51 | 52 | Shape: 53 | Input: :math:`(C, H, W)` for numpy array 54 | Output: :math:`(N, C, H, W)` 55 | 56 | Returns: 57 | torch.Tensor (torch.float32): Transformed image tensor 58 | 59 | Note: 60 | Symbols used to describe dimensions: 61 | - N: number of images in a batch 62 | - C: number of channels 63 | - H: height of the image 64 | - W: width of the image 65 | 66 | """ 67 | 68 | if not isinstance(image, Image.Image): 69 | image = F.to_pil_image(image) 70 | 71 | means = [0.485, 0.456, 0.406] 72 | stds = [0.229, 0.224, 0.225] 73 | 74 | transform = transforms.Compose([ 75 | transforms.Resize(size), 76 | transforms.CenterCrop(size), 77 | transforms.ToTensor(), 78 | transforms.Normalize(means, stds) 79 | ]) 80 | 81 | tensor = transform(image).unsqueeze(0) 82 | 83 | tensor.requires_grad = True 84 | 85 | return tensor 86 | 87 | 88 | def denormalize(tensor): 89 | """Reverses the normalisation on a tensor. 90 | 91 | Performs a reverse operation on a tensor, so the pixel value range is 92 | between 0 and 1. Useful for when plotting a tensor into an image. 93 | 94 | Normalisation: (image - mean) / std 95 | Denormalisation: image * std + mean 96 | 97 | Args: 98 | tensor (torch.Tensor, dtype=torch.float32): Normalized image tensor 99 | 100 | Shape: 101 | Input: :math:`(N, C, H, W)` 102 | Output: :math:`(N, C, H, W)` (same shape as input) 103 | 104 | Return: 105 | torch.Tensor (torch.float32): Demornalised image tensor with pixel 106 | values between [0, 1] 107 | 108 | Note: 109 | Symbols used to describe dimensions: 110 | - N: number of images in a batch 111 | - C: number of channels 112 | - H: height of the image 113 | - W: width of the image 114 | 115 | """ 116 | 117 | means = [0.485, 0.456, 0.406] 118 | stds = [0.229, 0.224, 0.225] 119 | 120 | denormalized = tensor.clone() 121 | 122 | for channel, mean, std in zip(denormalized[0], means, stds): 123 | channel.mul_(std).add_(mean) 124 | 125 | return denormalized 126 | 127 | 128 | def standardize_and_clip(tensor, min_value=0.0, max_value=1.0, 129 | saturation=0.1, brightness=0.5): 130 | 131 | """Standardizes and clips input tensor. 132 | Standardizes the input tensor (mean = 0.0, std = 1.0). The color saturation 133 | and brightness are adjusted, before tensor values are clipped to min/max 134 | (default: 0.0/1.0). 135 | Args: 136 | tensor (torch.Tensor): 137 | min_value (float, optional, default=0.0) 138 | max_value (float, optional, default=1.0) 139 | saturation (float, optional, default=0.1) 140 | brightness (float, optional, default=0.5) 141 | Shape: 142 | Input: :math:`(C, H, W)` 143 | Output: Same as the input 144 | Return: 145 | torch.Tensor (torch.float32): Normalised tensor with values between 146 | [min_value, max_value] 147 | """ 148 | 149 | tensor = tensor.detach().cpu() 150 | 151 | mean = tensor.mean() 152 | std = tensor.std() 153 | 154 | if std == 0: 155 | std += 1e-7 156 | 157 | standardized = tensor.sub(mean).div(std).mul(saturation) 158 | clipped = standardized.add(brightness).clamp(min_value, max_value) 159 | 160 | return clipped 161 | 162 | 163 | def format_for_plotting(tensor): 164 | """Formats the shape of tensor for plotting. 165 | 166 | Tensors typically have a shape of :math:`(N, C, H, W)` or :math:`(C, H, W)` 167 | which is not suitable for plotting as images. This function formats an 168 | input tensor :math:`(H, W, C)` for RGB and :math:`(H, W)` for mono-channel 169 | data. 170 | 171 | Args: 172 | tensor (torch.Tensor, torch.float32): Image tensor 173 | 174 | Shape: 175 | Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` 176 | Output: :math:`(H, W, C)` or :math:`(H, W)`, respectively 177 | 178 | Return: 179 | torch.Tensor (torch.float32): Formatted image tensor (detached) 180 | 181 | Note: 182 | Symbols used to describe dimensions: 183 | - N: number of images in a batch 184 | - C: number of channels 185 | - H: height of the image 186 | - W: width of the image 187 | 188 | """ 189 | 190 | has_batch_dimension = len(tensor.shape) == 4 191 | formatted = tensor.clone() 192 | 193 | if has_batch_dimension: 194 | formatted = tensor.squeeze(0) 195 | 196 | if formatted.shape[0] == 1: 197 | return formatted.squeeze(0).detach() 198 | else: 199 | return formatted.permute(1, 2, 0).detach() 200 | 201 | 202 | def visualize(input_, gradients, save_path=None, cmap='viridis', alpha=0.7): 203 | 204 | """ Method to plot the explanation. 205 | 206 | # Arguments 207 | input_: Tensor. Original image. 208 | gradients: Tensor. Saliency map result. 209 | save_path: String. Defaults to None. 210 | cmap: Defaults to be 'viridis'. 211 | alpha: Defaults to be 0.7. 212 | 213 | """ 214 | input_ = format_for_plotting(denormalize(input_)) 215 | gradients = format_for_plotting(standardize_and_clip(gradients)) 216 | 217 | subplots = [ 218 | ('Input image', [(input_, None, None)]), 219 | ('Saliency map', [(gradients, None, None)]), 220 | ('Overlay', [(input_, None, None), (gradients, cmap, alpha)]) 221 | ] 222 | 223 | num_subplots = len(subplots) 224 | 225 | fig = plt.figure(figsize=(16, 3)) 226 | 227 | for i, (title, images) in enumerate(subplots): 228 | ax = fig.add_subplot(1, num_subplots, i + 1) 229 | ax.set_axis_off() 230 | 231 | for image, cmap, alpha in images: 232 | ax.imshow(image, cmap=cmap, alpha=alpha) 233 | 234 | ax.set_title(title) 235 | if save_path is not None: 236 | plt.savefig(save_path) 237 | -------------------------------------------------------------------------------- /xdeep/utils/imagenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Misa Ogura. (2019, September 9). 4 | MisaOgura/flashtorch: 0.1.0 (Version v0.1.0). 5 | Zenodo. http://doi.org/10.5281/zenodo.3403214 6 | ''' 7 | 8 | import json 9 | 10 | from collections.abc import Mapping 11 | from importlib_resources import path 12 | 13 | from . import resources 14 | 15 | 16 | class ImageNetIndex(Mapping): 17 | """Interface to retrieve ImageNet class indeces from class names. 18 | 19 | This class implements a dictionary like object, aiming to provide an 20 | easy-to-use look-up table for finding a target class index from an ImageNet 21 | class name. 22 | 23 | Reference: 24 | - ImageNet class index: https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json 25 | - Synsets: http://image-net.org/challenges/LSVRC/2015/browse-synsets 26 | 27 | Note: 28 | Class names in `imagenet_class_index.json` has been slightly modified 29 | from the source due to duplicated class names (e.g. crane). This helps 30 | make the use of this tool simpler. 31 | """ 32 | 33 | def __init__(self): 34 | self._index = {} 35 | 36 | with path(resources, 'imagenet_class_index.json') as source_path: 37 | with open(str(source_path), 'r') as source: 38 | data = json.load(source) 39 | 40 | for index, (_, class_name) in data.items(): 41 | class_name = class_name.lower().replace('_', ' ') 42 | self._index[class_name] = int(index) 43 | 44 | def __len__(self): 45 | return len(self._index) 46 | 47 | def __iter__(self): 48 | return iter(self._index) 49 | 50 | def __getitem__(self, phrase): 51 | if type(phrase) != str: 52 | raise TypeError('Target class needs to be a string.') 53 | 54 | if phrase in self._index: 55 | return self._index[phrase] 56 | 57 | partial_matches = self._find_partial_matches(phrase) 58 | 59 | if not any(partial_matches): 60 | return None 61 | elif len(partial_matches) > 1: 62 | raise ValueError('Multiple potential matches found: {}' \ 63 | .format(', '.join(map(str, partial_matches)))) 64 | 65 | target_class = partial_matches.pop() 66 | 67 | return self._index[target_class] 68 | 69 | def __contains__(self, key): 70 | return any(key in name for name in self._index) 71 | 72 | def keys(self): 73 | return self._index.keys() 74 | 75 | def items(self): 76 | return self._index.items() 77 | 78 | def _find_partial_matches(self, phrase): 79 | words = phrase.lower().split(' ') 80 | 81 | # Find the intersection between search words and class names to 82 | # prioritise whole word matches 83 | # e.g. If words = {'dalmatian', 'dog'} then matches 'dalmatian' 84 | 85 | matches = set(words).intersection(set(self.keys())) 86 | 87 | if not any(matches): 88 | # Find substring matches between search words and class names to 89 | # accommodate for fuzzy matches to some extend 90 | # e.g. If words = {'foxhound'} then matches 'english foxhound' 91 | 92 | matches = [key for word in words for key in self.keys() \ 93 | if word in key] 94 | 95 | return matches 96 | -------------------------------------------------------------------------------- /xdeep/utils/resources/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /xdeep/xglobal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xglobal/__init__.py -------------------------------------------------------------------------------- /xdeep/xglobal/global_explainers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Thu Sep 19 18:37:22 2019 3 | @author: Haofan Wang - github.com/haofanwang 4 | """ 5 | 6 | from xdeep.xglobal.methods import * 7 | from xdeep.xglobal.methods.activation import * 8 | from xdeep.xglobal.methods.inverted_representation import * 9 | 10 | 11 | class GlobalImageInterpreter(object): 12 | """ Class for global image explanations. 13 | 14 | GlobalImageInterpreter provides unified interface to call different visualization methods. 15 | The user can use it by several lines. It includes '__init__()' and 'explain()'. The 16 | users can specify the name of visualization method and target layer. If params are 17 | not specified, default params will be used. 18 | """ 19 | 20 | def __init__(self, model): 21 | """Init 22 | 23 | # Arguments 24 | model: torchvision.models. A pretrained model. 25 | result: torch.Tensor. Save generated saliency map. 26 | 27 | """ 28 | self.model = model 29 | self.result = None 30 | 31 | def maximize_activation(self, model): 32 | """ 33 | Return class of activation maximum. 34 | 35 | """ 36 | 37 | return GradientAscent(model) 38 | 39 | def inverted_feature(self, model_dict): 40 | """ 41 | Return class of inverted representation. 42 | """ 43 | return InvertedRepresentation(model_dict) 44 | 45 | def explain(self, method_name=None, target_layer=None, target_filter=None, input_=None, num_iter=10, save_path=None): 46 | 47 | """Function to call different visualization methods. 48 | 49 | # Arguments 50 | method_name: str. The name of interpreter method. Currently, global explanation methods support 'filter', 51 | 'layer', 'logit', 'deepdream', 'inverted'. 52 | target_layer: torch.nn.Linear or torch.nn.conv.Conv2d. The objective layer. 53 | target_filter: int or list. Index of filter or filters. 54 | input_: str or Tensor. Path of input image or normalized tensor. Default to be None. 55 | num_iter: int. Iter times. Default to 10. 56 | save_path: str. Path to save generated explanations. Default to None. 57 | """ 58 | 59 | method_switch = {"filter": self.maximize_activation, 60 | "layer": self.maximize_activation, 61 | "logit": self.maximize_activation, 62 | "deepdream": self.maximize_activation, 63 | "inverted": self.inverted_feature} 64 | 65 | if method_name is None: 66 | raise TypeError("method_name has to be specified!") 67 | 68 | if method_name == "inverted": 69 | model_dict = dict(arch=self.model, layer_name=target_layer, input_size=(224, 224)) 70 | explainer = method_switch[method_name](model_dict) 71 | 72 | self.result = explainer.visualize(input_=input_, img_size=(224, 224), alpha_reg_alpha=6, 73 | alpha_reg_lambda=1e-7, tv_reg_beta=2, tv_reg_lambda=1e-8, 74 | n=num_iter, save_path=save_path) 75 | else: 76 | if target_layer is None: 77 | raise TypeError("target_layer cannot be None!") 78 | elif isinstance(target_layer, str): 79 | target_layer = find_layer(self.model, target_layer) 80 | 81 | explainer = method_switch[method_name](self.model) 82 | 83 | self.result = explainer.visualize(layer=target_layer, filter_idxs=target_filter, 84 | input_=input_, lr=1, num_iter=num_iter, num_subplots=4, 85 | figsize=(4, 4), title='Visualization Result', return_output=True, 86 | save_path=save_path) 87 | 88 | -------------------------------------------------------------------------------- /xdeep/xglobal/methods/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | def find_layer(arch, target_layer_name): 3 | if target_layer_name is None: 4 | if 'resnet' in str(type(arch)): 5 | target_layer_name = 'layer4' 6 | elif 'alexnet' in str(type(arch)) or 'vgg' in str(type(arch)) or 'squeezenet' in str( 7 | type(arch)) or 'densenet' in str(type(arch)): 8 | target_layer_name = 'features' 9 | else: 10 | raise Exception('Invalid layer name! Please specify layer name.', target_layer_name) 11 | 12 | hierarchy = target_layer_name.split('_') 13 | 14 | if hierarchy[0] not in arch._modules.keys(): 15 | raise Exception('Invalid layer name!', target_layer_name) 16 | 17 | target_layer = arch._modules[hierarchy[0]] 18 | 19 | if len(hierarchy) >= 2: 20 | if hierarchy[1] not in target_layer._modules.keys(): 21 | raise Exception('Invalid layer name!', target_layer_name) 22 | target_layer = target_layer._modules[hierarchy[1]] 23 | 24 | if len(hierarchy) >= 3: 25 | if hierarchy[2] not in target_layer._modules.keys(): 26 | raise Exception('Invalid layer name!', target_layer_name) 27 | target_layer = target_layer._modules[hierarchy[2]] 28 | 29 | if len(hierarchy) >= 4: 30 | if hierarchy[3] not in target_layer._modules.keys(): 31 | raise Exception('Invalid layer name!', target_layer_name) 32 | target_layer = target_layer._modules[hierarchy[3]] 33 | 34 | return target_layer -------------------------------------------------------------------------------- /xdeep/xglobal/methods/inverted_representation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.optim as optim 5 | from torchvision.utils import save_image 6 | 7 | from xdeep.utils import (load_image, apply_transforms, denormalize) 8 | 9 | 10 | class InvertedRepresentation(object): 11 | def __init__(self, model_dict): 12 | 13 | """Init 14 | 15 | # Arguments: 16 | layer_name: str. Name of target layer. 17 | model: torchvision.models. A pretrained model. 18 | img_size: tuple. Size of input image, default to (224,224). 19 | """ 20 | 21 | layer_name = model_dict['layer_name'] 22 | self.model = model_dict['arch'] 23 | self.model.eval() 24 | 25 | self.img_size = model_dict['input_size'] 26 | 27 | self.activations = dict() 28 | self.gradients = dict() 29 | 30 | def forward_hook(module, input, output): 31 | self.activations['value'] = output 32 | return None 33 | 34 | def backward_hook(module, grad_input, grad_output): 35 | self.gradients['value'] = grad_output[0] 36 | return None 37 | 38 | target_layer = self.find_layer(self.model, layer_name) 39 | target_layer.register_forward_hook(forward_hook) 40 | target_layer.register_backward_hook(backward_hook) 41 | 42 | def find_layer(self, arch, target_layer_name): 43 | """ 44 | Return target layer 45 | """ 46 | 47 | if target_layer_name is None: 48 | if 'resnet' in str(type(arch)): 49 | target_layer_name = 'layer4' 50 | elif 'alexnet' in str(type(arch)) or 'vgg' in str(type(arch)) or 'squeezenet' in str(type(arch)) or 'densenet' in str(type(arch)): 51 | target_layer_name = 'features' 52 | else: 53 | raise Exception('Invalid layer name! Please specify layer name.', target_layer_name) 54 | 55 | hierarchy = target_layer_name.split('_') 56 | 57 | if hierarchy[0] not in arch._modules.keys(): 58 | raise Exception('Invalid layer name!', target_layer_name) 59 | 60 | target_layer = arch._modules[hierarchy[0]] 61 | 62 | if len(hierarchy) >= 2: 63 | if hierarchy[1] not in target_layer._modules.keys(): 64 | raise Exception('Invalid layer name!', target_layer_name) 65 | target_layer = target_layer._modules[hierarchy[1]] 66 | 67 | if len(hierarchy) >= 3: 68 | if hierarchy[2] not in target_layer._modules.keys(): 69 | raise Exception('Invalid layer name!', target_layer_name) 70 | target_layer = target_layer._modules[hierarchy[2]] 71 | 72 | if len(hierarchy) >= 4: 73 | if hierarchy[3] not in target_layer._modules.keys(): 74 | raise Exception('Invalid layer name!', target_layer_name) 75 | target_layer = target_layer._modules[hierarchy[3]] 76 | 77 | return target_layer 78 | 79 | def alpha_norm(self, input_, alpha): 80 | """ 81 | Converts matrix to vector then calculates the alpha norm 82 | """ 83 | return ((input_.view(-1))**alpha).sum() 84 | 85 | def total_variation_norm(self, input_, beta): 86 | """ 87 | Total variation norm is the second norm in the paper 88 | represented as R_V(x) 89 | """ 90 | to_check = input_[:, :-1, :-1] 91 | one_bottom = input_[:, 1:, :-1] 92 | one_right = input_[:, :-1, 1:] 93 | total_variation = (((to_check - one_bottom)**2 + (to_check - one_right)**2)**(beta/2)).sum() 94 | return total_variation 95 | 96 | def euclidian_loss(self, org_matrix, target_matrix): 97 | """ 98 | Euclidian loss is the main loss function. 99 | """ 100 | distance_matrix = target_matrix - org_matrix 101 | euclidian_distance = self.alpha_norm(distance_matrix, 2) 102 | normalized_euclidian_distance = euclidian_distance / self.alpha_norm(org_matrix, 2) 103 | return normalized_euclidian_distance 104 | 105 | def visualize(self, input_, img_size=(224, 224), alpha_reg_alpha=6, alpha_reg_lambda=1e-7, 106 | tv_reg_beta=2, tv_reg_lambda=1e-8, n=20, save_path=None): 107 | 108 | if isinstance(input_, str): 109 | input_ = load_image(input_) 110 | input_ = apply_transforms(input_) 111 | 112 | if save_path is None: 113 | if not os.path.exists('results/inverted'): 114 | os.makedirs('results/inverted') 115 | save_path = 'results/inverted' 116 | 117 | # generate random image to be optimized 118 | opt_img = 1e-1 * torch.randn(1, 3, img_size[0], img_size[1]) 119 | opt_img.requires_grad = True 120 | 121 | # define optimizer 122 | optimizer = optim.SGD([opt_img], lr=1e4, momentum=0.9) 123 | 124 | # forward propagation to activate hook function 125 | image_output = self.model(input_) 126 | input_image_layer_output = self.activations['value'] 127 | 128 | for i in range(n): 129 | optimizer.zero_grad() 130 | opt_img_output = self.model(opt_img) 131 | output = self.activations['value'] 132 | 133 | euc_loss = 1e-1 * self.euclidian_loss(input_image_layer_output.detach(), output) 134 | 135 | reg_alpha = alpha_reg_lambda * self.alpha_norm(opt_img, alpha_reg_alpha) 136 | reg_total_variation = tv_reg_lambda * self.total_variation_norm(opt_img, tv_reg_beta) 137 | 138 | loss = euc_loss + reg_alpha + reg_total_variation 139 | 140 | loss.backward() 141 | optimizer.step() 142 | 143 | # save image 144 | if i % 1 == 0: 145 | print('Iteration:', str(i), 'Loss:', loss.data.numpy()) 146 | recreated_img = denormalize(opt_img) 147 | img_path = os.path.join(save_path, 'Inverted_Iteration_' + str(i) + '.jpg') 148 | save_image(recreated_img, img_path) 149 | 150 | if i % 40 == 0: 151 | for param_group in optimizer.param_groups: 152 | param_group['lr'] *= 0.1 153 | -------------------------------------------------------------------------------- /xdeep/xlocal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/__init__.py -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/gradient/__init__.py -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/backprop/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/gradient/backprop/__init__.py -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/backprop/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseProp(object): 6 | """ 7 | Base class for backpropagation. 8 | """ 9 | 10 | def __init__(self, model): 11 | """Init 12 | 13 | # Arguments: 14 | model: torchvision.models. A pretrained model. 15 | handle: list. Handle list that register a hook function. 16 | relu_outputs: list. Forward output after relu. 17 | 18 | """ 19 | self.model = model 20 | self.handle = [] 21 | self.relu_outputs = [] 22 | 23 | def _register_conv_hook(self): 24 | 25 | """ 26 | Register hook function to save gradient w.r.t input image. 27 | """ 28 | 29 | def _record_gradients(module, grad_in, grad_out): 30 | self.gradients = grad_in[0] 31 | 32 | for _, module in self.model.named_modules(): 33 | if isinstance(module, nn.modules.conv.Conv2d) and module.in_channels == 3: 34 | backward_handle = module.register_backward_hook(_record_gradients) 35 | self.handle.append(backward_handle) 36 | 37 | def _register_relu_hooks(self): 38 | 39 | """ 40 | Register hook function to save forward and backward relu result. 41 | """ 42 | 43 | # Save forward propagation output of the ReLU layer 44 | def _record_output(module, input_, output): 45 | self.relu_outputs.append(output) 46 | 47 | def _clip_gradients(module, grad_in, grad_out): 48 | # keep positive forward propagation output 49 | relu_output = self.relu_outputs.pop() 50 | relu_output[relu_output > 0] = 1 51 | 52 | # keep positive backward propagation gradient 53 | positive_grad_out = torch.clamp(grad_out[0], min=0.0) 54 | 55 | # generate modified guided gradient 56 | modified_grad_out = positive_grad_out * relu_output 57 | 58 | return (modified_grad_out, ) 59 | 60 | for _, module in self.model.named_modules(): 61 | if isinstance(module, nn.ReLU): 62 | forward_handle = module.register_forward_hook(_record_output) 63 | backward_handle = module.register_backward_hook(_clip_gradients) 64 | self.handle.append(forward_handle) 65 | self.handle.append(backward_handle) 66 | -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/backprop/guided_backprop.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | 3 | 4 | class Backprop(BaseProp): 5 | 6 | """ Generates vanilla or guided backprop gradients of a target class output w.r.t. an input image. 7 | 8 | # Arguments: 9 | model: torchvision.models. A pretrained model. 10 | guided: bool. If True, perform guided backpropagation. Defaults to False. 11 | 12 | # Return: 13 | Backprop Class. 14 | """ 15 | 16 | def __init__(self, model, guided=False): 17 | super().__init__(model) 18 | self.model.eval() 19 | self.guided = guided 20 | self.gradients = None 21 | self._register_conv_hook() 22 | 23 | def calculate_gradients(self, 24 | input_, 25 | target_class=None, 26 | take_max=False, 27 | use_gpu=False): 28 | 29 | """ Calculate gradient. 30 | 31 | # Arguments 32 | input_: torch.Tensor. Preprocessed image with shape (1, C, H, W). 33 | target_class: int. Index of target class. Default to None and use the prediction result as target class. 34 | take_max: bool. Take the maximum across colour channels. Defaults to False. 35 | use_gpu. bool. Use GPU or not. Defaults to False. 36 | 37 | # Return: 38 | Gradient (torch.Tensor) with shape (C, H, W). If take max is True, with shape (1, H, W). 39 | """ 40 | 41 | if self.guided: 42 | self._register_relu_hooks() 43 | 44 | if torch.cuda.is_available() and use_gpu: 45 | self.model = self.model.to('cuda') 46 | input_ = input_.to('cuda') 47 | 48 | # Create a empty tensor to save gradients 49 | self.gradients = torch.zeros(input_.shape) 50 | 51 | output = self.model(input_) 52 | 53 | self.model.zero_grad() 54 | 55 | if output.shape == torch.Size([1]): 56 | target = None 57 | else: 58 | pred_class = output.argmax().item() 59 | 60 | # Create a Tensor with zero elements, set the element at pred class index to be 1 61 | target = torch.zeros(output.shape) 62 | 63 | # If target class is None, calculate gradient of predicted class. 64 | if target_class is None: 65 | target[0][pred_class] = 1 66 | else: 67 | target[0][target_class] = 1 68 | 69 | if torch.cuda.is_available() and use_gpu: 70 | target = target.to('cuda') 71 | 72 | # Calculate gradients w.r.t. input image 73 | output.backward(gradient=target) 74 | 75 | gradients = self.gradients.detach().cpu()[0] 76 | 77 | if take_max: 78 | gradients = gradients.max(dim=0, keepdim=True)[0] 79 | 80 | for module in self.handle: 81 | module.remove() 82 | 83 | return gradients 84 | -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/backprop/integrate_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def generate_integrated_grad(explainer, input_, target_class=None, n=25): 6 | 7 | """ Generates integrate gradients of given explainer. You can use this with both vanilla 8 | and guided backprop 9 | 10 | # Arguments: 11 | explainer: class. Backprop method. 12 | input_: torch.Tensor. Preprocessed image with shape (N, C, H, W). 13 | target_class: int. Index of target class. Default to None. 14 | n: int. Integrate steps. Default to 10. 15 | 16 | # Return: 17 | Integrated gradient (torch.Tensor) with shape (C, H, W). 18 | """ 19 | 20 | # Generate xbar images 21 | step_list = np.arange(n + 1) / n 22 | 23 | # Remove zero 24 | step_list = step_list[1:] 25 | 26 | # Create a empty tensor 27 | integrated_grads = torch.zeros(input_.size()).squeeze() 28 | 29 | for step in step_list: 30 | single_integrated_grad = explainer.calculate_gradients(input_*step, target_class) 31 | integrated_grads = integrated_grads + single_integrated_grad / n 32 | 33 | return integrated_grads 34 | -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/backprop/smooth_grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def generate_smooth_grad(explainer, input_, target_class=None, n=50, mean=0, sigma_multiplier=4): 5 | 6 | """ Generates smooth gradients of given explainer. 7 | 8 | # Arguments 9 | explainer: class. Backprop method. 10 | input_: torch.Tensor. Preprocessed image with shape (1, C, H, W). 11 | target_class: int. Index of target class. Defaults to None. 12 | n: int. Amount of noisy images used to smooth gradient. Default to 10. 13 | mean: int. Mean value of normal distribution when generating noise. Default to 0. 14 | sigma_multiplier: int. Sigma multiplier when calculating std of noise. Default to 4. 15 | 16 | # Return: 17 | Smooth gradient (torch.Tensor) with shape (C, H, W). 18 | """ 19 | 20 | # Generate an empty image with shape (C, H, W) to save smooth gradient 21 | smooth_grad = torch.zeros(input_.size()).squeeze() 22 | 23 | sigma = sigma_multiplier / (torch.max(input_)-torch.min(input_)).item() 24 | 25 | x = 0 26 | while x < n: 27 | # Generate noise 28 | noise = input_.new(input_.size()).normal_(mean, sigma**2) 29 | 30 | # Add noise to the image 31 | noisy_img = input_ + noise 32 | 33 | # calculate gradient for noisy image 34 | grads = explainer.calculate_gradients(noisy_img, target_class) 35 | 36 | # accumulate gradient 37 | smooth_grad = smooth_grad + grads 38 | 39 | x += 1 40 | 41 | # average 42 | smooth_grad = smooth_grad / n 43 | 44 | return smooth_grad 45 | -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/cam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/gradient/cam/__init__.py -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/cam/basecam.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class BaseCAM(object): 4 | 5 | """ 6 | Base class for Class Activation Mapping. 7 | """ 8 | 9 | def __init__(self, model_dict): 10 | """Init 11 | 12 | # Arguments 13 | model_dict: dict. A dict with format model_dict = dict(arch=self.model, layer_name=target_layer_name). 14 | 15 | """ 16 | 17 | layer_name = model_dict['layer_name'] 18 | 19 | self.model_arch = model_dict['arch'] 20 | self.model_arch.eval() 21 | self.gradients = dict() 22 | self.activations = dict() 23 | 24 | # save gradient 25 | def backward_hook(module, grad_input, grad_output): 26 | self.gradients['value'] = grad_output[0] 27 | return None 28 | 29 | # save activation map 30 | def forward_hook(module, input, output): 31 | self.activations['value'] = output 32 | return None 33 | 34 | target_layer = self.find_layer(self.model_arch, layer_name) 35 | target_layer.register_forward_hook(forward_hook) 36 | target_layer.register_backward_hook(backward_hook) 37 | 38 | def find_layer(self, arch, target_layer_name): 39 | 40 | if target_layer_name is None: 41 | if 'resnet' in str(type(arch)): 42 | target_layer_name = 'layer4' 43 | elif 'alexnet' in str(type(arch)) or 'vgg' in str(type(arch)) or 'squeezenet' in str(type(arch)) or 'densenet' in str(type(arch)): 44 | target_layer_name = 'features' 45 | else: 46 | raise Exception('Invalid layer name! Please specify layer name.', target_layer_name) 47 | 48 | hierarchy = target_layer_name.split('_') 49 | 50 | if hierarchy[0] not in arch._modules.keys(): 51 | raise Exception('Invalid layer name!', target_layer_name) 52 | 53 | target_layer = arch._modules[hierarchy[0]] 54 | 55 | if len(hierarchy) >= 2: 56 | if hierarchy[1] not in target_layer._modules.keys(): 57 | raise Exception('Invalid layer name!', target_layer_name) 58 | target_layer = target_layer._modules[hierarchy[1]] 59 | 60 | if len(hierarchy) >= 3: 61 | if hierarchy[2] not in target_layer._modules.keys(): 62 | raise Exception('Invalid layer name!', target_layer_name) 63 | target_layer = target_layer._modules[hierarchy[2]] 64 | 65 | if len(hierarchy) >= 4: 66 | if hierarchy[3] not in target_layer._modules.keys(): 67 | raise Exception('Invalid layer name!', target_layer_name) 68 | target_layer = target_layer._modules[hierarchy[3]] 69 | 70 | return target_layer 71 | 72 | def forward(self, input_, class_idx=None, retain_graph=False): 73 | return None 74 | 75 | def __call__(self, input_, class_idx=None, retain_graph=False): 76 | return self.forward(input_, class_idx, retain_graph) 77 | -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/cam/gradcam.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from .basecam import BaseCAM 3 | 4 | 5 | class GradCAM(BaseCAM): 6 | """ 7 | GradCAM, inherit from BaseCAM 8 | """ 9 | 10 | def __init__(self, model_dict): 11 | super().__init__(model_dict) 12 | 13 | def forward(self, input_, class_idx=None, retain_graph=False): 14 | """Generates GradCAM result. 15 | 16 | # Arguments 17 | input_: torch.Tensor. Preprocessed image with shape (1, C, H, W). 18 | class_idx: int. Index of target class. Defaults to be index of predicted class. 19 | 20 | # Return 21 | Result of GradCAM (torch.Tensor) with shape (1, H, W). 22 | """ 23 | 24 | b, c, h, w = input_.size() 25 | logit = self.model_arch(input_) 26 | 27 | if class_idx is None: 28 | score = logit[:, logit.max(1)[-1]].squeeze() 29 | else: 30 | score = logit[:, class_idx].squeeze() 31 | 32 | self.model_arch.zero_grad() 33 | score.backward(retain_graph=retain_graph) 34 | 35 | gradients = self.gradients['value'] 36 | activations = self.activations['value'] 37 | b, k, u, v = gradients.size() 38 | 39 | alpha = gradients.view(b, k, -1).mean(2) 40 | weights = alpha.view(b, k, 1, 1) 41 | 42 | saliency_map = (weights * activations).sum(1, keepdim=True) 43 | saliency_map = F.relu(saliency_map) 44 | saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False) 45 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max() 46 | saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data 47 | 48 | return saliency_map -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/cam/gradcampp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .gradcam import GradCAM 4 | 5 | 6 | class GradCAMpp(GradCAM): 7 | """ 8 | GradCAM++, inherit from BaseCAM 9 | """ 10 | 11 | def __init__(self, model_dict): 12 | super(GradCAMpp, self).__init__(model_dict) 13 | 14 | def forward(self, input_image, class_idx=None, retain_graph=False): 15 | 16 | """Generates GradCAM++ result. 17 | 18 | # Arguments 19 | input_image: torch.Tensor. Preprocessed image with shape (1, C, H, W). 20 | class_idx: int. Index of target class. Defaults to be index of predicted class. 21 | 22 | # Return 23 | Result of GradCAM++ (torch.Tensor) with shape (1, H, W). 24 | """ 25 | 26 | b, c, h, w = input_image.size() 27 | 28 | logit = self.model_arch(input_image) 29 | if class_idx is None: 30 | score = logit[:, logit.max(1)[-1]].squeeze() 31 | else: 32 | score = logit[:, class_idx].squeeze() 33 | 34 | self.model_arch.zero_grad() 35 | score.backward(retain_graph=retain_graph) 36 | gradients = self.gradients['value'] 37 | activations = self.activations['value'] 38 | b, k, u, v = gradients.size() 39 | 40 | alpha_num = gradients.pow(2) 41 | 42 | global_sum = activations.view(b, k, u * v).sum(-1, keepdim=True).view(b, k, 1, 1) 43 | alpha_denom = gradients.pow(2).mul(2) + global_sum.mul(gradients.pow(3)) 44 | 45 | alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom)) 46 | 47 | alpha = alpha_num.div(alpha_denom + 1e-7) 48 | positive_gradients = F.relu(score.exp() * gradients) 49 | weights = (alpha * positive_gradients).view(b, k, u * v).sum(-1).view(b, k, 1, 1) 50 | 51 | saliency_map = (weights * activations).sum(1, keepdim=True) 52 | saliency_map = F.relu(saliency_map) 53 | saliency_map = F.interpolate(saliency_map, size=(224, 224), mode='bilinear', align_corners=False) 54 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max() 55 | saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data 56 | 57 | return saliency_map -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/cam/scorecam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .basecam import * 4 | 5 | 6 | class ScoreCAM(BaseCAM): 7 | """ 8 | ScoreCAM, inherit from BaseCAM 9 | """ 10 | 11 | def __init__(self, model_dict): 12 | super().__init__(model_dict) 13 | 14 | def forward(self, input, class_idx=None, retain_graph=False): 15 | 16 | """Generates ScoreCAM result. 17 | 18 | # Arguments 19 | input_: torch.Tensor. Preprocessed image with shape (1, C, H, W). 20 | class_idx: int. Index of target class. Defaults to be index of predicted class. 21 | 22 | # Return 23 | Result of GradCAM (torch.Tensor) with shape (1, H, W). 24 | """ 25 | 26 | b, c, h, w = input.size() 27 | 28 | logit = self.model_arch(input) 29 | 30 | if class_idx is None: 31 | predicted_class = logit.max(1)[-1] 32 | score = logit[:, predicted_class].squeeze() 33 | else: 34 | predicted_class = torch.LongTensor([class_idx]) 35 | score = logit[:, predicted_class].squeeze() 36 | 37 | self.model_arch.zero_grad() 38 | score.backward(retain_graph=retain_graph) 39 | gradients = self.gradients['value'] 40 | activations = self.activations['value'] 41 | b, k, u, v = gradients.size() 42 | score_saliency_map = torch.zeros((1, 1, h, w)) 43 | 44 | with torch.no_grad(): 45 | for i in range(k): 46 | saliency_map = torch.unsqueeze(activations[:, i, :, :], 1) 47 | saliency_map = F.relu(saliency_map) 48 | saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False) 49 | 50 | if (saliency_map.max() == saliency_map.min()).numpy().tolist() == 1: 51 | continue 52 | 53 | norm_saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min()) 54 | output = self.model_arch(input * norm_saliency_map) 55 | output = (output - output.mean()) / output.std() 56 | score = output[0][predicted_class] 57 | score = F.softmax(score) 58 | score_saliency_map += score * saliency_map 59 | 60 | score_saliency_map = F.relu(score_saliency_map) 61 | score_saliency_map_min, score_saliency_map_max = score_saliency_map.min(), score_saliency_map.max() 62 | 63 | if score_saliency_map_min == score_saliency_map_max: 64 | return None 65 | 66 | score_saliency_map = (score_saliency_map - score_saliency_map_min).div(score_saliency_map_max - score_saliency_map_min).data 67 | return score_saliency_map 68 | -------------------------------------------------------------------------------- /xdeep/xlocal/gradient/explainers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Thu Aug 02 16:21:22 2019 3 | @author: Haofan Wang - github.com/haofanwang 4 | """ 5 | 6 | 7 | from xdeep.utils import * 8 | 9 | from xdeep.xlocal.gradient.backprop.guided_backprop import Backprop 10 | from xdeep.xlocal.gradient.backprop.integrate_grad import generate_integrated_grad 11 | from xdeep.xlocal.gradient.backprop.smooth_grad import generate_smooth_grad 12 | 13 | from xdeep.xlocal.gradient.cam.gradcam import GradCAM 14 | from xdeep.xlocal.gradient.cam.gradcampp import GradCAMpp 15 | from xdeep.xlocal.gradient.cam.scorecam import ScoreCAM 16 | 17 | 18 | class ImageInterpreter(object): 19 | 20 | """ Class for image explanation 21 | 22 | ImageInterpreter provides unified interface to call different visualization methods. 23 | The user can use it by several lines. It includes '__init__()' and 'explain()'. The 24 | users can specify the name of visualization method and target layer. If params are 25 | not specified, default params will be used. 26 | """ 27 | 28 | def __init__(self, model): 29 | """Init 30 | 31 | # Arguments 32 | model: torchvision.models. A pretrained model. 33 | result: torch.Tensor. Generated saliency map. 34 | 35 | """ 36 | self.model = model 37 | self.result = None 38 | 39 | def vanilla_backprop(self, model): 40 | return Backprop(model, guided=False) 41 | 42 | def guided_backprop(self, model): 43 | return Backprop(model, guided=True) 44 | 45 | def smooth_grad(self, model, input_, guided=False): 46 | explainer = Backprop(model, guided=guided) 47 | return generate_smooth_grad(explainer, input_) 48 | 49 | def smooth_guided_grad(self, model, input_, guided=True): 50 | explainer = Backprop(model, guided=guided) 51 | return generate_smooth_grad(explainer, input_) 52 | 53 | def integrate_grad(self, model, input_, guided=False): 54 | explainer = Backprop(model, guided=guided) 55 | return generate_integrated_grad(explainer, input_) 56 | 57 | def integrate_guided_grad(self, model, input_, guided=True): 58 | explainer = Backprop(model, guided=guided) 59 | return generate_integrated_grad(explainer, input_) 60 | 61 | def gradcam(self, model_dict): 62 | return GradCAM(model_dict) 63 | 64 | def gradcampp(self, model_dict): 65 | return GradCAMpp(model_dict) 66 | 67 | def scorecam(self, model_dict): 68 | return ScoreCAM(model_dict) 69 | 70 | def explain(self, image, method_name, viz=True, target_layer_name=None, target_class=None, save_path=None): 71 | 72 | """Function to call different local visualization methods. 73 | 74 | # Arguments 75 | image: str or PIL.Image. The input image to ImageInterpreter. User can directly provide the path of input 76 | image, or provide PIL.Image foramt image. For example, image can be './test.jpg' or 77 | Image.open('./test.jpg').convert('RGB'). 78 | method_name: str. The name of interpreter method. Currently support for 'vanilla_backprop', 'guided_backprop', 79 | 'smooth_grad', 'smooth_guided_grad', 'integrate_grad', 'integrate_guided_grad', 'gradcam', 80 | 'gradcampp', 'scorecam'. 81 | viz: bool. Visualize or not. Defaults to True. 82 | target_layer_name: str. The layer to hook gradients and activation map. Defaults to the name of the latest 83 | activation map. User can also provide their target layer like 'features_29' or 'layer4' 84 | with respect to different network architectures. 85 | target_class: int. The index of target class. Default to be the index of predicted class. 86 | save_path: str. Path to save the saliency map. Default to be None. 87 | """ 88 | 89 | method_switch = {"vanilla_backprop": self.vanilla_backprop, 90 | "guided_backprop": self.guided_backprop, 91 | "smooth_grad": self.smooth_grad, 92 | "smooth_guided_grad": self.smooth_guided_grad, 93 | "integrate_grad": self.integrate_grad, 94 | "integrate_guided_grad": self.integrate_guided_grad, 95 | "gradcam": self.gradcam, 96 | "gradcampp": self.gradcampp, 97 | "scorecam": self.scorecam} 98 | 99 | if type(image) is str: 100 | image = load_image(image) 101 | 102 | norm_image = apply_transforms(image) 103 | 104 | if method_name not in method_switch.keys(): 105 | raise Exception("Non-support method!", method_name) 106 | 107 | if 'backprop' in method_name: 108 | explainer = method_switch[method_name](self.model) 109 | self.result = explainer.calculate_gradients(norm_image, target_class=target_class) 110 | elif 'cam' in method_name: 111 | model_dict = dict(arch=self.model, layer_name=target_layer_name, input_size=(224, 224)) 112 | explainer = method_switch[method_name](model_dict) 113 | self.result = explainer(norm_image, class_idx=target_class) 114 | else: 115 | self.result = method_switch[method_name](self.model, norm_image) 116 | 117 | if viz is True: 118 | visualize(norm_image, self.result, save_path=save_path) 119 | 120 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/.ipynb_checkpoints/Image_Tutorial-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import warnings\n", 10 | "warnings.filterwarnings('ignore')\n", 11 | "import tensorflow as tf\n", 12 | "slim = tf.contrib.slim\n", 13 | "import sys\n", 14 | "sys.path.append('/Users/zhangzijian/models/research/slim')\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "from nets import inception\n", 17 | "from preprocessing import inception_preprocessing\n", 18 | "from datasets import imagenet\n", 19 | "import time" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 4, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import xdeep_image" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [] 37 | } 38 | ], 39 | "metadata": { 40 | "kernelspec": { 41 | "display_name": "Python 3", 42 | "language": "python", 43 | "name": "python3" 44 | }, 45 | "language_info": { 46 | "codemirror_mode": { 47 | "name": "ipython", 48 | "version": 3 49 | }, 50 | "file_extension": ".py", 51 | "mimetype": "text/x-python", 52 | "name": "python", 53 | "nbconvert_exporter": "python", 54 | "pygments_lexer": "ipython3", 55 | "version": "3.6.6" 56 | } 57 | }, 58 | "nbformat": 4, 59 | "nbformat_minor": 2 60 | } 61 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/perturbation/__init__.py -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/cle/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/perturbation/cle/__init__.py -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/exceptions.py: -------------------------------------------------------------------------------- 1 | class XDeepError(Exception): 2 | """Raise for XDeep errors""" 3 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/explainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base explainer for xdeep. 3 | """ 4 | from .exceptions import XDeepError 5 | 6 | 7 | class Explainer(object): 8 | """A base explainer.""" 9 | 10 | def __init__(self, predict_proba, class_names): 11 | """Init function. 12 | 13 | # Arguments 14 | predict_proba: Function. Classifier prediction probability function. 15 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 16 | """ 17 | self.labels = list() 18 | self.explainer = None 19 | self.explanation = None 20 | self.instance = None 21 | self.original_pred = None 22 | self.class_names = class_names 23 | self.predict_proba = predict_proba 24 | 25 | def set_parameters(self, **kwargs): 26 | """Parameter setter. 27 | 28 | # Arguments 29 | **kwargs: Other parameters that depends on which method you use. For more detail, please check xdeep documentation. 30 | """ 31 | pass 32 | 33 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs): 34 | """Explain function. 35 | 36 | # Arguments 37 | instance: One instance to be explained. 38 | top_labels: Integer. Number of labels you care about. 39 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 40 | **kwargs: Other parameters that depends on which method you use. For more detail, please check xdeep documentation. 41 | """ 42 | self.instance = instance 43 | try: 44 | proba = self.predict_proba([instance])[0] 45 | except: 46 | proba = self.predict_proba(instance)[0] 47 | self.original_pred = proba 48 | if top_labels is not None: 49 | self.labels = proba.argsort()[-top_labels:] 50 | else: 51 | self.labels = labels 52 | 53 | def show_explanation(self, **kwargs): 54 | """Visualization of explanation. 55 | 56 | # Arguments 57 | **kwargs: Other parameters that depends on which method you use. For more detail, please check xdeep documentation. 58 | """ 59 | if not hasattr(self, 'explanation'): 60 | raise XDeepError("Please call explain function first.") 61 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/image_explainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/perturbation/image_explainers/__init__.py -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/image_explainers/anchor_image.py: -------------------------------------------------------------------------------- 1 | # The implementation of anchor refers the original authors' codes in GitHub https://github.com/marcotcr/anchor. 2 | # The Copyright of algorithm anchor is reserved for (c) 2018, Marco Tulio Correia Ribeiro. 3 | import copy 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | from skimage.segmentation import mark_boundaries 8 | from anchor.anchor_image import AnchorImage 9 | from ..explainer import Explainer 10 | 11 | 12 | class XDeepAnchorImageExplainer(Explainer): 13 | 14 | def __init__(self, predict_proba, class_names): 15 | """Init function. 16 | 17 | # Arguments 18 | predict_proba: Function. Classifier prediction probability function. 19 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 20 | """ 21 | Explainer.__init__(self, predict_proba, class_names) 22 | # Initialize explainer 23 | self.set_parameters() 24 | 25 | def set_parameters(self, **kwargs): 26 | """Parameter setter for anchor_text. 27 | 28 | # Arguments 29 | **kwargs: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor 30 | """ 31 | self.explainer = AnchorImage(**kwargs) 32 | 33 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs): 34 | """Generate explanation for a prediction using certain method. 35 | Anchor does not use top_labels and labels. 36 | 37 | # Arguments 38 | instance: Array. An image to be explained. 39 | **kwargs: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor. 40 | """ 41 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels) 42 | try: 43 | self.labels = self.predict_proba([instance]).argsort()[0][-1:] 44 | except: 45 | self.labels = self.predict_proba(instance).argsort()[0][-1:] 46 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba, **kwargs) 47 | 48 | def __show_image_no_axis(self, image, boundaries=None, save=None): 49 | """Show image with no axis 50 | 51 | # Arguments 52 | image: Array. The image you want to show. 53 | boundaries: Boundaries 54 | save: Str. Save path. 55 | """ 56 | fig = plt.figure() 57 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 58 | ax.set_axis_off() 59 | fig.add_axes(ax) 60 | if boundaries is not None: 61 | ax.imshow(mark_boundaries(image, boundaries)) 62 | else: 63 | ax.imshow(image) 64 | if save is not None: 65 | plt.savefig(save) 66 | plt.show() 67 | 68 | def show_explanation(self, deprocess=None): 69 | """Visualization of explanation of anchor_image. 70 | 71 | # Arguments 72 | deprocess: Function. A function to deprocess the image. 73 | """ 74 | Explainer.show_explanation(self) 75 | exp = self.explanation 76 | segments = exp[0] 77 | exp = exp[1] 78 | explainer = self.explainer 79 | predict_fn = self.predict_proba 80 | image = self.instance 81 | 82 | print() 83 | print("Anchor Explanation") 84 | print() 85 | if deprocess is not None: 86 | image = deprocess(copy.deepcopy(image)) 87 | temp = copy.deepcopy(image) 88 | temp_img = copy.deepcopy(temp) 89 | temp[:] = np.average(temp) 90 | for x in exp: 91 | temp[segments == x[0]] = temp_img[segments == x[0]] 92 | print('Prediction ', predict_fn(np.expand_dims(image, 0))[0].argmax()) 93 | assert hasattr(exp, '__len__') 94 | if len(exp) == 0: 95 | print("Fail to find Anchor") 96 | else: 97 | print('Confidence ', exp[-1][2]) 98 | self.__show_image_no_axis(temp) 99 | if len(exp[-1][3]) != 0: 100 | print('Counter Examples:') 101 | for e in exp[-1][3]: 102 | data = e[:-1] 103 | temp = explainer.dummys[e[-1]].copy() 104 | for x in data.nonzero()[0]: 105 | temp[segments == x] = image[segments == x] 106 | self.__show_image_no_axis(temp) 107 | print('Prediction = ', predict_fn(np.expand_dims(temp, 0))[0].argmax()) 108 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/image_explainers/cle_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.colors import LinearSegmentedColormap 4 | 5 | from ..cle.cle_image import CLEImageExplainer 6 | from ..explainer import Explainer 7 | from ..exceptions import XDeepError 8 | 9 | plt.rcParams.update({'font.size': 12}) 10 | 11 | 12 | class XDeepCLEImageExplainer(Explainer): 13 | 14 | def __init__(self, predict_proba, class_names): 15 | """Init function. 16 | 17 | # Arguments 18 | predict_proba: Function. Classifier prediction probability function. 19 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 20 | """ 21 | Explainer.__init__(self, predict_proba, class_names) 22 | # Initialize explainer 23 | self.set_parameters() 24 | 25 | def set_parameters(self, **kwargs): 26 | """Parameter setter for cle_image. 27 | 28 | # Arguments 29 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 30 | """ 31 | self.explainer = CLEImageExplainer(**kwargs) 32 | 33 | def explain(self, instance, top_labels=None, labels=(1,), care_segments=None, spans=(2,), include_original_feature=True, **kwargs): 34 | """Generate explanation for a prediction using certain method. 35 | 36 | # Arguments 37 | instance: Array. An image to be explained. 38 | top_labels: Integer. Number of labels you care about. 39 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 40 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 41 | """ 42 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels) 43 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba, labels=self.labels, top_labels=top_labels, 44 | care_segments=care_segments, spans=spans, 45 | include_original_feature=include_original_feature, **kwargs) 46 | 47 | def show_explanation(self, num_features=5, deprocess=None, positive_only=True): 48 | """Visualization of explanation of lime_image. 49 | 50 | # Arguments 51 | num_features: Integer. How many features you care about. 52 | deprocess: Function. A function to deprocess the image. 53 | positive_only: Boolean. Whether only show feature with positive weight. 54 | """ 55 | if not isinstance(num_features, int) or num_features < 1: 56 | raise XDeepError("Number of features should be positive integer.") 57 | Explainer.show_explanation(self) 58 | exp = self.explanation 59 | labels = self.labels 60 | 61 | print() 62 | print("CLE Explanation") 63 | print() 64 | 65 | colors = list() 66 | for l in np.linspace(1, 0, 100): 67 | colors.append((245 / 255, 39 / 255, 87 / 255, l)) 68 | for l in np.linspace(0, 1, 100): 69 | colors.append((24 / 255, 196 / 255, 93 / 255, l)) 70 | colormap = LinearSegmentedColormap.from_list("CLE", colors) 71 | 72 | # Plot 73 | fig, axes = plt.subplots(nrows=len(labels), ncols=num_features+1, 74 | figsize=((num_features+1)*3,len(labels)*3)) 75 | 76 | assert hasattr(labels, '__len__') 77 | for n in range(len(labels)): 78 | # Inverse 79 | label = labels[-n-1] 80 | result = exp.intercept[label] 81 | local_exp = exp.local_exp[label] 82 | for item in local_exp: 83 | result += item[1] 84 | print("Explanation for label {}:".format(self.class_names[label])) 85 | print("Local Prediction: {:.3f}".format(result)) 86 | print("Original Prediction: {:.3f}".format(self.original_pred[label])) 87 | print() 88 | img, mask, fs, ws, ms = exp.get_image_and_mask( 89 | label, positive_only=positive_only, 90 | num_features=num_features, hide_rest=False 91 | ) 92 | if deprocess is not None: 93 | img = deprocess(img) 94 | 95 | mask = np.zeros_like(mask, dtype=np.float64) 96 | for index in range(len(ws)): 97 | mask[ms[index] == 1] += ws[index] 98 | 99 | max_val = max(np.max(ws), np.max(mask), abs(np.min(ws)), abs(np.min(mask))) 100 | min_val = -max_val 101 | 102 | if len(labels) == 1: 103 | axs = axes 104 | else: 105 | axs = axes[n] 106 | axs[0].axis('off') 107 | axs[0].set_title("{}".format(self.class_names[label])) 108 | axs[0].imshow(img, alpha=0.9) 109 | axs[0].imshow(mask, cmap=colormap, vmin=min_val, vmax=max_val) 110 | 111 | for idx in range(1, num_features+1): 112 | m = ms[idx-1] * ws[idx-1] 113 | title = "seg." 114 | for item in fs[idx-1]: 115 | title += str(item)+"&" 116 | title = title[:-1] 117 | axs[idx].set_title(title) 118 | axs[idx].imshow(img, alpha=0.5) 119 | axs[idx].imshow(m, cmap=colormap, vmin=min_val, vmax=max_val) 120 | axs[idx].axis('off') 121 | plt.show() 122 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/image_explainers/lime_image.py: -------------------------------------------------------------------------------- 1 | # The implementation of LIME refers the original authors' codes in GitHub https://github.com/limetext/lime. 2 | # The Copyright of algorithm LIME is reserved for (c) 2016, Marco Tulio Correia Ribeiro. 3 | import matplotlib.pyplot as plt 4 | from skimage.segmentation import mark_boundaries 5 | 6 | from lime.lime_image import LimeImageExplainer 7 | 8 | from ..explainer import Explainer 9 | 10 | plt.rcParams.update({'font.size': 12}) 11 | 12 | 13 | class XDeepLimeImageExplainer(Explainer): 14 | 15 | def __init__(self, predict_proba, class_names): 16 | """Init function. 17 | 18 | # Arguments 19 | predict_proba: Function. Classifier prediction probability function. 20 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 21 | """ 22 | Explainer.__init__(self, predict_proba, class_names) 23 | # Initialize explainer 24 | self.set_parameters() 25 | 26 | def set_parameters(self, **kwargs): 27 | """Parameter setter for lime_text. 28 | 29 | # Arguments 30 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 31 | """ 32 | self.explainer = LimeImageExplainer(**kwargs) 33 | 34 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs): 35 | """Generate explanation for a prediction using certain method. 36 | 37 | # Arguments 38 | instance: Array. An image to be explained. 39 | top_labels: Integer. Number of labels you care about. 40 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 41 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 42 | """ 43 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels) 44 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba, 45 | top_labels=top_labels, labels=self.labels, **kwargs) 46 | 47 | def show_explanation(self, deprocess=None, positive_only=True, num_features=5, hide_rest=False): 48 | """Visualization of explanation of lime_image. 49 | 50 | # Arguments 51 | deprocess: Function. A function to deprocess the image. 52 | positive_only: Boolean. Whether only show feature with positive weight. 53 | num_features: Integer. Numbers of feature you care about. 54 | hide_rest: Boolean. Whether to hide rest of the image. 55 | """ 56 | Explainer.show_explanation(self) 57 | exp = self.explanation 58 | labels = self.labels 59 | 60 | print() 61 | print("LIME Explanation") 62 | print() 63 | 64 | fig, axs = plt.subplots(nrows=1, ncols=len(labels), figsize=(len(labels)*3, 3)) 65 | 66 | assert hasattr(labels, '__len__') 67 | for idx in range(len(labels)): 68 | # Inverse 69 | label = labels[-idx-1] 70 | result = exp.intercept[label] 71 | local_exp = exp.local_exp[label] 72 | for item in local_exp: 73 | result += item[1] 74 | print("Explanation for label {}:".format(self.class_names[label])) 75 | print("Local Prediction: {:.3f}".format(result)) 76 | print("Original Prediction: {:.3f}".format(self.original_pred[label])) 77 | print() 78 | temp, mask = exp.get_image_and_mask( 79 | label, positive_only=positive_only, 80 | num_features=num_features, hide_rest=hide_rest 81 | ) 82 | if deprocess is not None: 83 | temp = deprocess(temp) 84 | if len(labels) == 1: 85 | axs.imshow(mark_boundaries(temp, mask), alpha=0.9) 86 | axs.axis('off') 87 | axs.set_title("{}".format(self.class_names[label])) 88 | else: 89 | axs[idx].imshow(mark_boundaries(temp, mask), alpha=0.9) 90 | axs[idx].axis('off') 91 | axs[idx].set_title("{}".format(self.class_names[label])) 92 | plt.show() -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/image_explainers/shap_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from PIL import Image 4 | from matplotlib.colors import LinearSegmentedColormap 5 | 6 | import shap 7 | 8 | from ..explainer import Explainer 9 | 10 | 11 | class XDeepShapImageExplainer(Explainer): 12 | 13 | def __init__(self, predict_proba, class_names, n_segment, segment): 14 | """Init function. 15 | 16 | # Arguments 17 | predict_proba: Function. Classifier prediction probability function. 18 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 19 | n_segment: Integer. Number of segments in the image. 20 | segment: Array. An array with 2 dimensions, segments_slic of the image. 21 | """ 22 | Explainer.__init__(self, predict_proba, class_names) 23 | self.n_segment = n_segment 24 | self.segment = segment 25 | 26 | @staticmethod 27 | def _mask_image(zs, segmentation, image, background=None): 28 | """Mask image function for shap.""" 29 | if background is None: 30 | background = image.mean((0, 1)) 31 | out = np.zeros((zs.shape[0], image.shape[0], image.shape[1], image.shape[2])) 32 | for i in range(zs.shape[0]): 33 | out[i, :, :, :] = image 34 | for j in range(zs.shape[1]): 35 | if zs[i, j] == 0: 36 | out[i][segmentation == j, :] = background 37 | return out 38 | 39 | def set_parameters(self, n_segment, segment): 40 | """Parameter setter for shap_image. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py. 41 | 42 | # Arguments 43 | n_segment: Integer. number of segments in the image. 44 | segment: Array. An array with 2 dimensions, segments_slic of the image. 45 | """ 46 | self.n_segment = n_segment 47 | self.segment = segment 48 | 49 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs): 50 | """Generate explanation for a prediction using certain method. 51 | 52 | # Arguments 53 | instance: Array. An image to be explained. 54 | top_labels: Integer. Number of labels you care about. 55 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 56 | **kwarg: Parameters setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py. 57 | """ 58 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels) 59 | 60 | def f(z): 61 | masked = self._mask_image(z, self.segment, self.instance, 255) 62 | return self.predict_proba(masked) 63 | 64 | self.explainer = shap.KernelExplainer(f, np.zeros((1, self.n_segment)), **kwargs) 65 | self.explanation = self.explainer.shap_values(np.ones((1, self.n_segment)), **kwargs) 66 | 67 | def __fill_segmentation(self, values, segmentation): 68 | """Fill segmentation 69 | 70 | # Arguments 71 | values: Array, weights. 72 | segmentation: Array. An array with 2 dimensions, segments_slic of the image. 73 | 74 | # Return 75 | out 76 | """ 77 | out = np.zeros(segmentation.shape) 78 | for i in range(len(values)): 79 | out[segmentation == i] = values[i] 80 | return out 81 | 82 | def show_explanation(self, deprocess=None): 83 | """Visualization of explanation of shap. 84 | 85 | # Arguments 86 | deprocess: Function. A function to deprocess the image. 87 | """ 88 | Explainer.show_explanation(self) 89 | shap_values = self.explanation 90 | labels = self.labels 91 | segments_slic = self.segment 92 | img = self.instance 93 | 94 | if deprocess is not None: 95 | img = deprocess(img) 96 | img = Image.fromarray(np.uint8(img*255)) 97 | colors = [] 98 | for l in np.linspace(1, 0, 100): 99 | colors.append((245 / 255, 39 / 255, 87 / 255, l)) 100 | for l in np.linspace(0, 1, 100): 101 | colors.append((24 / 255, 196 / 255, 93 / 255, l)) 102 | colormap = LinearSegmentedColormap.from_list("shap", colors) 103 | # Plot our explanations 104 | fig, axes = plt.subplots(nrows=1, ncols=len(labels)+1, figsize=(12, 4)) 105 | inds = labels 106 | axes[0].imshow(img) 107 | axes[0].axis('off') 108 | max_val = np.max([np.max(np.abs(shap_values[i][:, :-1])) for i in range(len(shap_values))]) 109 | assert hasattr(labels, '__len__') 110 | for i in range(len(labels)): 111 | m = self.__fill_segmentation(shap_values[inds[i]][0], segments_slic) 112 | axes[i + 1].set_title(self.class_names[labels[i]]) 113 | axes[i + 1].imshow(img, alpha=0.5) 114 | axes[i + 1].imshow(m, cmap=colormap, vmin=-max_val, vmax=max_val) 115 | axes[i + 1].axis('off') 116 | plt.show() 117 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/tabular_explainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/perturbation/tabular_explainers/__init__.py -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/tabular_explainers/anchor_tabular.py: -------------------------------------------------------------------------------- 1 | # The implementation of anchor refers the original authors' codes in GitHub https://github.com/marcotcr/anchor. 2 | # The Copyright of algorithm anchor is reserved for (c) 2018, Marco Tulio Correia Ribeiro.import numpy as np 3 | import numpy as np 4 | 5 | from anchor.anchor_tabular import AnchorTabularExplainer 6 | 7 | from ..explainer import Explainer 8 | from ..exceptions import XDeepError 9 | 10 | 11 | class XDeepAnchorTabularExplainer(Explainer): 12 | 13 | def __init__(self, class_names, feature_names, data, categorical_names=None): 14 | """Init function. 15 | 16 | # Arguments 17 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 18 | feature_names: List. list of names (strings) corresponding to the columns in the training data. 19 | data: Array. Full data including train data, validation_data and test data. 20 | categorical_names: Dict. A dict which maps from int to list of names, where categorical_names[x][y] represents the name of the yth value of column x. 21 | """ 22 | Explainer.__init__(self, None, class_names) 23 | self.feature_names = feature_names 24 | self.data = data 25 | self.categorical_names = categorical_names 26 | # Initialize explainer 27 | self.set_parameters() 28 | 29 | def set_parameters(self, **kwargs): 30 | """Parameter setter for anchor_tabular. 31 | 32 | # Arguments 33 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 34 | """ 35 | class_names = kwargs.pop("class_names", self.class_names) 36 | feature_names = kwargs.pop("feature_names", self.feature_names) 37 | data = kwargs.pop("data", self.data) 38 | categorical_names = kwargs.pop("categorical_names", self.categorical_names) 39 | self.explainer = AnchorTabularExplainer(class_names, feature_names, data=data, 40 | categorical_names=categorical_names) 41 | 42 | def get_anchor_encoder(self, train_data, train_labels, validation_data, validation_labels, discretizer='quartile'): 43 | """Get encoder for tabular data. 44 | 45 | If you want to use 'anchor' to explain tabular classifier. You need to get this encoder to encode your data, and train another model. 46 | 47 | # Arguments 48 | train_data: Array. Train data. 49 | train_labels: Array. Train labels. 50 | validation_data: Array. Validation set. 51 | validation_labels: Array. Validation labels. 52 | discretizer: Str. Discretizer for data. Please choose between 'quartile' and 'decile'. 53 | 54 | # Return 55 | A encoder object which has function 'transform'. 56 | """ 57 | self.explainer.fit(train_data, train_labels, validation_data, validation_labels, 58 | discretizer=discretizer) 59 | return self.explainer.encoder 60 | 61 | def set_anchor_predict_proba(self, predict_proba): 62 | """Predict function setter. 63 | 64 | # Arguments 65 | predict_proba: Function. A classifier prediction probability function. 66 | """ 67 | self.predict_proba = predict_proba 68 | 69 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs): 70 | """Generate explanation for a prediction using certain method. 71 | Anchor does not use top_labels and labels. 72 | 73 | # Arguments 74 | instance: Array. One row of tabular data to be explained. 75 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 76 | """ 77 | if self.predict_proba is None: 78 | raise XDeepError("Please call set_anchor_predict_proba to pass in your new predict function.") 79 | 80 | def predict_label(x): 81 | return np.argmax(self.predict_proba(x), axis=1) 82 | self.instance = instance 83 | try: 84 | self.labels = self.predict_proba(self.explainer.encoder.transform([instance])).argsort()[0][-1:] 85 | except: 86 | self.labels = self.predict_proba(self.explainer.encoder.transform(instance)).argsort()[0][-1:] 87 | self.explanation = self.explainer.explain_instance(instance, predict_label, **kwargs) 88 | 89 | def show_explanation(self, show_in_note_book=True, verbose=True): 90 | """Visualization of explanation of anchor_tabular. 91 | 92 | # Arguments 93 | show_in_note_book: Boolean. Whether show in jupyter notebook. 94 | verbose: Boolean. Whether print out examples and counter examples. 95 | """ 96 | Explainer.show_explanation(self) 97 | exp = self.explanation 98 | 99 | print() 100 | print("Anchor Explanation") 101 | print() 102 | if verbose: 103 | print() 104 | print('Examples where anchor applies and model predicts same to instance:') 105 | print() 106 | print(exp.examples(only_same_prediction=True)) 107 | print() 108 | print('Examples where anchor applies and model predicts different with instance:') 109 | print() 110 | print(exp.examples(only_different_prediction=True)) 111 | 112 | label = self.labels[0] 113 | print('Prediction: {}'.format(label)) 114 | print('Anchor: %s' % (' AND '.join(exp.names()))) 115 | print('Precision: %.2f' % exp.precision()) 116 | print('Coverage: %.2f' % exp.coverage()) 117 | if show_in_note_book: 118 | pass 119 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/tabular_explainers/cle_tabular.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from ..cle.cle_tabular import CLETabularExplainer 5 | from ..explainer import Explainer 6 | 7 | 8 | class XDeepCLETabularExplainer(Explainer): 9 | 10 | def __init__(self, predict_proba, class_names, feature_names, train, 11 | categorical_features=None, categorical_names=None): 12 | """Init function. 13 | 14 | # Arguments 15 | predict_proba: Function. A classifier prediction probability function. 16 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 17 | feature_names: List. A list of names (strings) corresponding to the columns in the training data. 18 | train: Array. Train data. 19 | categorical_features: List. A list of indices (ints) corresponding to the categorical columns. Everything else will be considered continuous. Values in these columns MUST be integers. 20 | categorical_names: Dict. A dict which maps from int to list of names, where categorical_names[x][y] represents the name of the yth value of column x. 21 | """ 22 | Explainer.__init__(self, predict_proba, class_names) 23 | self.train = train 24 | self.feature_names = feature_names 25 | self.categorical_features = categorical_features 26 | self.categorical_names = categorical_names 27 | # Initialize explainer 28 | self.set_parameters() 29 | 30 | def set_parameters(self, **kwargs): 31 | """Parameter setter for cle_tabular. 32 | 33 | # Arguments 34 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 35 | """ 36 | train = kwargs.pop("train", self.train) 37 | class_names = kwargs.pop("class_names", self.class_names) 38 | feature_names = kwargs.pop("feature_names", self.feature_names) 39 | categorical_features = kwargs.pop("categorical_features", self.categorical_features) 40 | categorical_names = kwargs.pop("categorical_names", self.categorical_names) 41 | self.explainer = CLETabularExplainer(train, class_names=class_names, 42 | feature_names=feature_names, categorical_features=categorical_features, 43 | categorical_names=categorical_names, **kwargs) 44 | 45 | def explain(self, instance, top_labels=None, labels=(1,), care_cols=None, spans=(2,), include_original_feature=True, **kwargs): 46 | """Generate explanation for a prediction using certain method. 47 | 48 | # Arguments 49 | instance: Array. One row of tabular data to be explained. 50 | top_labels: Integer. Number of labels you care about. 51 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 52 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 53 | """ 54 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels) 55 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba, labels=self.labels, 56 | top_labels=top_labels, 57 | care_cols=care_cols, spans=spans, 58 | include_original_feature=include_original_feature, **kwargs) 59 | 60 | def show_explanation(self, span=3, plot=True): 61 | """Visualization of explanation of cle_text. 62 | 63 | # Arguments 64 | span: Boolean. Each row shows how many features. 65 | plot: Boolean. Whether plot a figure. 66 | """ 67 | Explainer.show_explanation(self) 68 | exp = self.explanation 69 | labels = self.labels 70 | 71 | print() 72 | print("CLE Explanation") 73 | print("Instance: {}".format(self.instance)) 74 | print() 75 | 76 | words = exp.domain_mapper.exp_feature_names 77 | D = len(words) 78 | 79 | # parts is a list which contains all the combination of features. 80 | parts = self.explainer.all_combinations 81 | 82 | assert hasattr(labels, '__len__') 83 | for i in range(len(labels)): 84 | label = labels[i] 85 | result = exp.intercept[label] 86 | local_exp = exp.local_exp[label] 87 | for item in local_exp: 88 | result += item[1] 89 | print("Explanation for label {}:".format(self.class_names[label])) 90 | print("Local Prediction: {:.3f}".format(result)) 91 | print("Original Prediction: {:.3f}".format(self.original_pred[label])) 92 | print() 93 | local_exp = exp.local_exp[label] 94 | for idx in range(len(local_exp)): 95 | if self.explainer.include_original_feature: 96 | if local_exp[idx][0] >= D: 97 | index = local_exp[idx][0] - D 98 | key = "" 99 | for item in parts[index]: 100 | key += words[item] + " AND " 101 | key = key[:-5] 102 | local_exp[idx] = (key, local_exp[idx][1]) 103 | else: 104 | local_exp[idx] = (words[local_exp[idx][0]], local_exp[idx][1]) 105 | else: 106 | index = local_exp[idx][0] 107 | key = "" 108 | for item in parts[index]: 109 | key += words[item] + " AND " 110 | key = key[:-5] 111 | local_exp[idx] = (key, local_exp[idx][1]) 112 | 113 | for idx in range(len(local_exp)): 114 | print(" {:30} : {:.3f} |".format(local_exp[idx][0], local_exp[idx][1]), end="") 115 | if idx % span == span - 1: 116 | print() 117 | print() 118 | if len(local_exp) % span != 0: 119 | print() 120 | 121 | if plot: 122 | fig, axs = plt.subplots(nrows=1, ncols=len(labels)) 123 | vals = [x[1] for x in local_exp] 124 | names = [x[0] for x in local_exp] 125 | vals.reverse() 126 | names.reverse() 127 | pos = np.arange(len(vals)) + .5 128 | colors = ['green' if x > 0 else 'red' for x in vals] 129 | if len(labels) == 1: 130 | ax = axs 131 | else: 132 | ax = axs[i] 133 | ax.barh(pos, vals, align='center', color=colors) 134 | ax.set_yticks(pos) 135 | ax.set_yticklabels(names) 136 | title = 'Local explanation for class %s' % self.class_names[label] 137 | ax.set_title(title) 138 | if plot: 139 | plt.show() 140 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/tabular_explainers/lime_tabular.py: -------------------------------------------------------------------------------- 1 | # The implementation of LIME refers the original authors' codes in GitHub https://github.com/limetext/lime. 2 | # The Copyright of algorithm LIME is reserved for (c) 2016, Marco Tulio Correia Ribeiro. 3 | import matplotlib.pyplot as plt 4 | 5 | from lime.lime_tabular import LimeTabularExplainer 6 | 7 | from ..explainer import Explainer 8 | 9 | 10 | class XDeepLimeTabularExplainer(Explainer): 11 | 12 | def __init__(self, predict_proba, class_names, feature_names, train, 13 | categorical_features=None, categorical_names=None): 14 | """Init function. 15 | 16 | # Arguments 17 | predict_proba: Function. A classifier prediction probability function. 18 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 19 | feature_names: List. A list of names (strings) corresponding to the columns in the training data. 20 | train: Array. Train data. 21 | categorical_features: List. A list of indices (ints) corresponding to the categorical columns. Everything else will be considered continuous. Values in these columns MUST be integers. 22 | categorical_names: Dict. A dict which maps from int to list of names, where categorical_names[x][y] represents the name of the yth value of column x. 23 | """ 24 | Explainer.__init__(self, predict_proba, class_names) 25 | self.train = train 26 | self.feature_names = feature_names 27 | self.categorical_features = categorical_features 28 | self.categorical_names = categorical_names 29 | # Initialize explainer 30 | self.set_parameters() 31 | 32 | def set_parameters(self, **kwargs): 33 | """Parameter setter for lime_tabular. 34 | 35 | # Arguments 36 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 37 | """ 38 | train = kwargs.pop("train", self.train) 39 | class_names = kwargs.pop("class_names", self.class_names) 40 | feature_names = kwargs.pop("feature_names", self.feature_names) 41 | categorical_features = kwargs.pop("categorical_features", self.categorical_features) 42 | categorical_names = kwargs.pop("categorical_names", self.categorical_names) 43 | self.explainer = LimeTabularExplainer(train, class_names=class_names, 44 | feature_names=feature_names, categorical_features=categorical_features, 45 | categorical_names=categorical_names, **kwargs) 46 | 47 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs): 48 | """Generate explanation for a prediction using certain method. 49 | 50 | # Arguments 51 | instance: Array. One row of tabular data to be explained. 52 | top_labels: Integer. Number of labels you care about. 53 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 54 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 55 | """ 56 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels) 57 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba, 58 | top_labels=top_labels, labels=self.labels, **kwargs) 59 | 60 | def show_explanation(self, span=3): 61 | """Visualization of explanation of lime_text. 62 | 63 | # Arguments 64 | span: Integer. Each row shows how many features. 65 | """ 66 | Explainer.show_explanation(self) 67 | exp = self.explanation 68 | labels = self.labels 69 | 70 | print() 71 | print("LIME Explanation") 72 | print("Instance: {}".format(self.instance)) 73 | print() 74 | 75 | assert hasattr(labels, '__len__') 76 | for label in labels: 77 | result = exp.intercept[label] 78 | local_exp = exp.local_exp[label] 79 | for item in local_exp: 80 | result += item[1] 81 | print("Explanation for label {}:".format(self.class_names[label])) 82 | print("Local Prediction: {:.3f}".format(result)) 83 | print("Original Prediction: {:.3f}".format(self.original_pred[label])) 84 | print() 85 | exp_list = exp.as_list(label=label) 86 | for idx in range(len(exp_list)): 87 | print(" {:20} : {:.3f} |".format(exp_list[idx][0], exp_list[idx][1]), end="") 88 | if idx % span == span - 1: 89 | print() 90 | print() 91 | exp.as_pyplot_figure(label=label) 92 | plt.show() 93 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/tabular_explainers/shap_tabular.py: -------------------------------------------------------------------------------- 1 | from IPython.display import display 2 | from scipy.sparse.csr import csr_matrix 3 | 4 | import shap 5 | 6 | from ..explainer import Explainer 7 | 8 | 9 | class XDeepShapTabularExplainer(Explainer): 10 | 11 | def __init__(self, predict_proba, class_names, train): 12 | """Init function. 13 | 14 | # Arguments 15 | predict_proba: Function. Classifier prediction probability function. 16 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 17 | train: Array. Train data, in this case a list of str. 18 | """ 19 | Explainer.__init__(self, predict_proba, class_names) 20 | self.train = train 21 | # Initialize explainer 22 | self.set_parameters() 23 | 24 | def set_parameters(self, **kwargs): 25 | """Parameter setter for shap. 26 | 27 | # Arguments 28 | **kwargs: Parameters setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py. 29 | """ 30 | predict_proba = kwargs.pop("predict_proba", self.predict_proba) 31 | train = kwargs.pop("train", self.train) 32 | self.explainer = shap.KernelExplainer(predict_proba, train, **kwargs) 33 | 34 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs): 35 | """Generate explanation for a prediction using certain method. 36 | 37 | # Arguments 38 | instance: Array. One row of tabular data to be explained. 39 | top_labels: Integer. Number of labels you care about. 40 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 41 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 42 | """ 43 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels) 44 | self.explanation = self.explainer.shap_values(instance, **kwargs) 45 | 46 | def show_explanation(self, show_in_note_book=True): 47 | """Visualization of explanation of shap. 48 | 49 | # Arguments 50 | show_in_note_book: Boolean. Whether show in jupyter notebook. 51 | """ 52 | Explainer.show_explanation(self) 53 | shap_values = self.explanation 54 | expected_value = self.explainer.expected_value 55 | class_names = self.class_names 56 | labels = self.labels 57 | instance = self.instance 58 | 59 | shap.initjs() 60 | print() 61 | print("Shap Explanation") 62 | print() 63 | 64 | assert hasattr(labels, '__len__') 65 | if len(instance.shape) == 1 or instance.shape[0] == 1: 66 | for item in labels: 67 | print("Shap value for label {}:".format(class_names[item])) 68 | print(shap_values[item]) 69 | 70 | if show_in_note_book: 71 | for item in labels: 72 | if isinstance(instance, csr_matrix): 73 | display(shap.force_plot(expected_value[item], shap_values[item], instance.A)) 74 | else: 75 | display(shap.force_plot(expected_value[item], shap_values[item], instance)) 76 | else: 77 | if show_in_note_book: 78 | shap.summary_plot(shap_values, instance) 79 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/text_explainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/perturbation/text_explainers/__init__.py -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/text_explainers/anchor_text.py: -------------------------------------------------------------------------------- 1 | # The implementation of anchor refers the original authors' codes in GitHub https://github.com/marcotcr/anchor. 2 | # The Copyright of algorithm anchor is reserved for (c) 2018, Marco Tulio Correia Ribeiro. 3 | import spacy 4 | import numpy as np 5 | 6 | from anchor.anchor_text import AnchorText 7 | 8 | from ..explainer import Explainer 9 | 10 | 11 | class XDeepAnchorTextExplainer(Explainer): 12 | 13 | def __init__(self, predict_proba, class_names): 14 | """Init function. 15 | 16 | # Arguments 17 | predict_proba: Function. Classifier prediction probability function. 18 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 19 | """ 20 | Explainer.__init__(self, predict_proba, class_names) 21 | # Initialize explainer 22 | self.set_parameters() 23 | 24 | def set_parameters(self, nlp=None, **kwargs): 25 | """Parameter setter for anchor_text. 26 | 27 | # Arguments 28 | nlp: Object. A spacy model object. 29 | **kwargs: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor 30 | """ 31 | if nlp is None: 32 | print("Use default nlp = spacy.load('en_core_web_sm') to initialize anchor") 33 | try: 34 | nlp = spacy.load('en_core_web_sm') 35 | except OSError: 36 | print("Cannot load en_core_web_sm") 37 | print("Please run the command 'python -m spacy download en' as admin, then rerun your code.") 38 | 39 | class_names = kwargs.pop("class_names", self.class_names) 40 | if nlp is not None: 41 | self.explainer = AnchorText(nlp, class_names=class_names, **kwargs) 42 | 43 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs): 44 | """Generate explanation for a prediction using certain method. 45 | Anchor does not use top_labels and labels. 46 | 47 | # Arguments 48 | instance: Str. A raw text string to be explained. 49 | **kwargs: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor. 50 | """ 51 | def predict_label(x): 52 | return np.argmax(self.predict_proba(x), axis=1) 53 | 54 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels) 55 | try: 56 | self.labels = self.predict_proba([instance]).argsort()[0][-1:] 57 | except: 58 | self.labels = self.predict_proba(instance).argsort()[0][-1:] 59 | if self.explainer is not None: 60 | self.explanation = self.explainer.explain_instance(instance, predict_label, **kwargs) 61 | 62 | def show_explanation(self, show_in_note_book=True, verbose=True): 63 | """Visualization of explanation of anchor_text. 64 | 65 | # Arguments 66 | show_in_note_book: Boolean. Whether show in jupyter notebook. 67 | verbose: Boolean. Whether print out examples and counter examples. 68 | """ 69 | Explainer.show_explanation(self) 70 | exp = self.explanation 71 | instance = self.instance 72 | 73 | print() 74 | print("Anchor Explanation") 75 | print("Instance: {}".format(instance)) 76 | print() 77 | if verbose: 78 | print() 79 | print('Examples where anchor applies and model predicts same to instance:') 80 | print() 81 | print('\n'.join([x[0] for x in exp.examples(only_same_prediction=True)])) 82 | print() 83 | print('Examples where anchor applies and model predicts different with instance:') 84 | print() 85 | print('\n'.join([x[0] for x in exp.examples(only_different_prediction=True)])) 86 | 87 | label = self.labels[0] 88 | print('Prediction: {}'.format(label)) 89 | print('Anchor: %s' % (' AND '.join(exp.names()))) 90 | print('Precision: %.2f' % exp.precision()) 91 | print('Coverage: %.2f' % exp.coverage()) 92 | if show_in_note_book: 93 | exp.show_in_notebook() 94 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/text_explainers/cle_text.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from ..cle.cle_text import CLETextExplainer 5 | 6 | from ..explainer import Explainer 7 | 8 | 9 | class XDeepCLETextExplainer(Explainer): 10 | 11 | def __init__(self, predict_proba, class_names): 12 | """Init function. 13 | 14 | # Arguments 15 | predict_proba: Function. Classifier prediction probability function. 16 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 17 | """ 18 | Explainer.__init__(self, predict_proba, class_names) 19 | # Initialize explainer 20 | self.set_parameters() 21 | 22 | def set_parameters(self, **kwargs): 23 | """Parameter setter for cle_text. 24 | 25 | # Arguments 26 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 27 | """ 28 | class_names = kwargs.pop("class_names", self.class_names) 29 | self.explainer = CLETextExplainer(class_names=class_names, **kwargs) 30 | 31 | def explain(self, instance, top_labels=None, labels=(1,), care_words=None, spans=(2,), include_original_feature=True, **kwargs): 32 | """Generate explanation for a prediction using certain method. 33 | 34 | # Arguments 35 | instance: Str. A raw text string to be explained. 36 | top_labels: Integer. Number of labels you care about. 37 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 38 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 39 | """ 40 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels) 41 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba, labels=self.labels, top_labels=top_labels, 42 | care_words=care_words, spans=spans, include_original_feature=include_original_feature, 43 | **kwargs) 44 | 45 | def show_explanation(self, span=3, plot=True): 46 | """Visualization of explanation of cle_text. 47 | 48 | # Arguments 49 | span: Integer. Each row shows how many features. 50 | plot: Boolean. Whether plots a figure. 51 | """ 52 | Explainer.show_explanation(self) 53 | exp = self.explanation 54 | labels = self.labels 55 | 56 | print() 57 | print("CLE Explanation") 58 | print("Instance: {}".format(self.instance)) 59 | print() 60 | 61 | words = exp.domain_mapper.indexed_string.inverse_vocab 62 | D = len(words) 63 | 64 | # parts is a list which contains all the combination of features. 65 | parts = self.explainer.all_combinations 66 | 67 | assert hasattr(labels, '__len__') 68 | for i in range(len(labels)): 69 | label = labels[i] 70 | result = exp.intercept[label] 71 | local_exp = exp.local_exp[label] 72 | for item in local_exp: 73 | result += item[1] 74 | print("Explanation for label {}:".format(self.class_names[label])) 75 | print("Local Prediction: {:.3f}".format(result)) 76 | print("Original Prediction: {:.3f}".format(self.original_pred[label])) 77 | print() 78 | for idx in range(len(local_exp)): 79 | if self.explainer.include_original_feature: 80 | if local_exp[idx][0] >= D: 81 | index = local_exp[idx][0] - D 82 | key = "" 83 | for item in parts[index]: 84 | key += words[item]+" AND " 85 | key = key[:-5] 86 | local_exp[idx] = (key, local_exp[idx][1]) 87 | else: 88 | local_exp[idx] = (words[local_exp[idx][0]], local_exp[idx][1]) 89 | else: 90 | index = local_exp[idx][0] 91 | key = "" 92 | for item in parts[index]: 93 | key += words[item] + " AND " 94 | key = key[:-5] 95 | local_exp[idx] = (key, local_exp[idx][1]) 96 | 97 | for idx in range(len(local_exp)): 98 | print(" {:20} : {:.3f} |".format(local_exp[idx][0], local_exp[idx][1]), end="") 99 | if idx % span == span - 1: 100 | print() 101 | print() 102 | if len(local_exp) % span != 0: 103 | print() 104 | 105 | if plot: 106 | fig, axs = plt.subplots(nrows=1, ncols=len(labels)) 107 | vals = [x[1] for x in local_exp] 108 | names = [x[0] for x in local_exp] 109 | vals.reverse() 110 | names.reverse() 111 | pos = np.arange(len(vals)) + .5 112 | colors = ['green' if x > 0 else 'red' for x in vals] 113 | if len(labels) == 1: 114 | ax = axs 115 | else: 116 | ax = axs[i] 117 | 118 | ax.barh(pos, vals, align='center', color=colors) 119 | ax.set_yticks(pos) 120 | ax.set_yticklabels(names) 121 | title = 'Local explanation for class %s' % self.class_names[label] 122 | ax.set_title(title) 123 | if plot: 124 | plt.show() -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/text_explainers/lime_text.py: -------------------------------------------------------------------------------- 1 | # The implementation of LIME refers the original authors' codes in GitHub https://github.com/limetext/lime. 2 | # The Copyright of algorithm LIME is reserved for (c) 2016, Marco Tulio Correia Ribeiro. 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | from lime.lime_text import LimeTextExplainer 7 | from ..explainer import Explainer 8 | 9 | 10 | class XDeepLimeTextExplainer(Explainer): 11 | 12 | def __init__(self, predict_proba, class_names): 13 | """Init function. 14 | 15 | # Arguments 16 | predict_proba: Function. Classifier prediction probability function. 17 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 18 | """ 19 | Explainer.__init__(self, predict_proba, class_names) 20 | # Initialize explainer 21 | self.set_parameters() 22 | 23 | def set_parameters(self, **kwargs): 24 | """Parameter setter for lime_text. 25 | 26 | # Arguments 27 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 28 | """ 29 | class_names = kwargs.pop("class_names", self.class_names) 30 | self.explainer = LimeTextExplainer(class_names=class_names, **kwargs) 31 | 32 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs): 33 | """Generate explanation for a prediction using certain method. 34 | 35 | # Arguments 36 | instance: Str. A raw text string to be explained. 37 | top_labels: Integer. Number of labels you care about. 38 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 39 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html. 40 | """ 41 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels) 42 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba, 43 | top_labels=top_labels, labels=self.labels, **kwargs) 44 | 45 | def show_explanation(self, span=3): 46 | """Visualization of explanation of lime_text. 47 | 48 | # Arguments 49 | span: Integer. Each row shows how many features. 50 | """ 51 | Explainer.show_explanation(self) 52 | exp = self.explanation 53 | labels = self.labels 54 | 55 | print() 56 | print("LIME Explanation") 57 | print("Instance: {}".format(self.instance)) 58 | print() 59 | 60 | assert hasattr(labels, '__len__') 61 | for label in labels: 62 | result = exp.intercept[label] 63 | local_exp = exp.local_exp[label] 64 | for item in local_exp: 65 | result += item[1] 66 | print("Explanation for label {}:".format(self.class_names[label])) 67 | print("Local Prediction: {:.3f}".format(result)) 68 | print("Original Prediction: {:.3f}".format(self.original_pred[label])) 69 | print() 70 | exp_list = exp.as_list(label=label) 71 | for idx in range(len(exp_list)): 72 | print(" {:20} : {:.3f} |".format(exp_list[idx][0], exp_list[idx][1]), end="") 73 | if idx % span == span - 1: 74 | print() 75 | print() 76 | exp.as_pyplot_figure(label=label) 77 | plt.show() 78 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/text_explainers/shap_text.py: -------------------------------------------------------------------------------- 1 | from scipy.sparse.csr import csr_matrix 2 | from IPython.display import display 3 | 4 | import shap 5 | 6 | from ..explainer import Explainer 7 | 8 | 9 | class XDeepShapTextExplainer(Explainer): 10 | 11 | def __init__(self, predict_vectorized, class_names, vectorizer, train): 12 | """Init function. The data pass to 'shap' cannot be str but vector.As a result, you need to pass in the vectorizer. 13 | 14 | # Arguments 15 | predict_vectorized: Function. Classifier prediction probability function. 16 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 17 | vectorizer: Vectorizer. A vectorizer which has 'transform' function that transforms list of str to vector. 18 | train: Array. Train data, in this case a list of str. 19 | """ 20 | Explainer.__init__(self, predict_vectorized, class_names) 21 | self.vectorizer = vectorizer 22 | self.train = train 23 | # Initialize explainer 24 | self.set_parameters() 25 | 26 | def set_parameters(self, **kwargs): 27 | """Parameter setter for shap. The data pass to 'shap' cannot be str but vector.As a result, you need to pass in the vectorizer. 28 | 29 | # Arguments 30 | **kwargs: Shap kernel explainer parameter setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py. 31 | """ 32 | predict_vectorized = kwargs.pop('predict_vectorized', self.predict_proba) 33 | vectorizer = kwargs.pop('vectorizer', self.vectorizer) 34 | vector = kwargs.pop('train', self.train) 35 | if not isinstance(vector, csr_matrix): 36 | vector = vectorizer.transform(vector) 37 | self.explainer = shap.KernelExplainer(predict_vectorized, vector, **kwargs) 38 | 39 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs): 40 | """Generate explanation for a prediction using certain method. 41 | 42 | # Arguments 43 | instance: Str. A raw text string to be explained. 44 | top_labels: Integer. Number of labels you care about. 45 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 46 | **kwarg: Parameters setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py. 47 | """ 48 | if not isinstance(instance, csr_matrix): 49 | vector = self.vectorizer.transform([instance]) 50 | else: 51 | vector = instance 52 | self.instance = vector 53 | proba = self.predict_proba(self.instance) 54 | self.original_pred = proba[0] 55 | if top_labels is not None: 56 | self.labels = proba.argsort()[0][-top_labels:] 57 | else: 58 | self.labels = labels 59 | self.explanation = self.explainer.shap_values(vector, **kwargs) 60 | 61 | def show_explanation(self, show_in_note_book=True): 62 | """Visualization of explanation of shap. 63 | 64 | # Arguments 65 | show_in_note_book: Booleam. Whether show in jupyter notebook. 66 | """ 67 | Explainer.show_explanation(self) 68 | shap_values = self.explanation 69 | expected_value = self.explainer.expected_value 70 | class_names = self.class_names 71 | labels = self.labels 72 | instance = self.instance 73 | 74 | shap.initjs() 75 | print() 76 | print("Shap Explanation") 77 | print() 78 | 79 | if len(instance.shape) == 1 or instance.shape[0] == 1: 80 | for label in labels: 81 | print("Shap value for label {}:".format(class_names[label])) 82 | print(shap_values[label]) 83 | 84 | if show_in_note_book: 85 | for label in labels: 86 | if isinstance(instance, csr_matrix): 87 | display(shap.force_plot(expected_value[label], shap_values[label], instance.A)) 88 | else: 89 | display(shap.force_plot(expected_value[label], shap_values[label], instance)) 90 | else: 91 | if show_in_note_book: 92 | shap.summary_plot(shap_values, instance) 93 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/xdeep_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base wrapper for xdeep. 3 | """ 4 | from .exceptions import XDeepError 5 | 6 | 7 | class Wrapper(object): 8 | """A wrapper to wrap all methods.""" 9 | 10 | def __init__(self, predict_proba, class_names): 11 | """Init function. 12 | 13 | # Arguments 14 | predict_proba: Function. Classifier prediction probability function, which takes a list of d strings and outputs a (d, k) numpy array with prediction probabilities, where k is the number of classes. 15 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 16 | """ 17 | self.methods = list() 18 | self.explainers = dict() 19 | 20 | self.predict_proba = predict_proba 21 | self.class_names = class_names 22 | 23 | def __initialization(self): 24 | """Initialize default explainers.""" 25 | pass 26 | 27 | def __check_avaliability(self, method, exp_required=False): 28 | """Check if the method is valid. 29 | 30 | # Arguments 31 | method: Str. The method you want to use. 32 | exp_required: Boolean. Whether check the existence of explanation. 33 | """ 34 | if method not in self.methods: 35 | raise XDeepError("Please input correct explain_method from {}.".format(self.methods)) 36 | if method not in self.explainers: 37 | raise XDeepError("Sorry. You need to initialize {}_explainer first.".format(method)) 38 | if exp_required: 39 | exp = self.explainers[method].explanation 40 | if exp is None: 41 | raise XDeepError("Sorry. You haven't got {} explanation yet.".format(method)) 42 | 43 | def set_parameters(self, method, **kwargs): 44 | """Parameter setter. 45 | 46 | # Arguments 47 | method: Str. The method you want to use. 48 | **kwargs: Other parameters that depends on which method you use. For more detail, please check xdeep documentation. 49 | """ 50 | self.__check_avaliability(method) 51 | self.explainers[method].set_parameters(**kwargs) 52 | 53 | def explain(self, method, instance, top_labels=None, labels=(1,), **kwargs): 54 | """Generate explanation for a prediction using certain method. 55 | 56 | # Arguments 57 | method: Str. The method you want to use. 58 | instance: Instance to be explained. 59 | top_labels: Integer. Number of labels you care about. 60 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 61 | **kwargs: Parameters setter. For more detail, please check 'explain_instance' in corresponding method. 62 | """ 63 | # Check validation of parameters 64 | self.__check_avaliability(method) 65 | self.explainers[method].explain(instance, top_labels=top_labels, labels=labels, **kwargs) 66 | 67 | def get_explanation(self, method): 68 | """Explanation getter. 69 | 70 | # Arguments 71 | method: Str. The method you want to use. 72 | 73 | # Return 74 | An Explanation object with the corresponding method. 75 | """ 76 | # Check validation of parameters 77 | self.__check_avaliability(method, exp_required=True) 78 | exp = self.explainers[method].explanation 79 | return exp 80 | 81 | def show_explanation(self, method, **kwargs): 82 | """Visualization of explanation of the corresponding method. 83 | 84 | # Arguments 85 | method: Str. The method you want to use. 86 | **kwargs: parameters setter. For more detail, please check xdeep documentation. 87 | """ 88 | # Check validation of parameters 89 | self.__check_avaliability(method, exp_required=True) 90 | self.explainers[method].show_explanation(**kwargs) 91 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/xdeep_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for explaining image classifiers. 3 | """ 4 | 5 | from .xdeep_base import Wrapper 6 | from .image_explainers.lime_image import XDeepLimeImageExplainer 7 | from .image_explainers.cle_image import XDeepCLEImageExplainer 8 | from .image_explainers.anchor_image import XDeepAnchorImageExplainer 9 | from .image_explainers.shap_image import XDeepShapImageExplainer 10 | 11 | 12 | class ImageExplainer(Wrapper): 13 | """Integrated explainer which explains text classifiers.""" 14 | 15 | def __init__(self, predict_proba, class_names): 16 | """Init function. 17 | 18 | # Arguments 19 | predict_proba: Function. Classifier prediction probability function. 20 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 21 | """ 22 | Wrapper.__init__(self, predict_proba, class_names=class_names) 23 | self.methods = ['lime', 'cle', 'anchor', 'shap'] 24 | self.__initialization() 25 | 26 | def __initialization(self): 27 | """Initializer. Use default parameters to initialize 'lime' and 'anchor' explainer.""" 28 | print("Initialize default 'lime', 'cle', 'anchor' explainers. " + 29 | "Please explicitly initialize 'shap' explainer before use it.") 30 | self.explainers['lime'] = XDeepLimeImageExplainer(self.predict_proba, self.class_names) 31 | self.explainers['cle'] = XDeepCLEImageExplainer(self.predict_proba, self.class_names) 32 | self.explainers['anchor'] = XDeepAnchorImageExplainer(self.predict_proba, self.class_names) 33 | 34 | def initialize_shap(self, n_segment, segment): 35 | """Init function. 36 | 37 | # Arguments 38 | n_segment: Integer. Number of segments in the image. 39 | segment: Array. An array with 2 dimensions, segments_slic of the image. 40 | """ 41 | self.explainers['shap'] = XDeepShapImageExplainer(self.predict_proba, self.class_names, n_segment, segment) 42 | 43 | def explain(self, method, instance, top_labels=2, labels=(1,), **kwargs): 44 | """Generate explanation for a prediction using certain method. 45 | 46 | # Arguments 47 | method: Str. The method you want to use. 48 | instance: Instance to be explained. 49 | top_labels: Integer. Number of labels you care about. 50 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 51 | **kwargs: Parameters setter. For more detail, please check 'explain_instance' in corresponding method. 52 | """ 53 | Wrapper.explain(self, method, instance, top_labels=top_labels, labels=labels, **kwargs) 54 | 55 | def show_explanation(self, method, **kwargs): 56 | """Visualization of explanation of the corresponding method. 57 | 58 | # Arguments 59 | method: Str. The method you want to use. 60 | **kwargs: parameters setter. For more detail, please check xdeep documentation. 61 | """ 62 | Wrapper.show_explanation(self, method, **kwargs) 63 | 64 | def get_explanation(self, method): 65 | """Explanation getter. 66 | 67 | # Arguments 68 | method: Str. The method you want to use. 69 | 70 | # Return 71 | An Explanation object with the corresponding method. 72 | """ 73 | Wrapper.get_explanation(self, method) 74 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/xdeep_tabular.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for explaining tabular classifiers. 3 | """ 4 | 5 | from .xdeep_base import Wrapper 6 | from .exceptions import XDeepError 7 | from .tabular_explainers.lime_tabular import XDeepLimeTabularExplainer 8 | from .tabular_explainers.cle_tabular import XDeepCLETabularExplainer 9 | from .tabular_explainers.anchor_tabular import XDeepAnchorTabularExplainer 10 | from .tabular_explainers.shap_tabular import XDeepShapTabularExplainer 11 | 12 | 13 | class TabularExplainer(Wrapper): 14 | """Integrated explainer which explains text classifiers.""" 15 | 16 | def __init__(self, predict_proba, class_names, feature_names, train, 17 | categorical_features=None, categorical_names=None): 18 | """Init function. 19 | 20 | # Arguments 21 | predict_proba: Function. Classifier prediction probability function. 22 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 23 | feature_names: List. A list of names (strings) corresponding to the columns in the training data. 24 | train: Array. Train data, in this case a list of tabular data. 25 | categorical_features: List. A list of indices (ints) corresponding to the categorical columns. Everything else will be considered continuous. Values in these columns MUST be integers. 26 | categorical_names: Dict. A dict which maps from int to list of names, where categorical_names[x][y] represents the name of the yth value of column x. 27 | """ 28 | Wrapper.__init__(self, predict_proba, class_names=class_names) 29 | self.methods = ['lime', 'cle', 'anchor', 'shap'] 30 | self.feature_names = feature_names 31 | self.train = train 32 | self.categorical_features = categorical_features 33 | self.categorical_names = categorical_names 34 | self.__initialization() 35 | 36 | def __initialization(self): 37 | """Initializer. Use default parameters to initialize 'lime' and 'shap' explainer.""" 38 | print("Initialize default 'lime', 'cle', 'shap' explainers. " 39 | "Please explicitly initialize 'anchor' explainer before use it.") 40 | self.explainers['lime'] = XDeepLimeTabularExplainer(self.predict_proba, self.class_names, self.feature_names, 41 | self.train, categorical_features=self.categorical_features, 42 | categorical_names=self.categorical_names) 43 | self.explainers['cle'] = XDeepCLETabularExplainer(self.predict_proba, self.class_names, self.feature_names, 44 | self.train, categorical_features=self.categorical_features, 45 | categorical_names=self.categorical_names) 46 | self.explainers['shap'] = XDeepShapTabularExplainer(self.predict_proba, self.class_names, self.train) 47 | 48 | def initialize_anchor(self, data): 49 | """Explicit shap initializer. 50 | 51 | # Arguments 52 | data: Array. Full data including train data, validation_data and test data. 53 | """ 54 | print("If you want to use 'anchor' to explain tabular classifier. " 55 | "You need to get this encoder to encode your data, and train another model.") 56 | self.explainers['anchor'] = XDeepAnchorTabularExplainer(self.class_names, self.feature_names, data, 57 | categorical_names=self.categorical_names) 58 | 59 | def get_anchor_encoder(self, train_data, train_labels, validation_data, validation_labels, discretizer='quartile'): 60 | """Get encoder for tabular data. 61 | 62 | If you want to use 'anchor' to explain tabular classifier. You need to get this encoder to encode your data, and train another model. 63 | 64 | # Arguments 65 | train_data: Array. Train data. 66 | train_labels: Array. Train labels. 67 | validation_data: Array. Validation set. 68 | validation_labels: Array. Validation labels. 69 | discretizer: Str. Discretizer for data. Please choose between 'quartile' and 'decile'. 70 | 71 | # Return 72 | A encoder object which has function 'transform'. 73 | """ 74 | if 'anchor' not in self.explainers: 75 | raise XDeepError("Please initialize anchor explainer first.") 76 | return self.explainers['anchor'].get_anchor_encoder(train_data, train_labels, 77 | validation_data, validation_labels, discretizer='quartile') 78 | 79 | def set_anchor_predict_proba(self, predict_proba): 80 | """Anchor predict function setter. 81 | Because you will get an encoder and train new model, you will need to update the predict function. 82 | 83 | # Arguments 84 | predict_proba: Function. A new classifier prediction probability function. 85 | """ 86 | self.explainers['anchor'].set_anchor_predict_proba(predict_proba) 87 | 88 | def explain(self, method, instance, top_labels=None, labels=(1,), **kwargs): 89 | """Generate explanation for a prediction using certain method. 90 | 91 | # Arguments 92 | method: Str. The method you want to use. 93 | instance: Instance to be explained. 94 | top_labels: Integer. Number of labels you care about. 95 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 96 | **kwargs: Parameters setter. For more detail, please check 'explain_instance' in corresponding method. 97 | """ 98 | Wrapper.explain(self, method, instance, top_labels=top_labels, labels=labels, **kwargs) 99 | 100 | 101 | def show_explanation(self, method, **kwargs): 102 | """Visualization of explanation of the corresponding method. 103 | 104 | # Arguments 105 | method: Str. The method you want to use. 106 | **kwargs: parameters setter. For more detail, please check xdeep documentation. 107 | """ 108 | Wrapper.show_explanation(self, method, **kwargs) 109 | 110 | def get_explanation(self, method): 111 | """Explanation getter. 112 | 113 | # Arguments 114 | method: Str. The method you want to use. 115 | 116 | # Return 117 | An Explanation object with the corresponding method. 118 | """ 119 | Wrapper.get_explanation(self, method) 120 | -------------------------------------------------------------------------------- /xdeep/xlocal/perturbation/xdeep_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for explaining text classifiers. 3 | """ 4 | 5 | from .xdeep_base import Wrapper 6 | from .text_explainers.lime_text import XDeepLimeTextExplainer 7 | from .text_explainers.cle_text import XDeepCLETextExplainer 8 | from .text_explainers.anchor_text import XDeepAnchorTextExplainer 9 | from .text_explainers.shap_text import XDeepShapTextExplainer 10 | 11 | 12 | class TextExplainer(Wrapper): 13 | """Integrated explainer which explains text classifiers.""" 14 | 15 | def __init__(self, predict_proba, class_names): 16 | """Init function. 17 | 18 | # Arguments 19 | predict_proba: Function. Classifier prediction probability function. 20 | class_names: List. A list of class names, ordered according to whatever the classifier is using. 21 | """ 22 | Wrapper.__init__(self, predict_proba, class_names=class_names) 23 | self.methods = ['lime', 'cle', 'anchor', 'shap'] 24 | self.__initialization() 25 | 26 | def __initialization(self): 27 | """Initializer. Use default parameters to initialize 'lime' and 'anchor' explainer.""" 28 | print("Initialize default 'lime', 'cle', 'anchor' explainers. " + 29 | "Please explicitly initialize 'shap' explainer before use it.") 30 | self.explainers['lime'] = XDeepLimeTextExplainer(self.predict_proba, self.class_names) 31 | self.explainers['cle'] = XDeepCLETextExplainer(self.predict_proba, self.class_names) 32 | self.explainers['anchor'] = XDeepAnchorTextExplainer(self.predict_proba, self.class_names) 33 | 34 | def initialize_shap(self, predict_vectorized, vectorizer, train): 35 | """Explicit shap initializer. 36 | 37 | # Arguments 38 | predict_vectorized: Function. Classifier prediction probability function which need vector passed in. 39 | vectorizer: Vectorizer. A vectorizer which has 'transform' function that transforms list of str to vector. 40 | train: Array. Train data, in this case a list of str. 41 | """ 42 | self.explainers['shap'] = XDeepShapTextExplainer(predict_vectorized, self.class_names, vectorizer, train) 43 | 44 | def explain(self, method, instance, top_labels=None, labels=(1,), **kwargs): 45 | """Generate explanation for a prediction using certain method. 46 | 47 | # Arguments 48 | method: Str. The method you want to use. 49 | instance: Instance to be explained. 50 | top_labels: Integer. Number of labels you care about. 51 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels. 52 | **kwargs: Parameters setter. For more detail, please check 'explain_instance' in corresponding method. 53 | """ 54 | Wrapper.explain(self, method, instance, top_labels=top_labels, labels=labels, **kwargs) 55 | 56 | def show_explanation(self, method, **kwargs): 57 | """Visualization of explanation of the corresponding method. 58 | 59 | # Arguments 60 | method: Str. The method you want to use. 61 | **kwargs: parameters setter. For more detail, please check xdeep documentation. 62 | """ 63 | Wrapper.show_explanation(self, method, **kwargs) 64 | 65 | def get_explanation(self, method): 66 | """Explanation getter. 67 | 68 | # Arguments 69 | method: Str. The method you want to use. 70 | 71 | # Return 72 | An Explanation object with the corresponding method. 73 | """ 74 | Wrapper.get_explanation(self, method) 75 | --------------------------------------------------------------------------------