├── .gitattributes
├── .gitignore
├── .travis.yml
├── LICENSE.md
├── README.md
├── docs
├── about.md
├── global_explainers.md
├── gradient_backprop.md
├── gradient_cam.md
├── gradient_explainer.md
├── index.md
├── tutorial_global.md
├── tutorial_local.md
├── xdeep_image.md
├── xdeep_tabular.md
└── xdeep_text.md
├── result_img
└── ensemble_fig.png
├── setup.py
├── tests
├── __init__.py
├── xglobal
│ ├── __init__.py
│ ├── images
│ │ └── jay.jpg
│ ├── results
│ │ ├── deepdream.jpg
│ │ ├── filter.jpg
│ │ ├── inverted
│ │ │ ├── Inverted_Iteration_0.jpg
│ │ │ ├── Inverted_Iteration_1.jpg
│ │ │ ├── Inverted_Iteration_2.jpg
│ │ │ ├── Inverted_Iteration_3.jpg
│ │ │ ├── Inverted_Iteration_4.jpg
│ │ │ ├── Inverted_Iteration_5.jpg
│ │ │ ├── Inverted_Iteration_6.jpg
│ │ │ ├── Inverted_Iteration_7.jpg
│ │ │ ├── Inverted_Iteration_8.jpg
│ │ │ └── Inverted_Iteration_9.jpg
│ │ ├── layer.jpg
│ │ └── logit.jpg
│ └── test_global.py
├── xlocal_gradient
│ ├── __init__.py
│ ├── images
│ │ └── ILSVRC2012_val_00000073.JPEG
│ ├── results
│ │ ├── GuidedBP_ILSVRC2012_val_00000073.JPEG
│ │ └── gradcam_ILSVRC2012_val_00000073.JPEG
│ ├── test_explainers.py
│ ├── test_gradcam.py
│ └── test_guided_backprop.py
└── xlocal_perturbation
│ ├── __init__.py
│ ├── data
│ ├── Eskimo dog, husky.JPEG
│ ├── Great Pyrenees.JPEG
│ ├── Ibizan hound.JPEG
│ ├── adult
│ │ └── adult.data
│ ├── bee eater.JPEG
│ ├── car mirror.JPEG
│ ├── eggnog.JPEG
│ ├── kitten.jpg
│ ├── oystercatcher.jpg
│ ├── rt-polaritydata
│ │ ├── rt-polaritydata.README.1.0.txt
│ │ └── rt-polaritydata
│ │ │ ├── rt-polarity.neg
│ │ │ └── rt-polarity.pos
│ ├── sea_snake.JPEG
│ ├── sky.jpg
│ ├── tusker.JPEG
│ └── violin.JPEG
│ ├── image_util
│ ├── imagenet_lsvrc_2015_synsets.txt
│ └── imagenet_metadata.txt
│ ├── test_image.py
│ ├── test_tabular.py
│ └── test_text.py
├── tutorial
├── notebook
│ ├── image - inceptionV3-checkpoint.ipynb
│ ├── image - tutorial-checkpoint.ipynb
│ ├── tabular - RF-checkpoint.ipynb
│ ├── tabular - tutorial-checkpoint.ipynb
│ ├── text - Bert-checkpoint.ipynb
│ ├── text - LR - CountVetorizer-checkpoint.ipynb
│ ├── text - LSTM - Word2Vec-checkpoint.ipynb
│ ├── text - RF - CountVectorizer-checkpoint.ipynb
│ ├── text - SVC - CountVetorizer-checkpoint.ipynb
│ └── text - tutorial-checkpoint.ipynb
├── xglobal
│ ├── images
│ │ └── jay.jpg
│ ├── results
│ │ ├── deepdream.jpg
│ │ ├── filter.jpg
│ │ ├── inverted
│ │ │ ├── Inverted_Iteration_0.jpg
│ │ │ ├── Inverted_Iteration_1.jpg
│ │ │ ├── Inverted_Iteration_2.jpg
│ │ │ ├── Inverted_Iteration_3.jpg
│ │ │ ├── Inverted_Iteration_4.jpg
│ │ │ ├── Inverted_Iteration_5.jpg
│ │ │ ├── Inverted_Iteration_6.jpg
│ │ │ ├── Inverted_Iteration_7.jpg
│ │ │ ├── Inverted_Iteration_8.jpg
│ │ │ └── Inverted_Iteration_9.jpg
│ │ ├── layer.jpg
│ │ └── logit.jpg
│ └── test_global.py
├── xlocal_gradient
│ ├── images
│ │ └── ILSVRC2012_val_00000073.JPEG
│ ├── results
│ │ ├── GuidedBP_ILSVRC2012_val_00000073.JPEG
│ │ └── gradcam_ILSVRC2012_val_00000073.JPEG
│ ├── test_explainers.py
│ ├── test_gradcam.py
│ └── test_guided_backprop.py
└── xlocal_perturbation
│ ├── data
│ ├── Eskimo dog, husky.JPEG
│ ├── Great Pyrenees.JPEG
│ ├── Ibizan hound.JPEG
│ ├── adult
│ │ └── adult.data
│ ├── bee eater.JPEG
│ ├── car mirror.JPEG
│ ├── eggnog.JPEG
│ ├── kitten.jpg
│ ├── oystercatcher.jpg
│ ├── rt-polaritydata
│ │ ├── rt-polaritydata.README.1.0.txt
│ │ └── rt-polaritydata
│ │ │ ├── rt-polarity.neg
│ │ │ └── rt-polarity.pos
│ ├── sea_snake.JPEG
│ ├── sky.jpg
│ ├── tusker.JPEG
│ └── violin.JPEG
│ ├── image - tutorial.html
│ ├── image - tutorial.ipynb
│ ├── image - tutorial.py
│ ├── image_util
│ ├── .DS_Store
│ ├── imagenet_lsvrc_2015_synsets.txt
│ ├── imagenet_metadata.txt
│ ├── inception_preprocessing.py
│ └── inception_v3.py
│ ├── tabular - tutorial.html
│ ├── tabular - tutorial.ipynb
│ ├── tabular - tutorial.py
│ ├── text - tutorial.html
│ ├── text - tutorial.ipynb
│ └── text - tutorial.py
└── xdeep
├── __init__.py
├── utils
├── __init__.py
├── imagenet.py
└── resources
│ ├── __init__.py
│ └── imagenet_class_index.json
├── xglobal
├── __init__.py
├── global_explainers.py
└── methods
│ ├── __init__.py
│ ├── activation.py
│ └── inverted_representation.py
└── xlocal
├── __init__.py
├── gradient
├── __init__.py
├── backprop
│ ├── __init__.py
│ ├── base.py
│ ├── guided_backprop.py
│ ├── integrate_grad.py
│ └── smooth_grad.py
├── cam
│ ├── __init__.py
│ ├── basecam.py
│ ├── gradcam.py
│ ├── gradcampp.py
│ └── scorecam.py
└── explainers.py
└── perturbation
├── .ipynb_checkpoints
├── Image_Tutorial-checkpoint.ipynb
├── Tutorial - image-checkpoint.ipynb
├── Tutorial - tabular-checkpoint.ipynb
└── Tutorial - text-checkpoint.ipynb
├── __init__.py
├── cle
├── __init__.py
├── cle_image.py
├── cle_tabular.py
└── cle_text.py
├── exceptions.py
├── explainer.py
├── image_explainers
├── __init__.py
├── anchor_image.py
├── cle_image.py
├── lime_image.py
└── shap_image.py
├── tabular_explainers
├── __init__.py
├── anchor_tabular.py
├── cle_tabular.py
├── lime_tabular.py
└── shap_tabular.py
├── text_explainers
├── __init__.py
├── anchor_text.py
├── cle_text.py
├── lime_text.py
└── shap_text.py
├── xdeep_base.py
├── xdeep_image.py
├── xdeep_tabular.py
└── xdeep_text.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | tests/xlocal_perturbation/model/inception_v3.ckpt filter=lfs diff=lfs merge=lfs -text
2 | tutorial/xlocal_perturbation/model/inception_v3.ckpt filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 |
3 | before_script:
4 | - pip install -e .[tests]
5 | - pip install python-coveralls
6 | - pip install pytest-cover
7 | script:
8 | - py.test tests/
9 | after_success:
10 | - coveralls
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Data Analytics Lab at Texas A&M University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # XDeep
2 | ## -- *for interpretable deep learning developers*
3 |
4 | [](https://www.codacy.com/manual/nacoyang/xdeep?utm_source=github.com&utm_medium=referral&utm_content=datamllab/xdeep&utm_campaign=Badge_Grade)
5 | [](https://travis-ci.com/datamllab/xdeep)
6 |
7 | XDeep is an open source Python library for **Interpretable Machine Learning**. It is developed by [DATA Lab](http://faculty.cs.tamu.edu/xiahu/index.html) at Texas A&M University. The goal of XDeep is to provide easily accessible interpretation tools for people who want to figure out how deep models work. XDeep provides a variety of methods to interpret a model both locally and globally.
8 |
9 | ## Installation
10 |
11 | To install the package, please use the pip installation as follows (https://pypi.org/project/x-deep/):
12 |
13 | pip install x-deep
14 |
15 | **Note**: currently, XDeep is only compatible with: **Python 3.6**.
16 |
17 | ## Example
18 |
19 | Here is a short example of using the package.
20 |
21 | ```python
22 | import torchvision.models as models
23 | from xdeep.xlocal.gradient.explainers import *
24 |
25 | # load the input image & target deep model
26 | image = load_image('input.jpg')
27 | model = models.vgg16(pretrained=True)
28 |
29 | # build the xdeep explainer
30 | model_explainer = ImageInterpreter(model)
31 |
32 | # generate the local interpretation
33 | model_explainer.explain(image, method_name='gradcam', target_layer_name='features_29', viz=True)
34 | ```
35 |
36 | For detailed tutorial, please check the docs directory of this repository [here](https://github.com/datamllab/xdeep/tree/master/docs).
37 |
38 | ## Sample Results
39 |
40 |
41 |
42 | ## Cite this work
43 |
44 | Fan Yang, Zijian Zhang, Haofan Wang, Yuening Li, Xia Hu "XDeep: An Interpretation Tool for Deep Neural Networks." arXiv:1911.01005, 2019. ([Download](https://arxiv.org/abs/1911.01005))
45 |
46 | Biblatex entry:
47 |
48 | @article{yang2019xdeep,
49 | title={XDeep: An Interpretation Tool for Deep Neural Networks},
50 | author={Yang, Fan and Zhang, Zijian and Wang, Haofan and Li, Yuening and Hu, Xia},
51 | journal={arXiv preprint arXiv:1911.01005},
52 | year={2019}
53 | }
54 |
55 | ## DISCLAIMER
56 |
57 | Please note that this is a **pre-release** version of the XDeep which is still undergoing final testing before its official release. The website, its software and all contents found on it are provided on an
58 | “as is” and “as available” basis. XDeep does **not** give any warranties, whether express or implied, as to the suitability or usability of the website, its software or any of its content. XDeep will **not** be liable for any loss, whether such loss is direct, indirect, special or consequential, suffered by any party as a result of their use of the libraries or content. Any usage of the libraries is done at the user’s own risk and the user will be solely responsible for any damage to any computer system or loss of data that results from such activities. Should you encounter any bugs, glitches, lack of functionality or
59 | other problems on the website, please let us know immediately so we
60 | can rectify these accordingly. Your help in this regard is greatly appreciated.
61 |
62 | ## Contact
63 |
64 | Please send emails to *nacoyang@tamu.edu* if you have any issues for this project. Thanks.
65 |
66 | ## Acknowledgements
67 |
68 | The authors gratefully acknowledge the XAI program of the Defense Advanced Research Projects Agency (DARPA) administered through grant N66001-17-2-4031; the Texas A&M College of Engineering, and Texas A&M.
69 |
--------------------------------------------------------------------------------
/docs/about.md:
--------------------------------------------------------------------------------
1 | #About
2 |
3 | This package is developed by [DATA LAB](http://faculty.cs.tamu.edu/xiahu/) at Texas A&M University.
4 |
5 | ## Main Contributors
6 |
7 | [**Fan Yang**](./):
8 | Overall package design.
9 |
10 | [**Zijian Zhang**](./):
11 | Developer for perturbation-based local explanation part.
12 |
13 | [**Haofan Wang**](./):
14 | Developer for gradient-based local explanation and global explanation part.
15 |
16 | [**Yuening Li**](./):
17 | Code review and testing.
18 |
19 | [**Xia "Ben" Hu**](./):
20 | Advisor.
21 |
--------------------------------------------------------------------------------
/docs/global_explainers.md:
--------------------------------------------------------------------------------
1 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xglobal/global_explainers.py#L11)
2 | ## GlobalImageInterpreter class
3 |
4 | ```python
5 | xdeep.xglobal.global_explainers.GlobalImageInterpreter(model)
6 | ```
7 |
8 | Class for global image explanations.
9 |
10 | GlobalImageInterpreter provides unified interface to call different visualization methods.
11 | The user can use it by several lines. It includes '__init__()' and 'explain()'. The
12 | users can specify the name of visualization method and target layer. If params are
13 | not specified, default params will be used.
14 |
15 |
16 | ---
17 | ## GlobalImageInterpreter methods
18 |
19 | ### explain
20 |
21 |
22 | ```python
23 | explain(method_name=None, target_layer=None, target_filter=None, input_=None, num_iter=10, save_path=None)
24 | ```
25 |
26 |
27 | Function to call different visualization methods.
28 |
29 | __Arguments__
30 |
31 | - __method_name__: str. The name of interpreter method. Currently, global explanation methods support 'filter',
32 | 'layer', 'logit', 'deepdream', 'inverted'.
33 | - __target_layer__: torch.nn.Linear or torch.nn.conv.Conv2d. The objective layer.
34 | - __target_filter__: int or list. Index of filter or filters.
35 | - __input___: str or Tensor. Path of input image or normalized tensor. Default to be None.
36 | - __num_iter__: int. Iter times. Default to 10.
37 | - __save_path__: str. Path to save generated explanations. Default to None.
38 |
39 | ---
40 | ### inverted_feature
41 |
42 |
43 | ```python
44 | inverted_feature(model)
45 | ```
46 |
47 | ---
48 | ### maximize_activation
49 |
50 |
51 | ```python
52 | maximize_activation(model)
53 | ```
54 |
--------------------------------------------------------------------------------
/docs/gradient_backprop.md:
--------------------------------------------------------------------------------
1 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/backprop/base.py#L5)
2 | ## BaseProp class
3 |
4 | ```python
5 | xdeep.xlocal.gradient.backprop.base.BaseProp(model)
6 | ```
7 |
8 |
9 | Base class for backpropagation.
10 |
11 | ----
12 |
13 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/backprop/guided_backprop.py#L4)
14 | ## Backprop class
15 |
16 | ```python
17 | xdeep.xlocal.gradient.backprop.guided_backprop.Backprop(model, guided=False)
18 | ```
19 |
20 | Generates vanilla or guided backprop gradients of a target class output w.r.t. an input image.
21 |
22 | __Arguments:__
23 |
24 | - __model__: torchvision.models. A pretrained model.
25 | - __guided__: bool. If True, perform guided backpropagation. Defaults to False.
26 |
27 | __Return:__
28 |
29 | Backprop Class.
30 |
31 |
32 | ---
33 | ## Backprop methods
34 |
35 | ### calculate_gradients
36 |
37 |
38 | ```python
39 | calculate_gradients(input_, target_class=None, take_max=False, use_gpu=False)
40 | ```
41 |
42 |
43 | Calculate gradient.
44 |
45 | __Arguments__
46 |
47 | - __input___: torch.Tensor. Preprocessed image with shape (1, C, H, W).
48 | - __target_class__: int. Index of target class. Default to None and use the prediction result as target class.
49 | - __take_max__: bool. Take the maximum across colour channels. Defaults to False.
50 |
51 | use_gpu. bool. Use GPU or not. Defaults to False.
52 |
53 | __Return:__
54 |
55 | Gradient (torch.Tensor) with shape (C, H, W). If take max is True, with shape (1, H, W).
56 |
57 | ----
58 |
59 | ### generate_integrated_grad
60 |
61 |
62 | ```python
63 | xdeep.xlocal.gradient.backprop.integrate_grad.generate_integrated_grad(explainer, input_, target_class=None, n=25)
64 | ```
65 |
66 |
67 | Generates integrate gradients of given explainer. You can use this with both vanilla
68 | and guided backprop
69 |
70 | __Arguments:__
71 |
72 | - __explainer__: class. Backprop method.
73 | - __input___: torch.Tensor. Preprocessed image with shape (N, C, H, W).
74 | - __target_class__: int. Index of target class. Default to None.
75 | - __n__: int. Integrate steps. Default to 10.
76 |
77 | __Return:__
78 |
79 | Integrated gradient (torch.Tensor) with shape (C, H, W).
80 |
81 | ----
82 |
83 | ### generate_smooth_grad
84 |
85 |
86 | ```python
87 | xdeep.xlocal.gradient.backprop.smooth_grad.generate_smooth_grad(explainer, input_, target_class=None, n=50, mean=0, sigma_multiplier=4)
88 | ```
89 |
90 |
91 | Generates smooth gradients of given explainer.
92 |
93 | __Arguments__
94 |
95 | - __explainer__: class. Backprop method.
96 | - __input___: torch.Tensor. Preprocessed image with shape (1, C, H, W).
97 | - __target_class__: int. Index of target class. Defaults to None.
98 | - __n__: int. Amount of noisy images used to smooth gradient. Default to 10.
99 | - __mean__: int. Mean value of normal distribution when generating noise. Default to 0.
100 |
101 | sigma_multiplier. int. Sigma multiplier when calculating std of noise. Default to 4.
102 |
103 | __Return:__
104 |
105 | Smooth gradient (torch.Tensor) with shape (C, H, W).
106 |
--------------------------------------------------------------------------------
/docs/gradient_cam.md:
--------------------------------------------------------------------------------
1 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/cam/basecam.py#L3)
2 | ## BaseCAM class
3 |
4 | ```python
5 | xdeep.xlocal.gradient.cam.basecam.BaseCAM(model_dict)
6 | ```
7 |
8 |
9 | Base class for Class Activation Mapping.
10 |
11 |
12 | ---
13 | ## BaseCAM methods
14 |
15 | ### find_layer
16 |
17 |
18 | ```python
19 | find_layer(arch, target_layer_name)
20 | ```
21 |
22 | ---
23 | ### forward
24 |
25 |
26 | ```python
27 | forward(input_, class_idx=None, retain_graph=False)
28 | ```
29 |
30 | ----
31 |
32 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/cam/gradcam.py#L5)
33 | ## GradCAM class
34 |
35 | ```python
36 | xdeep.xlocal.gradient.cam.gradcam.GradCAM(model_dict)
37 | ```
38 |
39 |
40 | GradCAM, inherit from BaseCAM
41 |
42 |
43 | ---
44 | ## GradCAM methods
45 |
46 | ### find_layer
47 |
48 |
49 | ```python
50 | find_layer(arch, target_layer_name)
51 | ```
52 |
53 | ---
54 | ### forward
55 |
56 |
57 | ```python
58 | forward(input_, class_idx=None, retain_graph=False)
59 | ```
60 |
61 |
62 | Generates GradCAM result.
63 |
64 | __Arguments__
65 |
66 | - __input___: torch.Tensor. Preprocessed image with shape (1, C, H, W).
67 | - __class_idx__: int. Index of target class. Defaults to be index of predicted class.
68 |
69 | __Return__
70 |
71 | Result of GradCAM (torch.Tensor) with shape (1, H, W).
72 |
73 | ----
74 |
75 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/cam/gradcampp.py#L6)
76 | ## GradCAMpp class
77 |
78 | ```python
79 | xdeep.xlocal.gradient.cam.gradcampp.GradCAMpp(model_dict)
80 | ```
81 |
82 |
83 | GradCAM++, inherit from BaseCAM
84 |
85 |
86 | ---
87 | ## GradCAMpp methods
88 |
89 | ### find_layer
90 |
91 |
92 | ```python
93 | find_layer(arch, target_layer_name)
94 | ```
95 |
96 | ---
97 | ### forward
98 |
99 |
100 | ```python
101 | forward(input, class_idx=None, retain_graph=False)
102 | ```
103 |
104 |
105 | Generates GradCAM++ result.
106 |
107 | __Arguments__
108 |
109 | - __input___: torch.Tensor. Preprocessed image with shape (1, C, H, W).
110 | - __class_idx__: int. Index of target class. Defaults to be index of predicted class.
111 |
112 | __Return__
113 |
114 | Result of GradCAM++ (torch.Tensor) with shape (1, H, W).
115 |
116 | ----
117 |
118 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/cam/scorecam.py#L6)
119 | ## ScoreCAM class
120 |
121 | ```python
122 | xdeep.xlocal.gradient.cam.scorecam.ScoreCAM(model_dict)
123 | ```
124 |
125 |
126 | ScoreCAM, inherit from BaseCAM
127 |
128 |
129 | ---
130 | ## ScoreCAM methods
131 |
132 | ### find_layer
133 |
134 |
135 | ```python
136 | find_layer(arch, target_layer_name)
137 | ```
138 |
139 | ---
140 | ### forward
141 |
142 |
143 | ```python
144 | forward(input, class_idx=None, retain_graph=False)
145 | ```
146 |
147 |
148 | Generates ScoreCAM result.
149 |
150 | __Arguments__
151 |
152 | - __input___: torch.Tensor. Preprocessed image with shape (1, C, H, W).
153 | - __class_idx__: int. Index of target class. Defaults to be index of predicted class.
154 |
155 | __Return__
156 |
157 | Result of GradCAM (torch.Tensor) with shape (1, H, W).
158 |
--------------------------------------------------------------------------------
/docs/gradient_explainer.md:
--------------------------------------------------------------------------------
1 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/gradient/explainers.py#L18)
2 | ## ImageInterpreter class
3 |
4 | ```python
5 | xdeep.xlocal.gradient.explainers.ImageInterpreter(model)
6 | ```
7 |
8 | Class for image explanation
9 |
10 | ImageInterpreter provides unified interface to call different visualization methods.
11 | The user can use it by several lines. It includes '__init__()' and 'explain()'. The
12 | users can specify the name of visualization method and target layer. If params are
13 | not specified, default params will be used.
14 |
15 |
16 | ---
17 | ## ImageInterpreter methods
18 |
19 | ### explain
20 |
21 |
22 | ```python
23 | explain(image, method_name, viz=True, target_layer_name=None, target_class=None, save_path=None)
24 | ```
25 |
26 |
27 | Function to call different local visualization methods.
28 |
29 | __Arguments__
30 |
31 | - __image__: str or PIL.Image. The input image to ImageInterpreter. User can directly provide the path of input
32 | image, or provide PIL.Image foramt image. For example, image can be './test.jpg' or
33 | Image.open('./test.jpg').convert('RGB').
34 | - __method_name__: str. The name of interpreter method. Currently support for 'vallina_backprop', 'guided_backprop',
35 | 'smooth_grad', 'smooth_guided_grad', 'integrate_grad', 'integrate_guided_grad', 'gradcam',
36 | 'gradcampp', 'scorecam'.
37 | - __viz__: bool. Visualize or not. Defaults to True.
38 | - __target_layer_name__: str. The layer to hook gradients and activation map. Defaults to the name of the latest
39 | activation map. User can also provide their target layer like 'features_29' or 'layer4'
40 | with respect to different network architectures.
41 | - __target_class__: int. The index of target class. Default to be the index of predicted class.
42 | - __save_path__: str. Path to save the saliency map. Default to be None.
43 |
44 | ---
45 | ### gradcam
46 |
47 |
48 | ```python
49 | gradcam(model_dict)
50 | ```
51 |
52 | ---
53 | ### gradcampp
54 |
55 |
56 | ```python
57 | gradcampp(model_dict)
58 | ```
59 |
60 | ---
61 | ### guided_backprop
62 |
63 |
64 | ```python
65 | guided_backprop(model)
66 | ```
67 |
68 | ---
69 | ### integrate_grad
70 |
71 |
72 | ```python
73 | integrate_grad(model, input_, guided=False)
74 | ```
75 |
76 | ---
77 | ### integrate_guided_grad
78 |
79 |
80 | ```python
81 | integrate_guided_grad(model, input_, guided=True)
82 | ```
83 |
84 | ---
85 | ### scorecam
86 |
87 |
88 | ```python
89 | scorecam(model_dict)
90 | ```
91 |
92 | ---
93 | ### smooth_grad
94 |
95 |
96 | ```python
97 | smooth_grad(model, input_, guided=False)
98 | ```
99 |
100 | ---
101 | ### smooth_guided_grad
102 |
103 |
104 | ```python
105 | smooth_guided_grad(model, input_, guided=True)
106 | ```
107 |
108 | ---
109 | ### vallina_backprop
110 |
111 |
112 | ```python
113 | vallina_backprop(model)
114 | ```
115 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Welcome to XDeep
2 |
3 | XDeep is an open source software library for automated Interpretable Machine Learning. It is developed by [DATA Lab](http://faculty.cs.tamu.edu/xiahu/index.html) at Texas A&M University. The ultimate goal of XDeep is to provide easily accessible interpretation tools to people who want to figure out why the model predicts so. XDeep provides a variety of methods to interpret a model locally and globally.
4 |
5 | ## Installation
6 |
7 | To install the package, please use the pip installation as follows:
8 |
9 | pip install x-deep
10 |
11 | **Note**: currently, XDeep is only compatible with: **Python 3.6**.
12 |
13 | ## Interpretation Methods
14 |
15 | * Local
16 | - Gradient
17 | 1. CAM\-Based
18 | - Grad-CAM
19 | - Grad-CAM++
20 | - Score-CAM
21 | 2. Gradient\-Based
22 | - Vanilla Backpropagation
23 | - Guided Backpropagation
24 | - Smooth Gradient
25 | - Integrate Gradient
26 |
27 | - Perturbation
28 | 1. LIME
29 | 2. Anchors
30 | 3. SHAP
31 | 4. CLE
32 |
33 | * Global
34 | - Maximum Activation(filter, layer, logit, deepdream)
35 | - Inverted Features
36 |
37 | ## Example
38 |
39 | Here is a short example of using the package.
40 |
41 | from xdeep.xlocal.perturbation.xdeep_tabular import TabularExplainer
42 |
43 | iris = sklearn.datasets.load_iris()
44 | train, test, labels_train, labels_test =
45 | sklearn.model_selection.train_test_split(iris.data, iris.target, train_size=0.80)
46 | rf = sklearn.ensemble.RandomForestClassifier(n_estimators=500)
47 | rf.fit(train, labels_train)
48 |
49 | explainer = TabularExplainer(rf.predict_proba,
50 | train, feature_names=iris.feature_names, class_names=iris.target_names)
51 | explainer.set_lime_parameters(discretize_continuous=True)
52 | explainer.explain('lime', test[0])
53 | explainer.show_explanation('lime')
54 |
55 | For detailed tutorial, please check [here](./tutorial_local) for local ones and [here](./tutorial_global) for global ones.
56 |
--------------------------------------------------------------------------------
/docs/tutorial_global.md:
--------------------------------------------------------------------------------
1 | # Tutorial
2 |
3 | ## Global \- Maximum Activation
4 |
5 | ### GlobalExplainer
6 |
7 | from xdeep.xglobal.global_explaners import *
8 | import torchvision.models as models
9 |
10 | # load model
11 | model = models.vgg16(pretrained=True)
12 |
13 | # load global image interpreter
14 | model_explainer = GlobalImageInterpreter(model)
15 |
16 | # generate global explanation via activation maximum
17 | model_explainer.explain(method_name='filter', target_layer='features_24', target_filter=20, num_iter=10, save_path='results/filter.jpg')
18 |
19 | model_explainer.explain(method_name='layer', target_layer='features_24', num_iter=10, save_path='results/layer.jpg')
20 |
21 | model_explainer.explain(method_name='logit', target_layer='features_24', target_filter=20, num_iter=10, save_path='results/logit.jpg')
22 |
23 | model_explainer.explain(method_name='deepdream', target_layer='features_24', target_filter=20, input_='images/jay.jpg', num_iter=10, save_path='results/deepdream.jpg')
24 |
25 | # generate global explanation via inverted representation
26 | model_explainer.explain(method_name='inverted', target_layer='features_24', input_='images/jay.jpg', num_iter=10)
27 |
--------------------------------------------------------------------------------
/docs/tutorial_local.md:
--------------------------------------------------------------------------------
1 | # Tutorial
2 |
3 | ## Local \- Gradient
4 |
5 | import torchvision.models as models
6 | from xdeep.xlocal.gradient.explainers import *
7 |
8 | # load image
9 | image = load_image('tutorial/xlocal.gradient/images/ILSVRC2012_val_00000073.JPEG')
10 |
11 | # load model
12 | model = models.vgg16(pretrained=True)
13 |
14 | # load embedded image interpreter
15 | model_explainer = ImageInterpreter(model)
16 |
17 | # generate saliency map of method specified by 'method_name'
18 | model_explainer.explain(image, method_name='vallina_backprop', viz=True, save_path='results/bp.jpg')
19 |
20 | model_explainer.explain(image, method_name='guided_backprop', viz=True, save_path='results/guided.jpg')
21 |
22 | model_explainer.explain(image, method_name='smooth_grad', viz=True, save_path='results/smooth_grad.jpg')
23 |
24 | model_explainer.explain(image, method_name='integrate_grad', viz=True, save_path='results/integrate.jpg')
25 |
26 | model_explainer.explain(image, method_name='gradcam', target_layer_name='features_29', viz=True, save_path='results/gradcam.jpg')
27 |
28 | model_explainer.explain(image, method_name='gradcampp', target_layer_name='features_29', viz=True, save_path='results/gradcampp.jpg')
29 |
30 | model_explainer.explain(image, method_name='scorecam', target_layer_name='features_29', viz=True, save_path='results/scorecam.jpg')
31 |
32 | ## Local \- Perturbation
33 |
34 | ### Text
35 |
36 | from xdeep.xlocal.perturbation import xdeep_text
37 |
38 | # Load data and model
39 |
40 | model = load_model()
41 | dataset = load_data()
42 |
43 | # Start explaining
44 |
45 | text = dataset.test[0]
46 |
47 | explainer = xdeep_text.TextExplainer(model.predict_proba, dataset.class_names)
48 |
49 | explainer.explain('lime', text)
50 | explainer.show_explanation('lime')
51 |
52 | explainer.explain('cle', text)
53 | explainer.show_explanation('cle')
54 |
55 | explainer.explain('anchor', text)
56 | explainer.show_explanation('anchor')
57 |
58 | # for SHAP please check the Jupyter Notebook tutorial on GitHub.
59 |
60 | ### Tabular
61 |
62 | from xdeep.xlocal.perturbation import xdeep_tabular
63 |
64 | # Load data and model
65 |
66 | model = load_model()
67 | dataset = load_data()
68 |
69 | explainer = xdeep_tabular.TabularExplainer(
70 | model.predict_proba, dataset.train, dataset.class_names, dataset.feature_names,
71 | categorical_features=dataset.categorical_features,
72 | categorical_names=dataset.categorical_names)
73 |
74 | explainer.set_parameters('lime', discretize_continuous=True)
75 | explainer.explain('lime', dataset.test[0])
76 | explainer.show_explanation('lime')
77 |
78 | explainer.set_parameters('cle', discretize_continuous=True)
79 | explainer.explain('cle', dataset.test[0])
80 | explainer.show_explanation('cle')
81 |
82 | explainer.explain('shap', dataset.test[0])
83 | explainer.show_explanation('shap')
84 |
85 | # for Anchor please check the Jupyter Notebook tutorial on GitHub.
86 |
87 | ### Image
88 |
89 | from xdeep.local.perturbation import xdeep_image
90 |
91 | # Load data and model
92 |
93 | model = load_model()
94 | dataset = load_data()
95 |
96 | # Start explaining
97 |
98 | explainer = xdeep_image.ImageExplainer(model.predict_proba, dataset.class_names)
99 |
100 | # deprocess function (If the range of image is [-1, 1], then transfroms it to [0, 1])
101 | def f(x):
102 | return x / 2 + 0.5
103 |
104 | image = dataset.test[0]
105 |
106 | explainer.explain('lime', image, top_labels=3)
107 |
108 | explainer.show_explanation('lime', deprocess=f, positive_only=False)
109 |
110 | explainer.explain('cle', image, top_labels=3)
111 |
112 | explainer.show_explanation('cle', deprocess=f, positive_only=False)
113 |
114 | explainer.explain('anchor', image)
115 |
116 | explainer.show_explanation('anchor')
117 |
118 | from skimage.segmentation import slic
119 | segments_slic = slic(image, n_segments=50, compactness=30, sigma=3)
120 | explainer.initialize_shap(n_segment=50, segment=segments_slic)
121 | explainer.explain('shap', image, nsamples=200)
122 | explainer.show_explanation('shap', deprocess=f)
--------------------------------------------------------------------------------
/docs/xdeep_text.md:
--------------------------------------------------------------------------------
1 | [[source]](.)
2 | ## TextExplainer class
3 |
4 | ```python
5 | xdeep.xlocal.perturbation.xdeep_text.TextExplainer(predict_proba, class_names)
6 | ```
7 |
8 | Integrated explainer which explains text classifiers.
9 |
10 | ---
11 | ## TextExplainer methods
12 |
13 | ### explain
14 |
15 |
16 | ```python
17 | explain(method, instance, top_labels=None, labels=(1,))
18 | ```
19 |
20 |
21 | Generate explanation for a prediction using certain method.
22 |
23 | __Arguments__
24 |
25 | - __method__: Str. The method you want to use.
26 | - __instance__: Instance to be explained.
27 | - __top_labels__: Integer. Number of labels you care about.
28 | - __labels__: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
29 | - __**kwargs__: Parameters setter. For more detail, please check 'explain_instance' in corresponding method.
30 |
31 | ---
32 | ### get_explanation
33 |
34 |
35 | ```python
36 | get_explanation(method)
37 | ```
38 |
39 |
40 | Explanation getter.
41 |
42 | __Arguments__
43 |
44 | - __method__: Str. The method you want to use.
45 |
46 | __Return__
47 |
48 | An Explanation object with the corresponding method.
49 |
50 | ---
51 | ### initialize_shap
52 |
53 |
54 | ```python
55 | initialize_shap(predict_vectorized, vectorizer, train)
56 | ```
57 |
58 |
59 | Explicit shap initializer.
60 |
61 | __Arguments__
62 |
63 | - __predict_vectorized__: Function. Classifier prediction probability function which need vector passed in.
64 | - __vectorizer__: Vectorizer. A vectorizer which has 'transform' function that transforms list of str to vector.
65 | - __train__: Array. Train data, in this case a list of str.
66 |
67 | ---
68 | ### set_parameters
69 |
70 |
71 | ```python
72 | set_parameters(method)
73 | ```
74 |
75 |
76 | Parameter setter.
77 |
78 | __Arguments__
79 |
80 | - __method__: Str. The method you want to use.
81 | - __**kwargs__: Other parameters that depends on which method you use. For more detail, please check xdeep documentation.
82 |
83 | ---
84 | ### show_explanation
85 |
86 |
87 | ```python
88 | show_explanation(method)
89 | ```
90 |
91 |
92 | Visualization of explanation of the corresponding method.
93 |
94 | __Arguments__
95 |
96 | - __method__: Str. The method you want to use.
97 | - __**kwargs__: parameters setter. For more detail, please check xdeep documentation.
98 |
99 | ----
100 |
101 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/perturbation/text_explainers/lime_text.py#L8)
102 | ## XDeepLimeTextExplainer class
103 |
104 | ```python
105 | xdeep.xlocal.perturbation.text_explainers.lime_text.XDeepLimeTextExplainer(predict_proba, class_names)
106 | ```
107 |
108 |
109 | ---
110 | ## XDeepLimeTextExplainer methods
111 |
112 | ### explain
113 |
114 |
115 | ```python
116 | explain(instance, top_labels=None, labels=(1,))
117 | ```
118 |
119 |
120 | Generate explanation for a prediction using certain method.
121 |
122 | __Arguments__
123 |
124 | - __instance__: Str. A raw text string to be explained.
125 | - __top_labels__: Integer. Number of labels you care about.
126 | - __labels__: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
127 | - __**kwargs__: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
128 |
129 | ---
130 | ### set_parameters
131 |
132 |
133 | ```python
134 | set_parameters()
135 | ```
136 |
137 |
138 | Parameter setter for lime_text.
139 |
140 | __Arguments__
141 |
142 | - __**kwargs__: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
143 |
144 | ---
145 | ### show_explanation
146 |
147 |
148 | ```python
149 | show_explanation(span=3)
150 | ```
151 |
152 |
153 | Visualization of explanation of lime_text.
154 |
155 | __Arguments__
156 |
157 | - __span__: Integer. Each row shows how many features.
158 |
159 | ----
160 |
161 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/perturbation/text_explainers/cle_text.py#L9)
162 | ## XDeepCLETextExplainer class
163 |
164 | ```python
165 | xdeep.xlocal.perturbation.text_explainers.cle_text.XDeepCLETextExplainer(predict_proba, class_names)
166 | ```
167 |
168 |
169 | ---
170 | ## XDeepCLETextExplainer methods
171 |
172 | ### explain
173 |
174 |
175 | ```python
176 | explain(instance, top_labels=None, labels=(1,), care_words=None, spans=(2,), include_original_feature=True)
177 | ```
178 |
179 |
180 | Generate explanation for a prediction using certain method.
181 |
182 | __Arguments__
183 |
184 | - __instance__: Str. A raw text string to be explained.
185 | - __top_labels__: Integer. Number of labels you care about.
186 | - __labels__: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
187 | - __**kwarg__: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
188 |
189 | ---
190 | ### set_parameters
191 |
192 |
193 | ```python
194 | set_parameters()
195 | ```
196 |
197 |
198 | Parameter setter for cle_text.
199 |
200 | __Arguments__
201 |
202 | - __**kwargs__: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
203 |
204 | ---
205 | ### show_explanation
206 |
207 |
208 | ```python
209 | show_explanation(span=3, plot=True)
210 | ```
211 |
212 |
213 | Visualization of explanation of cle_text.
214 |
215 | __Arguments__
216 |
217 | - __span__: Integer. Each row shows how many features.
218 | - __plot__: Boolean. Whether plots a figure.
219 |
220 | ----
221 |
222 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/perturbation/text_explainers/anchor_text.py#L9)
223 | ## XDeepAnchorTextExplainer class
224 |
225 | ```python
226 | xdeep.xlocal.perturbation.text_explainers.anchor_text.XDeepAnchorTextExplainer(predict_proba, class_names)
227 | ```
228 |
229 |
230 | ---
231 | ## XDeepAnchorTextExplainer methods
232 |
233 | ### explain
234 |
235 |
236 | ```python
237 | explain(instance, top_labels=None, labels=(1,))
238 | ```
239 |
240 |
241 | Generate explanation for a prediction using certain method.
242 | Anchor does not use top_labels and labels.
243 |
244 | __Arguments__
245 |
246 | - __instance__: Str. A raw text string to be explained.
247 | - __**kwargs__: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor.
248 |
249 | ---
250 | ### set_parameters
251 |
252 |
253 | ```python
254 | set_parameters(nlp=None)
255 | ```
256 |
257 |
258 | Parameter setter for anchor_text.
259 |
260 | __Arguments__
261 |
262 | - __nlp__: Object. A spacy model object.
263 | - __**kwargs__: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor
264 |
265 | ---
266 | ### show_explanation
267 |
268 |
269 | ```python
270 | show_explanation(show_in_note_book=True, verbose=True)
271 | ```
272 |
273 |
274 | Visualization of explanation of anchor_text.
275 |
276 | __Arguments__
277 |
278 | - __show_in_note_book__: Boolean. Whether show in jupyter notebook.
279 | - __verbose__: Boolean. Whether print out examples and counter examples.
280 |
281 | ----
282 |
283 | [[source]](https://github.com/keras-team/keras/blob/master/xdeep/xlocal/perturbation/text_explainers/shap_text.py#L10)
284 | ## XDeepShapTextExplainer class
285 |
286 | ```python
287 | xdeep.xlocal.perturbation.text_explainers.shap_text.XDeepShapTextExplainer(predict_vectorized, class_names, vectorizer, train)
288 | ```
289 |
290 |
291 | ---
292 | ## XDeepShapTextExplainer methods
293 |
294 | ### explain
295 |
296 |
297 | ```python
298 | explain(instance, top_labels=None, labels=(1,))
299 | ```
300 |
301 |
302 | Generate explanation for a prediction using certain method.
303 |
304 | __Arguments__
305 |
306 | - __instance__: Str. A raw text string to be explained.
307 | - __top_labels__: Integer. Number of labels you care about.
308 | - __labels__: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
309 | - __**kwarg__: Parameters setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py.
310 |
311 | ---
312 | ### set_parameters
313 |
314 |
315 | ```python
316 | set_parameters()
317 | ```
318 |
319 |
320 | Parameter setter for shap. The data pass to 'shap' cannot be str but vector.As a result, you need to pass in the vectorizer.
321 |
322 | __Arguments__
323 |
324 | - __**kwargs__: Shap kernel explainer parameter setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py.
325 |
326 | ---
327 | ### show_explanation
328 |
329 |
330 | ```python
331 | show_explanation(show_in_note_book=True)
332 | ```
333 |
334 |
335 | Visualization of explanation of shap.
336 |
337 | __Arguments__
338 |
339 | - __show_in_note_book__: Booleam. Whether show in jupyter notebook.
340 |
--------------------------------------------------------------------------------
/result_img/ensemble_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/result_img/ensemble_fig.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="x-deep",
8 | version="0.0.8",
9 | author="DATA LAB at Texas A&M University",
10 | author_email="hu@cse.tamu.edu",
11 | description="XDeep is an open-source package for Interpretable Machine Learning.",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | url="https://github.com/datamllab/xdeep",
15 | packages=setuptools.find_packages(),
16 | install_requires=[
17 | 'anchor-exp',
18 | 'shap',
19 | 'torch',
20 | 'torchvision',
21 | 'matplotlib',
22 | 'importlib_resources',
23 | ],
24 | extras_require={
25 | 'tests': ['IPython',
26 | 'tensorflow < 2.0.0',
27 | 'pytest',
28 | ],
29 | },
30 | classifiers=[
31 | "Programming Language :: Python :: 3.6",
32 | "License :: OSI Approved :: MIT License",
33 | "Operating System :: OS Independent",
34 | ],
35 | )
36 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/__init__.py
--------------------------------------------------------------------------------
/tests/xglobal/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/__init__.py
--------------------------------------------------------------------------------
/tests/xglobal/images/jay.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/images/jay.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/deepdream.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/deepdream.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/filter.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/filter.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/inverted/Inverted_Iteration_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_0.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/inverted/Inverted_Iteration_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_1.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/inverted/Inverted_Iteration_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_2.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/inverted/Inverted_Iteration_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_3.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/inverted/Inverted_Iteration_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_4.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/inverted/Inverted_Iteration_5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_5.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/inverted/Inverted_Iteration_6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_6.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/inverted/Inverted_Iteration_7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_7.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/inverted/Inverted_Iteration_8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_8.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/inverted/Inverted_Iteration_9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/inverted/Inverted_Iteration_9.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/layer.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/layer.jpg
--------------------------------------------------------------------------------
/tests/xglobal/results/logit.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xglobal/results/logit.jpg
--------------------------------------------------------------------------------
/tests/xglobal/test_global.py:
--------------------------------------------------------------------------------
1 | import torchvision.models as models
2 | from xdeep.xglobal.global_explainers import *
3 | from xdeep.xglobal.methods.inverted_representation import *
4 |
5 |
6 | def test_explainer():
7 |
8 | model = models.vgg16(pretrained=True)
9 | model_explainer = GlobalImageInterpreter(model)
10 |
11 | model_explainer.explain(method_name='filter', target_layer='features_24', target_filter=20, num_iter=10,
12 | save_path='tests/xglobal/results/filter.jpg')
13 | model_explainer.explain(method_name='layer', target_layer='features_24', num_iter=10, save_path='tests/xglobal/results/layer.jpg')
14 | model_explainer.explain(method_name='logit', target_layer='features_24', target_filter=20, num_iter=10,
15 | save_path='tests/xglobal/results/logit.jpg')
16 | model_explainer.explain(method_name='deepdream', target_layer='features_24', target_filter=20,
17 | input_='tests/xglobal/images/jay.jpg', num_iter=10, save_path='tests/xglobal/results/deepdream.jpg')
18 | model_explainer.explain(method_name='inverted', target_layer='features_24', input_='tests/xglobal/images/jay.jpg', num_iter=10)
19 |
20 |
21 | def test_filter():
22 | model = models.vgg16(pretrained=True)
23 | layer = model.features[24]
24 | filters = [45, 271, 363, 409]
25 | g_ascent = GradientAscent(model.features)
26 | g_ascent.visualize(layer, filters, num_iter=30, title='filter visualization', save_path='tests/xglobal/results/filter.jpg')
27 |
28 |
29 | def test_layer():
30 | model = models.vgg16(pretrained=True)
31 | layer = model.features[24]
32 | g_ascent = GradientAscent(model.features)
33 | g_ascent.visualize(layer, num_iter=100, title='layer visualization', save_path='tests/xglobal/results/layer.jpg')
34 |
35 |
36 | def test_deepdream():
37 | model = models.vgg16(pretrained=True)
38 | layer = model.features[24]
39 | g_ascent = GradientAscent(model.features)
40 | g_ascent.visualize(layer=layer, filter_idxs=33, input_='tests/xglobal/images/jay.jpg', title='deepdream',num_iter=50, save_path='tests/xglobal/results/deepdream.jpg')
41 |
42 |
43 | def test_logit():
44 | model = models.vgg16(pretrained=True)
45 | layer = model.classifier[-1]
46 | logit = 17
47 | g_ascent = GradientAscent(model)
48 | g_ascent.visualize(layer, logit, num_iter=30, title='logit visualization', save_path='tests/xglobal/results/logit.jpg')
49 |
50 |
51 | def test_inverted():
52 | image_path = 'tests/xglobal/images/jay.jpg'
53 | image = load_image(image_path)
54 | norm_image = apply_transforms(image)
55 |
56 | model = models.vgg16(pretrained=True)
57 | model_dict = dict(arch=model, layer_name='features_29', input_size=(224, 224))
58 | IR = InvertedRepresentation(model_dict)
59 |
60 | IR.visualize(norm_image)
61 |
--------------------------------------------------------------------------------
/tests/xlocal_gradient/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_gradient/__init__.py
--------------------------------------------------------------------------------
/tests/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_gradient/results/GuidedBP_ILSVRC2012_val_00000073.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_gradient/results/GuidedBP_ILSVRC2012_val_00000073.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_gradient/results/gradcam_ILSVRC2012_val_00000073.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_gradient/results/gradcam_ILSVRC2012_val_00000073.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_gradient/test_explainers.py:
--------------------------------------------------------------------------------
1 | import torchvision.models as models
2 | from xdeep.xlocal.gradient.explainers import *
3 |
4 | image = load_image('tests/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG')
5 | model = models.vgg16(pretrained=True)
6 | model_explainer = ImageInterpreter(model)
7 |
8 | def test_backprop():
9 | model_explainer.explain(image, method_name='vanilla_backprop', viz=True, save_path='tests/xlocal_gradient/results/bp.jpg')
10 | model_explainer.explain(image, method_name='guided_backprop', viz=True, save_path='tests/xlocal_gradient/results/guided.jpg')
11 |
12 | def test_grad():
13 | model_explainer.explain(image, method_name='smooth_grad', viz=True, save_path='tests/xlocal_gradient/results/smooth_grad.jpg')
14 | model_explainer.explain(image, method_name='integrate_grad', viz=True, save_path='tests/xlocal_gradient/results/integrate.jpg')
15 | model_explainer.explain(image, method_name='smooth_guided_grad', viz=True, save_path='tests/xlocal_gradient/results/smooth_guided_grad.jpg')
16 | model_explainer.explain(image, method_name='integrate_guided_grad', viz=True, save_path='tests/xlocal_gradient/results/integrate_guided.jpg')
17 |
18 | def test_cam():
19 | model_explainer.explain(image, method_name='gradcam', target_layer_name='features_29', viz=True, save_path='tests/xlocal_gradient/results/gradcam.jpg')
20 | model_explainer.explain(image, method_name='gradcampp', target_layer_name='features_29', viz=True, save_path='tests/xlocal_gradient/results/gradcampp.jpg')
21 | model_explainer.explain(image, method_name='scorecam', target_layer_name='features_29', viz=True, save_path='tests/xlocal_gradient/results/scorecam.jpg')
22 |
--------------------------------------------------------------------------------
/tests/xlocal_gradient/test_gradcam.py:
--------------------------------------------------------------------------------
1 | import torchvision.models as models
2 |
3 | from xdeep.utils import *
4 | from xdeep.xlocal.gradient.cam.gradcam import *
5 |
6 |
7 | def test_gradcam():
8 |
9 | image_path = 'tests/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG'
10 | image_name = image_path.split('/')[1]
11 | image = load_image(image_path)
12 | norm_image = apply_transforms(image)
13 |
14 | model = models.vgg16(pretrained=True)
15 | model_dict = dict(arch=model, layer_name='features_29', input_size=(224, 224))
16 |
17 | gradcam = GradCAM(model_dict)
18 | output = gradcam(norm_image)
19 | visualize(norm_image, output, save_path='tests/xlocal_gradient/results/gradcam_' + image_name)
20 |
--------------------------------------------------------------------------------
/tests/xlocal_gradient/test_guided_backprop.py:
--------------------------------------------------------------------------------
1 | import torchvision.models as models
2 |
3 | from xdeep.utils import *
4 | from xdeep.xlocal.gradient.backprop.guided_backprop import *
5 |
6 |
7 | def test_bp():
8 |
9 | image_path = 'tests/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG'
10 | image_name = image_path.split('/')[1]
11 | image = load_image(image_path)
12 | norm_image = apply_transforms(image)
13 |
14 | model = models.vgg16(pretrained=True)
15 | Guided_BP = Backprop(model, guided=True)
16 | output = Guided_BP.calculate_gradients(norm_image)
17 | visualize(norm_image, output, save_path='tests/xlocal_gradient/results/GuidedBP_' + image_name)
18 |
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/__init__.py
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/Eskimo dog, husky.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/Eskimo dog, husky.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/Great Pyrenees.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/Great Pyrenees.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/Ibizan hound.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/Ibizan hound.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/bee eater.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/bee eater.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/car mirror.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/car mirror.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/eggnog.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/eggnog.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/kitten.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/kitten.jpg
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/oystercatcher.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/oystercatcher.jpg
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata.README.1.0.txt:
--------------------------------------------------------------------------------
1 |
2 | =======
3 |
4 | Introduction
5 |
6 | This README v1.0 (June, 2005) for the v1.0 sentence polarity dataset comes
7 | from the URL
8 | http://www.cs.cornell.edu/people/pabo/movie-review-data .
9 |
10 | =======
11 |
12 | Citation Info
13 |
14 | This data was first used in Bo Pang and Lillian Lee,
15 | ``Seeing stars: Exploiting class relationships for sentiment categorization
16 | with respect to rating scales.'', Proceedings of the ACL, 2005.
17 |
18 | @InProceedings{Pang+Lee:05a,
19 | author = {Bo Pang and Lillian Lee},
20 | title = {Seeing stars: Exploiting class relationships for sentiment
21 | categorization with respect to rating scales},
22 | booktitle = {Proceedings of the ACL},
23 | year = 2005
24 | }
25 |
26 | =======
27 |
28 | Data Format Summary
29 |
30 | - rt-polaritydata.tar.gz: contains this readme and two data files that
31 | were used in the experiments described in Pang/Lee ACL 2005.
32 |
33 | Specifically:
34 | * rt-polarity.pos contains 5331 positive snippets
35 | * rt-polarity.neg contains 5331 negative snippets
36 |
37 | Each line in these two files corresponds to a single snippet (usually
38 | containing roughly one single sentence); all snippets are down-cased.
39 | The snippets were labeled automatically, as described below (see
40 | section "Label Decision").
41 |
42 | Note: The original source files from which the data in
43 | rt-polaritydata.tar.gz was derived can be found in the subjective
44 | part (Rotten Tomatoes pages) of subjectivity_html.tar.gz (released
45 | with subjectivity dataset v1.0).
46 |
47 |
48 | =======
49 |
50 | Label Decision
51 |
52 | We assumed snippets (from Rotten Tomatoes webpages) for reviews marked with
53 | ``fresh'' are positive, and those for reviews marked with ``rotten'' are
54 | negative.
55 |
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.neg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.neg
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.pos:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.pos
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/sea_snake.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/sea_snake.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/sky.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/sky.jpg
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/tusker.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/tusker.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/data/violin.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tests/xlocal_perturbation/data/violin.JPEG
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/test_image.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models as models
3 | import numpy as np
4 | # from skimage.segmentation import slic
5 | import xdeep.xlocal.perturbation.xdeep_image as xdeep_image
6 | from xdeep.utils import *
7 |
8 | image = load_image('tests/xlocal_perturbation/data/violin.JPEG')
9 | image = np.asarray(image)
10 | model = models.vgg16(pretrained=True)
11 |
12 |
13 | def test_image_data():
14 | with torch.no_grad():
15 | names = create_readable_names_for_imagenet_labels()
16 |
17 | class_names = []
18 | for item in names:
19 | class_names.append(names[item])
20 |
21 | def predict_fn(x):
22 | return model(torch.from_numpy(x.reshape(-1, 3, 224, 224)).float()).numpy()
23 |
24 | explainer = xdeep_image.ImageExplainer(predict_fn, class_names)
25 |
26 | explainer.explain('lime', image, top_labels=1, num_samples=100)
27 | explainer.show_explanation('lime', positive_only=False)
28 |
29 | explainer.explain('cle', image, top_labels=1, num_samples=100)
30 | explainer.show_explanation('cle', positive_only=False)
31 |
32 | # explainer.explain('anchor', image, threshold=0.7, coverage_samples=5000)
33 | # explainer.show_explanation('anchor')
34 |
35 | # segments_slic = slic(image, n_segments=10, compactness=30, sigma=3)
36 | # explainer.initialize_shap(n_segment=10, segment=segments_slic)
37 | # explainer.explain('shap',image, nsamples=10)
38 | # explainer.show_explanation('shap')
39 |
40 |
41 | def create_readable_names_for_imagenet_labels():
42 | """Create a dict mapping label id to human readable string.
43 |
44 | Returns:
45 | labels_to_names: dictionary where keys are integers from to 1000
46 | and values are human-readable names.
47 |
48 | We retrieve a synset file, which contains a list of valid synset labels used
49 | by ILSVRC competition. There is one synset one per line, eg.
50 | # n01440764
51 | # n01443537
52 | We also retrieve a synset_to_human_file, which contains a mapping from synsets
53 | to human-readable names for every synset in Imagenet. These are stored in a
54 | tsv format, as follows:
55 | # n02119247 black fox
56 | # n02119359 silver fox
57 | We assign each synset (in alphabetical order) an integer, starting from 1
58 | (since 0 is reserved for the background class).
59 |
60 | Code is based on
61 | https://github.com/tensorflow/models/blob/master/research/inception/inception/data/build_imagenet_data.py#L463
62 | """
63 |
64 | filename = "tests/xlocal_perturbation/image_util/imagenet_lsvrc_2015_synsets.txt"
65 | synset_list = [s.strip() for s in open(filename).readlines()]
66 | num_synsets_in_ilsvrc = len(synset_list)
67 | if not num_synsets_in_ilsvrc == 1000:
68 | raise AssertionError()
69 |
70 | filename = "tests/xlocal_perturbation/image_util/imagenet_metadata.txt"
71 | synset_to_human_list = open(filename).readlines()
72 | num_synsets_in_all_imagenet = len(synset_to_human_list)
73 | if not num_synsets_in_all_imagenet == 21842:
74 | raise AssertionError()
75 |
76 | synset_to_human = {}
77 | for s in synset_to_human_list:
78 | parts = s.strip().split('\t')
79 | if not len(parts) == 2:
80 | raise AssertionError()
81 | synset = parts[0]
82 | human = parts[1]
83 | synset_to_human[synset] = human
84 |
85 | label_index = 1
86 | labels_to_names = {0: 'background'}
87 | for synset in synset_list:
88 | name = synset_to_human[synset]
89 | labels_to_names[label_index] = name
90 | label_index += 1
91 |
92 | return labels_to_names
93 |
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/test_tabular.py:
--------------------------------------------------------------------------------
1 | from sklearn.ensemble import RandomForestClassifier
2 | from anchor import utils
3 | import xdeep.xlocal.perturbation.xdeep_tabular as xdeep_tabular
4 |
5 | # Please download the dataset at
6 | # https://archive.ics.uci.edu/ml/datasets/adult
7 | # Then reset the path.
8 |
9 | def test_tabular_data():
10 | dataset_folder = 'tests/xlocal_perturbation/data/'
11 | dataset = utils.load_dataset('adult', balance=True, dataset_folder=dataset_folder)
12 |
13 | c = RandomForestClassifier(n_estimators=50, n_jobs=5)
14 | c.fit(dataset.train, dataset.labels_train)
15 |
16 | explainer = xdeep_tabular.TabularExplainer(c.predict_proba, ['<=50K', '>50K'], dataset.feature_names, dataset.train[0:50],
17 | categorical_features=dataset.categorical_features, categorical_names=dataset.categorical_names)
18 |
19 | explainer.set_parameters('lime', discretize_continuous=True)
20 | explainer.explain('lime', dataset.test[0])
21 | explainer.show_explanation('lime')
22 |
23 | explainer.set_parameters('cle', discretize_continuous=True)
24 | explainer.explain('cle', dataset.test[0])
25 | explainer.show_explanation('cle')
26 |
27 | explainer.explain('shap', dataset.test[0])
28 | explainer.show_explanation('shap')
29 |
30 | explainer.initialize_anchor(dataset.data)
31 | encoder = explainer.get_anchor_encoder(dataset.train[0:50], dataset.labels_train, dataset.validation, dataset.labels_validation)
32 | c_new = RandomForestClassifier(n_estimators=50, n_jobs=5)
33 | c_new.fit(encoder.transform(dataset.train), dataset.labels_train)
34 | explainer.set_anchor_predict_proba(c_new.predict_proba)
35 | explainer.explain('anchor', dataset.test[0])
36 | explainer.show_explanation('anchor')
37 |
38 |
--------------------------------------------------------------------------------
/tests/xlocal_perturbation/test_text.py:
--------------------------------------------------------------------------------
1 | import os
2 | import spacy
3 | import numpy as np
4 | from sklearn import model_selection
5 | from sklearn.linear_model import LogisticRegression
6 | from sklearn.pipeline import make_pipeline
7 | from sklearn.feature_extraction.text import CountVectorizer
8 | import xdeep.xlocal.perturbation.xdeep_text as xdeep_text
9 |
10 | # Please download dataset at
11 | # http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz
12 | # Then reset the path.
13 |
14 | def load_polarity(path='tests/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata'):
15 | data = []
16 | labels = []
17 | f_names = ['rt-polarity.neg', 'rt-polarity.pos']
18 | for (l, f) in enumerate(f_names):
19 | for line in open(os.path.join(path, f), 'rb'):
20 | try:
21 | line.decode('utf8')
22 | except:
23 | continue
24 | data.append(line.strip())
25 | labels.append(l)
26 | return data, labels
27 |
28 | def test_text_data():
29 | data, labels = load_polarity()
30 | train, test, train_labels, test_labels = model_selection.train_test_split(data, labels, test_size=.2, random_state=42)
31 | train, val, train_labels, val_labels = model_selection.train_test_split(train, train_labels, test_size=.1, random_state=42)
32 | train_labels = np.array(train_labels)
33 | test_labels = np.array(test_labels)
34 | val_labels = np.array(val_labels)
35 |
36 | vectorizer = CountVectorizer(min_df=1)
37 | vectorizer.fit(train)
38 | train_vectors = vectorizer.transform(train)
39 | test_vectors = vectorizer.transform(test)
40 | val_vectors = vectorizer.transform(val)
41 |
42 | x = LogisticRegression()
43 | x.fit(train_vectors, train_labels)
44 | c = make_pipeline(vectorizer, x)
45 |
46 | explainer = xdeep_text.TextExplainer(c.predict_proba, ['negative', 'positive'])
47 |
48 | text = 'This is a good movie .'
49 |
50 | explainer.explain('lime', text)
51 | explainer.show_explanation('lime')
52 |
53 | explainer.explain('cle', text)
54 | explainer.show_explanation('cle')
55 |
56 | try:
57 | nlp = spacy.load('en_core_web_sm')
58 | explainer.explain('anchor', text)
59 | explainer.show_explanation('anchor')
60 | except OSError:
61 | pass
62 |
63 | explainer.initialize_shap(x.predict_proba, vectorizer, train[0:10])
64 | explainer.explain('shap', text)
65 | explainer.show_explanation('shap')
66 |
--------------------------------------------------------------------------------
/tutorial/notebook/tabular - tutorial-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [],
3 | "metadata": {},
4 | "nbformat": 4,
5 | "nbformat_minor": 2
6 | }
7 |
--------------------------------------------------------------------------------
/tutorial/xglobal/images/jay.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/images/jay.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/deepdream.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/deepdream.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/filter.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/filter.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/inverted/Inverted_Iteration_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_0.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/inverted/Inverted_Iteration_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_1.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/inverted/Inverted_Iteration_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_2.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/inverted/Inverted_Iteration_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_3.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/inverted/Inverted_Iteration_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_4.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/inverted/Inverted_Iteration_5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_5.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/inverted/Inverted_Iteration_6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_6.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/inverted/Inverted_Iteration_7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_7.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/inverted/Inverted_Iteration_8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_8.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/inverted/Inverted_Iteration_9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/inverted/Inverted_Iteration_9.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/layer.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/layer.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/results/logit.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xglobal/results/logit.jpg
--------------------------------------------------------------------------------
/tutorial/xglobal/test_global.py:
--------------------------------------------------------------------------------
1 | import torchvision.models as models
2 | from xdeep.xglobal.global_explainers import *
3 | from xdeep.xglobal.methods.inverted_representation import *
4 |
5 |
6 | def test_explainer():
7 |
8 | model = models.vgg16(pretrained=True)
9 | model_explainer = GlobalImageInterpreter(model)
10 |
11 | model_explainer.explain(method_name='filter', target_layer='features_24', target_filter=20, num_iter=10,
12 | save_path='results/filter.jpg')
13 | model_explainer.explain(method_name='layer', target_layer='features_24', num_iter=10, save_path='results/layer.jpg')
14 | model_explainer.explain(method_name='logit', target_layer='features_24', target_filter=20, num_iter=10,
15 | save_path='results/logit.jpg')
16 | model_explainer.explain(method_name='deepdream', target_layer='features_24', target_filter=20,
17 | input_='images/jay.jpg', num_iter=10, save_path='results/deepdream.jpg')
18 | model_explainer.explain(method_name='inverted', target_layer='features_24', input_='images/jay.jpg', num_iter=10)
19 |
20 |
21 | def test_filter():
22 | model = models.vgg16(pretrained=True)
23 | layer = model.features[24]
24 | filters = [45, 271, 363, 409]
25 | g_ascent = GradientAscent(model.features)
26 | g_ascent.visualize(layer, filters, num_iter=30, title='filter visualization', save_path='results/filter.jpg')
27 |
28 |
29 | def test_layer():
30 | model = models.vgg16(pretrained=True)
31 | layer = model.features[24]
32 | g_ascent = GradientAscent(model.features)
33 | g_ascent.visualize(layer, num_iter=100, title='layer visualization', save_path='results/layer.jpg')
34 |
35 |
36 | def test_deepdream():
37 | model = models.vgg16(pretrained=True)
38 | layer = model.features[24]
39 | g_ascent = GradientAscent(model.features)
40 | g_ascent.visualize(layer=layer, filter_idxs=33, input_='images/jay.jpg', title='deepdream',num_iter=50, save_path='results/deepdream.jpg')
41 |
42 |
43 | def test_logit():
44 | model = models.vgg16(pretrained=True)
45 | layer = model.classifier[-1]
46 | logit = 17
47 | g_ascent = GradientAscent(model)
48 | g_ascent.visualize(layer, logit, num_iter=30, title='logit visualization', save_path='results/logit.jpg')
49 |
50 |
51 | def test_inverted():
52 | image_path = 'images/jay.jpg'
53 | image = load_image(image_path)
54 | norm_image = apply_transforms(image)
55 |
56 | model = models.vgg16(pretrained=True)
57 | model_dict = dict(arch=model, layer_name='features_29', input_size=(224, 224))
58 | IR = InvertedRepresentation(model_dict)
59 |
60 | IR.visualize(norm_image)
61 |
--------------------------------------------------------------------------------
/tutorial/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_gradient/images/ILSVRC2012_val_00000073.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_gradient/results/GuidedBP_ILSVRC2012_val_00000073.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_gradient/results/GuidedBP_ILSVRC2012_val_00000073.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_gradient/results/gradcam_ILSVRC2012_val_00000073.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_gradient/results/gradcam_ILSVRC2012_val_00000073.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_gradient/test_explainers.py:
--------------------------------------------------------------------------------
1 | import torchvision.models as models
2 | from xdeep.xlocal.gradient.explainers import *
3 |
4 |
5 | def test_explainer():
6 |
7 | image = load_image('images/ILSVRC2012_val_00000073.JPEG')
8 | model = models.vgg16(pretrained=True)
9 | model_explainer = ImageInterpreter(model)
10 |
11 | # generate explanations
12 | model_explainer.explain(image, method_name='vanilla_backprop', viz=True, save_path='results/bp.jpg')
13 | model_explainer.explain(image, method_name='guided_backprop', viz=True, save_path='results/guided.jpg')
14 | model_explainer.explain(image, method_name='smooth_grad', viz=True, save_path='results/smooth_grad.jpg')
15 | model_explainer.explain(image, method_name='integrate_grad', viz=True, save_path='results/integrate.jpg')
16 | model_explainer.explain(image, method_name='smooth_guided_grad', viz=True, save_path='results/smooth_guided_grad.jpg')
17 | model_explainer.explain(image, method_name='integrate_guided_grad', viz=True, save_path='results/integrate_guided.jpg')
18 | model_explainer.explain(image, method_name='gradcam', target_layer_name='features_29', viz=True, save_path='results/gradcam.jpg')
19 | model_explainer.explain(image, method_name='gradcampp', target_layer_name='features_29', viz=True, save_path='results/gradcampp.jpg')
20 | model_explainer.explain(image, method_name='scorecam', target_layer_name='features_29', viz=True, save_path='results/scorecam.jpg')
21 |
--------------------------------------------------------------------------------
/tutorial/xlocal_gradient/test_gradcam.py:
--------------------------------------------------------------------------------
1 | import torchvision.models as models
2 |
3 | from xdeep.utils import *
4 | from xdeep.xlocal.gradient.cam.gradcam import *
5 |
6 |
7 | def test_gradcam():
8 |
9 | image_path = 'images/ILSVRC2012_val_00000073.JPEG'
10 | image_name = image_path.split('/')[1]
11 | image = load_image(image_path)
12 | norm_image = apply_transforms(image)
13 |
14 | model = models.vgg16(pretrained=True)
15 | model_dict = dict(arch=model, layer_name='features_29', input_size=(224, 224))
16 |
17 | gradcam = GradCAM(model_dict)
18 | output = gradcam(norm_image)
19 | visualize(norm_image, output, save_path='results/gradcam_' + image_name)
20 |
--------------------------------------------------------------------------------
/tutorial/xlocal_gradient/test_guided_backprop.py:
--------------------------------------------------------------------------------
1 | import torchvision.models as models
2 |
3 | from xdeep.utils import *
4 | from xdeep.xlocal.gradient.backprop.guided_backprop import *
5 |
6 |
7 | def test_bp():
8 |
9 | image_path = 'images/ILSVRC2012_val_00000073.JPEG'
10 | image_name = image_path.split('/')[1]
11 | image = load_image(image_path)
12 | norm_image = apply_transforms(image)
13 |
14 | model = models.vgg16(pretrained=True)
15 | Guided_BP = Backprop(model, guided=True)
16 | output = Guided_BP.calculate_gradients(norm_image)
17 | visualize(norm_image, output, save_path='results/GuidedBP_' + image_name)
18 |
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/Eskimo dog, husky.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/Eskimo dog, husky.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/Great Pyrenees.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/Great Pyrenees.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/Ibizan hound.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/Ibizan hound.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/bee eater.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/bee eater.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/car mirror.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/car mirror.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/eggnog.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/eggnog.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/kitten.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/kitten.jpg
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/oystercatcher.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/oystercatcher.jpg
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata.README.1.0.txt:
--------------------------------------------------------------------------------
1 |
2 | =======
3 |
4 | Introduction
5 |
6 | This README v1.0 (June, 2005) for the v1.0 sentence polarity dataset comes
7 | from the URL
8 | http://www.cs.cornell.edu/people/pabo/movie-review-data .
9 |
10 | =======
11 |
12 | Citation Info
13 |
14 | This data was first used in Bo Pang and Lillian Lee,
15 | ``Seeing stars: Exploiting class relationships for sentiment categorization
16 | with respect to rating scales.'', Proceedings of the ACL, 2005.
17 |
18 | @InProceedings{Pang+Lee:05a,
19 | author = {Bo Pang and Lillian Lee},
20 | title = {Seeing stars: Exploiting class relationships for sentiment
21 | categorization with respect to rating scales},
22 | booktitle = {Proceedings of the ACL},
23 | year = 2005
24 | }
25 |
26 | =======
27 |
28 | Data Format Summary
29 |
30 | - rt-polaritydata.tar.gz: contains this readme and two data files that
31 | were used in the experiments described in Pang/Lee ACL 2005.
32 |
33 | Specifically:
34 | * rt-polarity.pos contains 5331 positive snippets
35 | * rt-polarity.neg contains 5331 negative snippets
36 |
37 | Each line in these two files corresponds to a single snippet (usually
38 | containing roughly one single sentence); all snippets are down-cased.
39 | The snippets were labeled automatically, as described below (see
40 | section "Label Decision").
41 |
42 | Note: The original source files from which the data in
43 | rt-polaritydata.tar.gz was derived can be found in the subjective
44 | part (Rotten Tomatoes pages) of subjectivity_html.tar.gz (released
45 | with subjectivity dataset v1.0).
46 |
47 |
48 | =======
49 |
50 | Label Decision
51 |
52 | We assumed snippets (from Rotten Tomatoes webpages) for reviews marked with
53 | ``fresh'' are positive, and those for reviews marked with ``rotten'' are
54 | negative.
55 |
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.neg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.neg
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.pos:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/rt-polaritydata/rt-polaritydata/rt-polarity.pos
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/sea_snake.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/sea_snake.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/sky.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/sky.jpg
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/tusker.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/tusker.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/data/violin.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/data/violin.JPEG
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/image - tutorial.py:
--------------------------------------------------------------------------------
1 | # Please set the folder slim according to your own path.
2 | # Download slim at https://github.com/tensorflow/models/tree/master/research/slim
3 |
4 | import os
5 | import torch
6 | import torchvision.models as models
7 | import tensorflow as tf
8 | import numpy as np
9 | from image_util import inception_preprocessing, inception_v3
10 | from skimage.segmentation import slic
11 | import xdeep.xlocal.perturbation.xdeep_image as xdeep_image
12 | from xdeep.utils import *
13 |
14 | def test_image_data_torch():
15 | with torch.no_grad():
16 | image = load_image('./xlocal_perturbation/data/violin.JPEG')
17 | image = np.asarray(image)
18 | model = models.vgg16(pretrained=True)
19 |
20 | names = create_readable_names_for_imagenet_labels()
21 |
22 | class_names = []
23 | for item in names:
24 | class_names.append(names[item])
25 |
26 | def predict_fn(x):
27 | return model(torch.from_numpy(x.reshape(-1, 3, 224, 224)).float()).numpy()
28 |
29 | explainer = xdeep_image.ImageExplainer(predict_fn, class_names)
30 |
31 | explainer.explain('lime', image, top_labels=1, num_samples=30)
32 | explainer.show_explanation('lime', positive_only=False)
33 |
34 | explainer.explain('cle', image, top_labels=1, num_samples=30)
35 | explainer.show_explanation('cle', positive_only=False)
36 |
37 | explainer.explain('anchor', image, threshold=0.7, coverage_samples=5000)
38 | explainer.show_explanation('anchor')
39 |
40 | segments_slic = slic(image, n_segments=30, compactness=30, sigma=3)
41 | explainer.initialize_shap(n_segment=30, segment=segments_slic)
42 | explainer.explain('shap',image,nsamples=30)
43 | explainer.show_explanation('shap')
44 |
45 | def test_image_data_tf():
46 | slim = tf.contrib.slim
47 | tf.reset_default_graph()
48 | session = tf.Session()
49 |
50 | names = create_readable_names_for_imagenet_labels()
51 |
52 | processed_images = tf.placeholder(tf.float32, shape=(None, 299, 299, 3))
53 |
54 | with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
55 | logits, _ = inception_v3.inception_v3(processed_images, num_classes=1001, is_training=False)
56 | probabilities = tf.nn.softmax(logits)
57 |
58 | # Please correctly set the model path.
59 | # Download the model at https://github.com/tensorflow/models/tree/master/research/slim
60 | checkpoints_dir = 'model'
61 | init_fn = slim.assign_from_checkpoint_fn(
62 | os.path.join(checkpoints_dir, 'inception_v3.ckpt'),
63 | slim.get_model_variables('InceptionV3'))
64 | init_fn(session)
65 |
66 | def predict_fn(images):
67 | return session.run(probabilities, feed_dict={processed_images: images})
68 |
69 | def f(x):
70 | return x / 2 + 0.5
71 |
72 | class_names = []
73 | for item in names:
74 | class_names.append(names[item])
75 |
76 | def transform_img_fn(path_list):
77 | out = []
78 | image_size = 299
79 | for f in path_list:
80 | with open(f,'rb') as img:
81 | image_raw = tf.image.decode_jpeg(img.read(), channels=3)
82 | image = inception_preprocessing.preprocess_image(image_raw, image_size, image_size, is_training=False)
83 | out.append(image)
84 | return session.run([out])[0]
85 |
86 | images = transform_img_fn(['./data/violin.JPEG'])
87 | image = images[0]
88 |
89 | explainer = xdeep_image.ImageExplainer(predict_fn, class_names)
90 |
91 | explainer.explain('lime', image, top_labels=3)
92 | explainer.show_explanation('lime', deprocess=f, positive_only=False)
93 |
94 | explainer.explain('cle', image, top_labels=3)
95 | explainer.show_explanation('cle', deprocess=f, positive_only=False)
96 |
97 | explainer.explain('anchor', image)
98 | explainer.show_explanation('anchor')
99 |
100 | segments_slic = slic(image, n_segments=50, compactness=30, sigma=3)
101 | explainer.initialize_shap(n_segment=50, segment=segments_slic)
102 | explainer.explain('shap',image,nsamples=400)
103 | explainer.show_explanation('shap',deprocess=f)
104 |
105 | def create_readable_names_for_imagenet_labels():
106 | """Create a dict mapping label id to human readable string.
107 |
108 | Returns:
109 | labels_to_names: dictionary where keys are integers from to 1000
110 | and values are human-readable names.
111 |
112 | We retrieve a synset file, which contains a list of valid synset labels used
113 | by ILSVRC competition. There is one synset one per line, eg.
114 | # n01440764
115 | # n01443537
116 | We also retrieve a synset_to_human_file, which contains a mapping from synsets
117 | to human-readable names for every synset in Imagenet. These are stored in a
118 | tsv format, as follows:
119 | # n02119247 black fox
120 | # n02119359 silver fox
121 | We assign each synset (in alphabetical order) an integer, starting from 1
122 | (since 0 is reserved for the background class).
123 |
124 | Code is based on
125 | https://github.com/tensorflow/models/blob/master/research/inception/inception/data/build_imagenet_data.py#L463
126 | """
127 |
128 | filename = "./image_util/imagenet_lsvrc_2015_synsets.txt"
129 | synset_list = [s.strip() for s in open(filename).readlines()]
130 | num_synsets_in_ilsvrc = len(synset_list)
131 | if not num_synsets_in_ilsvrc == 1000:
132 | raise AssertionError()
133 |
134 | filename = "./image_util/imagenet_metadata.txt"
135 | synset_to_human_list = open(filename).readlines()
136 | num_synsets_in_all_imagenet = len(synset_to_human_list)
137 | if not num_synsets_in_all_imagenet == 21842:
138 | raise AssertionError()
139 |
140 | synset_to_human = {}
141 | for s in synset_to_human_list:
142 | parts = s.strip().split('\t')
143 | if not len(parts) == 2:
144 | raise AssertionError()
145 | synset = parts[0]
146 | human = parts[1]
147 | synset_to_human[synset] = human
148 |
149 | label_index = 1
150 | labels_to_names = {0: 'background'}
151 | for synset in synset_list:
152 | name = synset_to_human[synset]
153 | labels_to_names[label_index] = name
154 | label_index += 1
155 |
156 | return labels_to_names
157 |
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/image_util/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/tutorial/xlocal_perturbation/image_util/.DS_Store
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/tabular - tutorial.py:
--------------------------------------------------------------------------------
1 | from sklearn.ensemble import RandomForestClassifier
2 | from anchor import utils
3 | import xdeep.xlocal.perturbation.xdeep_tabular as xdeep_tabular
4 |
5 | # Please download the dataset at
6 | # https://archive.ics.uci.edu/ml/datasets/adult
7 | # Then reset the path. eg. 'data/'
8 |
9 | def test_tabular_data():
10 | dataset_folder = 'data/'
11 | dataset = utils.load_dataset('adult', balance=True, dataset_folder=dataset_folder)
12 |
13 | c = RandomForestClassifier(n_estimators=50, n_jobs=5)
14 | c.fit(dataset.train, dataset.labels_train)
15 |
16 | explainer = xdeep_tabular.TabularExplainer(c.predict_proba, ['<=50K', '>50K'], dataset.feature_names, dataset.train[0:50],
17 | categorical_features=dataset.categorical_features, categorical_names=dataset.categorical_names)
18 |
19 | explainer.set_parameters('lime', discretize_continuous=True)
20 | explainer.explain('lime', dataset.test[0])
21 | explainer.show_explanation('lime')
22 |
23 | explainer.set_parameters('cle', discretize_continuous=True)
24 | explainer.explain('cle', dataset.test[0])
25 | explainer.show_explanation('cle')
26 |
27 | explainer.explain('shap', dataset.test[0])
28 | explainer.show_explanation('shap')
29 |
30 | explainer.initialize_anchor(dataset.data)
31 | encoder = explainer.get_anchor_encoder(dataset.train[0:50], dataset.labels_train, dataset.validation, dataset.labels_validation)
32 | c_new = RandomForestClassifier(n_estimators=50, n_jobs=5)
33 | c_new.fit(encoder.transform(dataset.train), dataset.labels_train)
34 | explainer.set_anchor_predict_proba(c_new.predict_proba)
35 | explainer.explain('anchor', dataset.test[0])
36 | explainer.show_explanation('anchor')
37 |
--------------------------------------------------------------------------------
/tutorial/xlocal_perturbation/text - tutorial.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import spacy
4 | from sklearn import model_selection
5 | from sklearn.linear_model import LogisticRegression
6 | from sklearn.pipeline import make_pipeline
7 | from sklearn.feature_extraction.text import CountVectorizer
8 | import xdeep.xlocal.perturbation.xdeep_text as xdeep_text
9 |
10 | # Please download dataset at
11 | # http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz
12 | # Then reset the path.
13 |
14 | def load_polarity(path='./data/rt-polaritydata/rt-polaritydata'):
15 | data = []
16 | labels = []
17 | f_names = ['rt-polarity.neg', 'rt-polarity.pos']
18 | for (l, f) in enumerate(f_names):
19 | for line in open(os.path.join(path, f), 'rb'):
20 | try:
21 | line.decode('utf8')
22 | except:
23 | continue
24 | data.append(line.strip())
25 | labels.append(l)
26 | return data, labels
27 |
28 | def test_text_data():
29 | data, labels = load_polarity()
30 | train, test, train_labels, test_labels = model_selection.train_test_split(data, labels, test_size=.2, random_state=42)
31 | train, val, train_labels, val_labels = model_selection.train_test_split(train, train_labels, test_size=.1, random_state=42)
32 | train_labels = np.array(train_labels)
33 | test_labels = np.array(test_labels)
34 | val_labels = np.array(val_labels)
35 |
36 | vectorizer = CountVectorizer(min_df=1)
37 | vectorizer.fit(train)
38 | train_vectors = vectorizer.transform(train)
39 | test_vectors = vectorizer.transform(test)
40 | val_vectors = vectorizer.transform(val)
41 |
42 | x = LogisticRegression()
43 | x.fit(train_vectors, train_labels)
44 | c = make_pipeline(vectorizer, x)
45 |
46 | explainer = xdeep_text.TextExplainer(c.predict_proba, ['negative', 'positive'])
47 |
48 | text = 'This is a good movie .'
49 |
50 | explainer.explain('lime', text)
51 | explainer.show_explanation('lime')
52 |
53 | explainer.explain('cle', text)
54 | explainer.show_explanation('cle')
55 |
56 | try:
57 | nlp = spacy.load('en_core_web_sm')
58 | explainer.explain('anchor', text)
59 | explainer.show_explanation('anchor')
60 | except OSError:
61 | pass
62 |
63 | explainer.initialize_shap(x.predict_proba, vectorizer, train[0:10])
64 | explainer.explain('shap', text)
65 | explainer.show_explanation('shap')
66 |
--------------------------------------------------------------------------------
/xdeep/__init__.py:
--------------------------------------------------------------------------------
1 | name = "x-deep"
--------------------------------------------------------------------------------
/xdeep/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | Misa Ogura. (2019, September 9).
4 | MisaOgura/flashtorch: 0.1.0 (Version v0.1.0).
5 | Zenodo. http://doi.org/10.5281/zenodo.3403214
6 |
7 | This module provides utility functions for image handling and tensor
8 | transformation.
9 |
10 | """
11 | from PIL import Image
12 | import matplotlib.pyplot as plt
13 |
14 | import torchvision.transforms as transforms
15 | import torchvision.transforms.functional as F
16 |
17 | from .imagenet import *
18 |
19 |
20 | def load_image(image_path):
21 | """Loads image as a PIL RGB image.
22 |
23 | Args:
24 | - **image_path (str) - **: A path to the image
25 |
26 | Returns:
27 | An instance of PIL.Image.Image in RGB
28 |
29 | """
30 |
31 | return Image.open(image_path).convert('RGB')
32 |
33 |
34 | def apply_transforms(image, size=224):
35 | """Transforms a PIL image to torch.Tensor.
36 |
37 | Applies a series of tranformations on PIL image including a conversion
38 | to a tensor. The returned tensor has a shape of :math:`(N, C, H, W)` and
39 | is ready to be used as an input to neural networks.
40 |
41 | First the image is resized to 256, then cropped to 224. The `means` and
42 | `stds` for normalisation are taken from numbers used in ImageNet, as
43 | currently developing the package for visualizing pre-trained models.
44 |
45 | The plan is to to expand this to handle custom size/mean/std.
46 |
47 | Args:
48 | image (PIL.Image.Image or numpy array)
49 | size (int, optional, default=224): Desired size (width/height) of the
50 | output tensor
51 |
52 | Shape:
53 | Input: :math:`(C, H, W)` for numpy array
54 | Output: :math:`(N, C, H, W)`
55 |
56 | Returns:
57 | torch.Tensor (torch.float32): Transformed image tensor
58 |
59 | Note:
60 | Symbols used to describe dimensions:
61 | - N: number of images in a batch
62 | - C: number of channels
63 | - H: height of the image
64 | - W: width of the image
65 |
66 | """
67 |
68 | if not isinstance(image, Image.Image):
69 | image = F.to_pil_image(image)
70 |
71 | means = [0.485, 0.456, 0.406]
72 | stds = [0.229, 0.224, 0.225]
73 |
74 | transform = transforms.Compose([
75 | transforms.Resize(size),
76 | transforms.CenterCrop(size),
77 | transforms.ToTensor(),
78 | transforms.Normalize(means, stds)
79 | ])
80 |
81 | tensor = transform(image).unsqueeze(0)
82 |
83 | tensor.requires_grad = True
84 |
85 | return tensor
86 |
87 |
88 | def denormalize(tensor):
89 | """Reverses the normalisation on a tensor.
90 |
91 | Performs a reverse operation on a tensor, so the pixel value range is
92 | between 0 and 1. Useful for when plotting a tensor into an image.
93 |
94 | Normalisation: (image - mean) / std
95 | Denormalisation: image * std + mean
96 |
97 | Args:
98 | tensor (torch.Tensor, dtype=torch.float32): Normalized image tensor
99 |
100 | Shape:
101 | Input: :math:`(N, C, H, W)`
102 | Output: :math:`(N, C, H, W)` (same shape as input)
103 |
104 | Return:
105 | torch.Tensor (torch.float32): Demornalised image tensor with pixel
106 | values between [0, 1]
107 |
108 | Note:
109 | Symbols used to describe dimensions:
110 | - N: number of images in a batch
111 | - C: number of channels
112 | - H: height of the image
113 | - W: width of the image
114 |
115 | """
116 |
117 | means = [0.485, 0.456, 0.406]
118 | stds = [0.229, 0.224, 0.225]
119 |
120 | denormalized = tensor.clone()
121 |
122 | for channel, mean, std in zip(denormalized[0], means, stds):
123 | channel.mul_(std).add_(mean)
124 |
125 | return denormalized
126 |
127 |
128 | def standardize_and_clip(tensor, min_value=0.0, max_value=1.0,
129 | saturation=0.1, brightness=0.5):
130 |
131 | """Standardizes and clips input tensor.
132 | Standardizes the input tensor (mean = 0.0, std = 1.0). The color saturation
133 | and brightness are adjusted, before tensor values are clipped to min/max
134 | (default: 0.0/1.0).
135 | Args:
136 | tensor (torch.Tensor):
137 | min_value (float, optional, default=0.0)
138 | max_value (float, optional, default=1.0)
139 | saturation (float, optional, default=0.1)
140 | brightness (float, optional, default=0.5)
141 | Shape:
142 | Input: :math:`(C, H, W)`
143 | Output: Same as the input
144 | Return:
145 | torch.Tensor (torch.float32): Normalised tensor with values between
146 | [min_value, max_value]
147 | """
148 |
149 | tensor = tensor.detach().cpu()
150 |
151 | mean = tensor.mean()
152 | std = tensor.std()
153 |
154 | if std == 0:
155 | std += 1e-7
156 |
157 | standardized = tensor.sub(mean).div(std).mul(saturation)
158 | clipped = standardized.add(brightness).clamp(min_value, max_value)
159 |
160 | return clipped
161 |
162 |
163 | def format_for_plotting(tensor):
164 | """Formats the shape of tensor for plotting.
165 |
166 | Tensors typically have a shape of :math:`(N, C, H, W)` or :math:`(C, H, W)`
167 | which is not suitable for plotting as images. This function formats an
168 | input tensor :math:`(H, W, C)` for RGB and :math:`(H, W)` for mono-channel
169 | data.
170 |
171 | Args:
172 | tensor (torch.Tensor, torch.float32): Image tensor
173 |
174 | Shape:
175 | Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
176 | Output: :math:`(H, W, C)` or :math:`(H, W)`, respectively
177 |
178 | Return:
179 | torch.Tensor (torch.float32): Formatted image tensor (detached)
180 |
181 | Note:
182 | Symbols used to describe dimensions:
183 | - N: number of images in a batch
184 | - C: number of channels
185 | - H: height of the image
186 | - W: width of the image
187 |
188 | """
189 |
190 | has_batch_dimension = len(tensor.shape) == 4
191 | formatted = tensor.clone()
192 |
193 | if has_batch_dimension:
194 | formatted = tensor.squeeze(0)
195 |
196 | if formatted.shape[0] == 1:
197 | return formatted.squeeze(0).detach()
198 | else:
199 | return formatted.permute(1, 2, 0).detach()
200 |
201 |
202 | def visualize(input_, gradients, save_path=None, cmap='viridis', alpha=0.7):
203 |
204 | """ Method to plot the explanation.
205 |
206 | # Arguments
207 | input_: Tensor. Original image.
208 | gradients: Tensor. Saliency map result.
209 | save_path: String. Defaults to None.
210 | cmap: Defaults to be 'viridis'.
211 | alpha: Defaults to be 0.7.
212 |
213 | """
214 | input_ = format_for_plotting(denormalize(input_))
215 | gradients = format_for_plotting(standardize_and_clip(gradients))
216 |
217 | subplots = [
218 | ('Input image', [(input_, None, None)]),
219 | ('Saliency map', [(gradients, None, None)]),
220 | ('Overlay', [(input_, None, None), (gradients, cmap, alpha)])
221 | ]
222 |
223 | num_subplots = len(subplots)
224 |
225 | fig = plt.figure(figsize=(16, 3))
226 |
227 | for i, (title, images) in enumerate(subplots):
228 | ax = fig.add_subplot(1, num_subplots, i + 1)
229 | ax.set_axis_off()
230 |
231 | for image, cmap, alpha in images:
232 | ax.imshow(image, cmap=cmap, alpha=alpha)
233 |
234 | ax.set_title(title)
235 | if save_path is not None:
236 | plt.savefig(save_path)
237 |
--------------------------------------------------------------------------------
/xdeep/utils/imagenet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | Misa Ogura. (2019, September 9).
4 | MisaOgura/flashtorch: 0.1.0 (Version v0.1.0).
5 | Zenodo. http://doi.org/10.5281/zenodo.3403214
6 | '''
7 |
8 | import json
9 |
10 | from collections.abc import Mapping
11 | from importlib_resources import path
12 |
13 | from . import resources
14 |
15 |
16 | class ImageNetIndex(Mapping):
17 | """Interface to retrieve ImageNet class indeces from class names.
18 |
19 | This class implements a dictionary like object, aiming to provide an
20 | easy-to-use look-up table for finding a target class index from an ImageNet
21 | class name.
22 |
23 | Reference:
24 | - ImageNet class index: https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
25 | - Synsets: http://image-net.org/challenges/LSVRC/2015/browse-synsets
26 |
27 | Note:
28 | Class names in `imagenet_class_index.json` has been slightly modified
29 | from the source due to duplicated class names (e.g. crane). This helps
30 | make the use of this tool simpler.
31 | """
32 |
33 | def __init__(self):
34 | self._index = {}
35 |
36 | with path(resources, 'imagenet_class_index.json') as source_path:
37 | with open(str(source_path), 'r') as source:
38 | data = json.load(source)
39 |
40 | for index, (_, class_name) in data.items():
41 | class_name = class_name.lower().replace('_', ' ')
42 | self._index[class_name] = int(index)
43 |
44 | def __len__(self):
45 | return len(self._index)
46 |
47 | def __iter__(self):
48 | return iter(self._index)
49 |
50 | def __getitem__(self, phrase):
51 | if type(phrase) != str:
52 | raise TypeError('Target class needs to be a string.')
53 |
54 | if phrase in self._index:
55 | return self._index[phrase]
56 |
57 | partial_matches = self._find_partial_matches(phrase)
58 |
59 | if not any(partial_matches):
60 | return None
61 | elif len(partial_matches) > 1:
62 | raise ValueError('Multiple potential matches found: {}' \
63 | .format(', '.join(map(str, partial_matches))))
64 |
65 | target_class = partial_matches.pop()
66 |
67 | return self._index[target_class]
68 |
69 | def __contains__(self, key):
70 | return any(key in name for name in self._index)
71 |
72 | def keys(self):
73 | return self._index.keys()
74 |
75 | def items(self):
76 | return self._index.items()
77 |
78 | def _find_partial_matches(self, phrase):
79 | words = phrase.lower().split(' ')
80 |
81 | # Find the intersection between search words and class names to
82 | # prioritise whole word matches
83 | # e.g. If words = {'dalmatian', 'dog'} then matches 'dalmatian'
84 |
85 | matches = set(words).intersection(set(self.keys()))
86 |
87 | if not any(matches):
88 | # Find substring matches between search words and class names to
89 | # accommodate for fuzzy matches to some extend
90 | # e.g. If words = {'foxhound'} then matches 'english foxhound'
91 |
92 | matches = [key for word in words for key in self.keys() \
93 | if word in key]
94 |
95 | return matches
96 |
--------------------------------------------------------------------------------
/xdeep/utils/resources/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/xdeep/xglobal/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xglobal/__init__.py
--------------------------------------------------------------------------------
/xdeep/xglobal/global_explainers.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Thu Sep 19 18:37:22 2019
3 | @author: Haofan Wang - github.com/haofanwang
4 | """
5 |
6 | from xdeep.xglobal.methods import *
7 | from xdeep.xglobal.methods.activation import *
8 | from xdeep.xglobal.methods.inverted_representation import *
9 |
10 |
11 | class GlobalImageInterpreter(object):
12 | """ Class for global image explanations.
13 |
14 | GlobalImageInterpreter provides unified interface to call different visualization methods.
15 | The user can use it by several lines. It includes '__init__()' and 'explain()'. The
16 | users can specify the name of visualization method and target layer. If params are
17 | not specified, default params will be used.
18 | """
19 |
20 | def __init__(self, model):
21 | """Init
22 |
23 | # Arguments
24 | model: torchvision.models. A pretrained model.
25 | result: torch.Tensor. Save generated saliency map.
26 |
27 | """
28 | self.model = model
29 | self.result = None
30 |
31 | def maximize_activation(self, model):
32 | """
33 | Return class of activation maximum.
34 |
35 | """
36 |
37 | return GradientAscent(model)
38 |
39 | def inverted_feature(self, model_dict):
40 | """
41 | Return class of inverted representation.
42 | """
43 | return InvertedRepresentation(model_dict)
44 |
45 | def explain(self, method_name=None, target_layer=None, target_filter=None, input_=None, num_iter=10, save_path=None):
46 |
47 | """Function to call different visualization methods.
48 |
49 | # Arguments
50 | method_name: str. The name of interpreter method. Currently, global explanation methods support 'filter',
51 | 'layer', 'logit', 'deepdream', 'inverted'.
52 | target_layer: torch.nn.Linear or torch.nn.conv.Conv2d. The objective layer.
53 | target_filter: int or list. Index of filter or filters.
54 | input_: str or Tensor. Path of input image or normalized tensor. Default to be None.
55 | num_iter: int. Iter times. Default to 10.
56 | save_path: str. Path to save generated explanations. Default to None.
57 | """
58 |
59 | method_switch = {"filter": self.maximize_activation,
60 | "layer": self.maximize_activation,
61 | "logit": self.maximize_activation,
62 | "deepdream": self.maximize_activation,
63 | "inverted": self.inverted_feature}
64 |
65 | if method_name is None:
66 | raise TypeError("method_name has to be specified!")
67 |
68 | if method_name == "inverted":
69 | model_dict = dict(arch=self.model, layer_name=target_layer, input_size=(224, 224))
70 | explainer = method_switch[method_name](model_dict)
71 |
72 | self.result = explainer.visualize(input_=input_, img_size=(224, 224), alpha_reg_alpha=6,
73 | alpha_reg_lambda=1e-7, tv_reg_beta=2, tv_reg_lambda=1e-8,
74 | n=num_iter, save_path=save_path)
75 | else:
76 | if target_layer is None:
77 | raise TypeError("target_layer cannot be None!")
78 | elif isinstance(target_layer, str):
79 | target_layer = find_layer(self.model, target_layer)
80 |
81 | explainer = method_switch[method_name](self.model)
82 |
83 | self.result = explainer.visualize(layer=target_layer, filter_idxs=target_filter,
84 | input_=input_, lr=1, num_iter=num_iter, num_subplots=4,
85 | figsize=(4, 4), title='Visualization Result', return_output=True,
86 | save_path=save_path)
87 |
88 |
--------------------------------------------------------------------------------
/xdeep/xglobal/methods/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | def find_layer(arch, target_layer_name):
3 | if target_layer_name is None:
4 | if 'resnet' in str(type(arch)):
5 | target_layer_name = 'layer4'
6 | elif 'alexnet' in str(type(arch)) or 'vgg' in str(type(arch)) or 'squeezenet' in str(
7 | type(arch)) or 'densenet' in str(type(arch)):
8 | target_layer_name = 'features'
9 | else:
10 | raise Exception('Invalid layer name! Please specify layer name.', target_layer_name)
11 |
12 | hierarchy = target_layer_name.split('_')
13 |
14 | if hierarchy[0] not in arch._modules.keys():
15 | raise Exception('Invalid layer name!', target_layer_name)
16 |
17 | target_layer = arch._modules[hierarchy[0]]
18 |
19 | if len(hierarchy) >= 2:
20 | if hierarchy[1] not in target_layer._modules.keys():
21 | raise Exception('Invalid layer name!', target_layer_name)
22 | target_layer = target_layer._modules[hierarchy[1]]
23 |
24 | if len(hierarchy) >= 3:
25 | if hierarchy[2] not in target_layer._modules.keys():
26 | raise Exception('Invalid layer name!', target_layer_name)
27 | target_layer = target_layer._modules[hierarchy[2]]
28 |
29 | if len(hierarchy) >= 4:
30 | if hierarchy[3] not in target_layer._modules.keys():
31 | raise Exception('Invalid layer name!', target_layer_name)
32 | target_layer = target_layer._modules[hierarchy[3]]
33 |
34 | return target_layer
--------------------------------------------------------------------------------
/xdeep/xglobal/methods/inverted_representation.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.optim as optim
5 | from torchvision.utils import save_image
6 |
7 | from xdeep.utils import (load_image, apply_transforms, denormalize)
8 |
9 |
10 | class InvertedRepresentation(object):
11 | def __init__(self, model_dict):
12 |
13 | """Init
14 |
15 | # Arguments:
16 | layer_name: str. Name of target layer.
17 | model: torchvision.models. A pretrained model.
18 | img_size: tuple. Size of input image, default to (224,224).
19 | """
20 |
21 | layer_name = model_dict['layer_name']
22 | self.model = model_dict['arch']
23 | self.model.eval()
24 |
25 | self.img_size = model_dict['input_size']
26 |
27 | self.activations = dict()
28 | self.gradients = dict()
29 |
30 | def forward_hook(module, input, output):
31 | self.activations['value'] = output
32 | return None
33 |
34 | def backward_hook(module, grad_input, grad_output):
35 | self.gradients['value'] = grad_output[0]
36 | return None
37 |
38 | target_layer = self.find_layer(self.model, layer_name)
39 | target_layer.register_forward_hook(forward_hook)
40 | target_layer.register_backward_hook(backward_hook)
41 |
42 | def find_layer(self, arch, target_layer_name):
43 | """
44 | Return target layer
45 | """
46 |
47 | if target_layer_name is None:
48 | if 'resnet' in str(type(arch)):
49 | target_layer_name = 'layer4'
50 | elif 'alexnet' in str(type(arch)) or 'vgg' in str(type(arch)) or 'squeezenet' in str(type(arch)) or 'densenet' in str(type(arch)):
51 | target_layer_name = 'features'
52 | else:
53 | raise Exception('Invalid layer name! Please specify layer name.', target_layer_name)
54 |
55 | hierarchy = target_layer_name.split('_')
56 |
57 | if hierarchy[0] not in arch._modules.keys():
58 | raise Exception('Invalid layer name!', target_layer_name)
59 |
60 | target_layer = arch._modules[hierarchy[0]]
61 |
62 | if len(hierarchy) >= 2:
63 | if hierarchy[1] not in target_layer._modules.keys():
64 | raise Exception('Invalid layer name!', target_layer_name)
65 | target_layer = target_layer._modules[hierarchy[1]]
66 |
67 | if len(hierarchy) >= 3:
68 | if hierarchy[2] not in target_layer._modules.keys():
69 | raise Exception('Invalid layer name!', target_layer_name)
70 | target_layer = target_layer._modules[hierarchy[2]]
71 |
72 | if len(hierarchy) >= 4:
73 | if hierarchy[3] not in target_layer._modules.keys():
74 | raise Exception('Invalid layer name!', target_layer_name)
75 | target_layer = target_layer._modules[hierarchy[3]]
76 |
77 | return target_layer
78 |
79 | def alpha_norm(self, input_, alpha):
80 | """
81 | Converts matrix to vector then calculates the alpha norm
82 | """
83 | return ((input_.view(-1))**alpha).sum()
84 |
85 | def total_variation_norm(self, input_, beta):
86 | """
87 | Total variation norm is the second norm in the paper
88 | represented as R_V(x)
89 | """
90 | to_check = input_[:, :-1, :-1]
91 | one_bottom = input_[:, 1:, :-1]
92 | one_right = input_[:, :-1, 1:]
93 | total_variation = (((to_check - one_bottom)**2 + (to_check - one_right)**2)**(beta/2)).sum()
94 | return total_variation
95 |
96 | def euclidian_loss(self, org_matrix, target_matrix):
97 | """
98 | Euclidian loss is the main loss function.
99 | """
100 | distance_matrix = target_matrix - org_matrix
101 | euclidian_distance = self.alpha_norm(distance_matrix, 2)
102 | normalized_euclidian_distance = euclidian_distance / self.alpha_norm(org_matrix, 2)
103 | return normalized_euclidian_distance
104 |
105 | def visualize(self, input_, img_size=(224, 224), alpha_reg_alpha=6, alpha_reg_lambda=1e-7,
106 | tv_reg_beta=2, tv_reg_lambda=1e-8, n=20, save_path=None):
107 |
108 | if isinstance(input_, str):
109 | input_ = load_image(input_)
110 | input_ = apply_transforms(input_)
111 |
112 | if save_path is None:
113 | if not os.path.exists('results/inverted'):
114 | os.makedirs('results/inverted')
115 | save_path = 'results/inverted'
116 |
117 | # generate random image to be optimized
118 | opt_img = 1e-1 * torch.randn(1, 3, img_size[0], img_size[1])
119 | opt_img.requires_grad = True
120 |
121 | # define optimizer
122 | optimizer = optim.SGD([opt_img], lr=1e4, momentum=0.9)
123 |
124 | # forward propagation to activate hook function
125 | image_output = self.model(input_)
126 | input_image_layer_output = self.activations['value']
127 |
128 | for i in range(n):
129 | optimizer.zero_grad()
130 | opt_img_output = self.model(opt_img)
131 | output = self.activations['value']
132 |
133 | euc_loss = 1e-1 * self.euclidian_loss(input_image_layer_output.detach(), output)
134 |
135 | reg_alpha = alpha_reg_lambda * self.alpha_norm(opt_img, alpha_reg_alpha)
136 | reg_total_variation = tv_reg_lambda * self.total_variation_norm(opt_img, tv_reg_beta)
137 |
138 | loss = euc_loss + reg_alpha + reg_total_variation
139 |
140 | loss.backward()
141 | optimizer.step()
142 |
143 | # save image
144 | if i % 1 == 0:
145 | print('Iteration:', str(i), 'Loss:', loss.data.numpy())
146 | recreated_img = denormalize(opt_img)
147 | img_path = os.path.join(save_path, 'Inverted_Iteration_' + str(i) + '.jpg')
148 | save_image(recreated_img, img_path)
149 |
150 | if i % 40 == 0:
151 | for param_group in optimizer.param_groups:
152 | param_group['lr'] *= 0.1
153 |
--------------------------------------------------------------------------------
/xdeep/xlocal/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/__init__.py
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/gradient/__init__.py
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/backprop/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/gradient/backprop/__init__.py
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/backprop/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BaseProp(object):
6 | """
7 | Base class for backpropagation.
8 | """
9 |
10 | def __init__(self, model):
11 | """Init
12 |
13 | # Arguments:
14 | model: torchvision.models. A pretrained model.
15 | handle: list. Handle list that register a hook function.
16 | relu_outputs: list. Forward output after relu.
17 |
18 | """
19 | self.model = model
20 | self.handle = []
21 | self.relu_outputs = []
22 |
23 | def _register_conv_hook(self):
24 |
25 | """
26 | Register hook function to save gradient w.r.t input image.
27 | """
28 |
29 | def _record_gradients(module, grad_in, grad_out):
30 | self.gradients = grad_in[0]
31 |
32 | for _, module in self.model.named_modules():
33 | if isinstance(module, nn.modules.conv.Conv2d) and module.in_channels == 3:
34 | backward_handle = module.register_backward_hook(_record_gradients)
35 | self.handle.append(backward_handle)
36 |
37 | def _register_relu_hooks(self):
38 |
39 | """
40 | Register hook function to save forward and backward relu result.
41 | """
42 |
43 | # Save forward propagation output of the ReLU layer
44 | def _record_output(module, input_, output):
45 | self.relu_outputs.append(output)
46 |
47 | def _clip_gradients(module, grad_in, grad_out):
48 | # keep positive forward propagation output
49 | relu_output = self.relu_outputs.pop()
50 | relu_output[relu_output > 0] = 1
51 |
52 | # keep positive backward propagation gradient
53 | positive_grad_out = torch.clamp(grad_out[0], min=0.0)
54 |
55 | # generate modified guided gradient
56 | modified_grad_out = positive_grad_out * relu_output
57 |
58 | return (modified_grad_out, )
59 |
60 | for _, module in self.model.named_modules():
61 | if isinstance(module, nn.ReLU):
62 | forward_handle = module.register_forward_hook(_record_output)
63 | backward_handle = module.register_backward_hook(_clip_gradients)
64 | self.handle.append(forward_handle)
65 | self.handle.append(backward_handle)
66 |
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/backprop/guided_backprop.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 |
3 |
4 | class Backprop(BaseProp):
5 |
6 | """ Generates vanilla or guided backprop gradients of a target class output w.r.t. an input image.
7 |
8 | # Arguments:
9 | model: torchvision.models. A pretrained model.
10 | guided: bool. If True, perform guided backpropagation. Defaults to False.
11 |
12 | # Return:
13 | Backprop Class.
14 | """
15 |
16 | def __init__(self, model, guided=False):
17 | super().__init__(model)
18 | self.model.eval()
19 | self.guided = guided
20 | self.gradients = None
21 | self._register_conv_hook()
22 |
23 | def calculate_gradients(self,
24 | input_,
25 | target_class=None,
26 | take_max=False,
27 | use_gpu=False):
28 |
29 | """ Calculate gradient.
30 |
31 | # Arguments
32 | input_: torch.Tensor. Preprocessed image with shape (1, C, H, W).
33 | target_class: int. Index of target class. Default to None and use the prediction result as target class.
34 | take_max: bool. Take the maximum across colour channels. Defaults to False.
35 | use_gpu. bool. Use GPU or not. Defaults to False.
36 |
37 | # Return:
38 | Gradient (torch.Tensor) with shape (C, H, W). If take max is True, with shape (1, H, W).
39 | """
40 |
41 | if self.guided:
42 | self._register_relu_hooks()
43 |
44 | if torch.cuda.is_available() and use_gpu:
45 | self.model = self.model.to('cuda')
46 | input_ = input_.to('cuda')
47 |
48 | # Create a empty tensor to save gradients
49 | self.gradients = torch.zeros(input_.shape)
50 |
51 | output = self.model(input_)
52 |
53 | self.model.zero_grad()
54 |
55 | if output.shape == torch.Size([1]):
56 | target = None
57 | else:
58 | pred_class = output.argmax().item()
59 |
60 | # Create a Tensor with zero elements, set the element at pred class index to be 1
61 | target = torch.zeros(output.shape)
62 |
63 | # If target class is None, calculate gradient of predicted class.
64 | if target_class is None:
65 | target[0][pred_class] = 1
66 | else:
67 | target[0][target_class] = 1
68 |
69 | if torch.cuda.is_available() and use_gpu:
70 | target = target.to('cuda')
71 |
72 | # Calculate gradients w.r.t. input image
73 | output.backward(gradient=target)
74 |
75 | gradients = self.gradients.detach().cpu()[0]
76 |
77 | if take_max:
78 | gradients = gradients.max(dim=0, keepdim=True)[0]
79 |
80 | for module in self.handle:
81 | module.remove()
82 |
83 | return gradients
84 |
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/backprop/integrate_grad.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | def generate_integrated_grad(explainer, input_, target_class=None, n=25):
6 |
7 | """ Generates integrate gradients of given explainer. You can use this with both vanilla
8 | and guided backprop
9 |
10 | # Arguments:
11 | explainer: class. Backprop method.
12 | input_: torch.Tensor. Preprocessed image with shape (N, C, H, W).
13 | target_class: int. Index of target class. Default to None.
14 | n: int. Integrate steps. Default to 10.
15 |
16 | # Return:
17 | Integrated gradient (torch.Tensor) with shape (C, H, W).
18 | """
19 |
20 | # Generate xbar images
21 | step_list = np.arange(n + 1) / n
22 |
23 | # Remove zero
24 | step_list = step_list[1:]
25 |
26 | # Create a empty tensor
27 | integrated_grads = torch.zeros(input_.size()).squeeze()
28 |
29 | for step in step_list:
30 | single_integrated_grad = explainer.calculate_gradients(input_*step, target_class)
31 | integrated_grads = integrated_grads + single_integrated_grad / n
32 |
33 | return integrated_grads
34 |
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/backprop/smooth_grad.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def generate_smooth_grad(explainer, input_, target_class=None, n=50, mean=0, sigma_multiplier=4):
5 |
6 | """ Generates smooth gradients of given explainer.
7 |
8 | # Arguments
9 | explainer: class. Backprop method.
10 | input_: torch.Tensor. Preprocessed image with shape (1, C, H, W).
11 | target_class: int. Index of target class. Defaults to None.
12 | n: int. Amount of noisy images used to smooth gradient. Default to 10.
13 | mean: int. Mean value of normal distribution when generating noise. Default to 0.
14 | sigma_multiplier: int. Sigma multiplier when calculating std of noise. Default to 4.
15 |
16 | # Return:
17 | Smooth gradient (torch.Tensor) with shape (C, H, W).
18 | """
19 |
20 | # Generate an empty image with shape (C, H, W) to save smooth gradient
21 | smooth_grad = torch.zeros(input_.size()).squeeze()
22 |
23 | sigma = sigma_multiplier / (torch.max(input_)-torch.min(input_)).item()
24 |
25 | x = 0
26 | while x < n:
27 | # Generate noise
28 | noise = input_.new(input_.size()).normal_(mean, sigma**2)
29 |
30 | # Add noise to the image
31 | noisy_img = input_ + noise
32 |
33 | # calculate gradient for noisy image
34 | grads = explainer.calculate_gradients(noisy_img, target_class)
35 |
36 | # accumulate gradient
37 | smooth_grad = smooth_grad + grads
38 |
39 | x += 1
40 |
41 | # average
42 | smooth_grad = smooth_grad / n
43 |
44 | return smooth_grad
45 |
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/cam/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/gradient/cam/__init__.py
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/cam/basecam.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class BaseCAM(object):
4 |
5 | """
6 | Base class for Class Activation Mapping.
7 | """
8 |
9 | def __init__(self, model_dict):
10 | """Init
11 |
12 | # Arguments
13 | model_dict: dict. A dict with format model_dict = dict(arch=self.model, layer_name=target_layer_name).
14 |
15 | """
16 |
17 | layer_name = model_dict['layer_name']
18 |
19 | self.model_arch = model_dict['arch']
20 | self.model_arch.eval()
21 | self.gradients = dict()
22 | self.activations = dict()
23 |
24 | # save gradient
25 | def backward_hook(module, grad_input, grad_output):
26 | self.gradients['value'] = grad_output[0]
27 | return None
28 |
29 | # save activation map
30 | def forward_hook(module, input, output):
31 | self.activations['value'] = output
32 | return None
33 |
34 | target_layer = self.find_layer(self.model_arch, layer_name)
35 | target_layer.register_forward_hook(forward_hook)
36 | target_layer.register_backward_hook(backward_hook)
37 |
38 | def find_layer(self, arch, target_layer_name):
39 |
40 | if target_layer_name is None:
41 | if 'resnet' in str(type(arch)):
42 | target_layer_name = 'layer4'
43 | elif 'alexnet' in str(type(arch)) or 'vgg' in str(type(arch)) or 'squeezenet' in str(type(arch)) or 'densenet' in str(type(arch)):
44 | target_layer_name = 'features'
45 | else:
46 | raise Exception('Invalid layer name! Please specify layer name.', target_layer_name)
47 |
48 | hierarchy = target_layer_name.split('_')
49 |
50 | if hierarchy[0] not in arch._modules.keys():
51 | raise Exception('Invalid layer name!', target_layer_name)
52 |
53 | target_layer = arch._modules[hierarchy[0]]
54 |
55 | if len(hierarchy) >= 2:
56 | if hierarchy[1] not in target_layer._modules.keys():
57 | raise Exception('Invalid layer name!', target_layer_name)
58 | target_layer = target_layer._modules[hierarchy[1]]
59 |
60 | if len(hierarchy) >= 3:
61 | if hierarchy[2] not in target_layer._modules.keys():
62 | raise Exception('Invalid layer name!', target_layer_name)
63 | target_layer = target_layer._modules[hierarchy[2]]
64 |
65 | if len(hierarchy) >= 4:
66 | if hierarchy[3] not in target_layer._modules.keys():
67 | raise Exception('Invalid layer name!', target_layer_name)
68 | target_layer = target_layer._modules[hierarchy[3]]
69 |
70 | return target_layer
71 |
72 | def forward(self, input_, class_idx=None, retain_graph=False):
73 | return None
74 |
75 | def __call__(self, input_, class_idx=None, retain_graph=False):
76 | return self.forward(input_, class_idx, retain_graph)
77 |
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/cam/gradcam.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from .basecam import BaseCAM
3 |
4 |
5 | class GradCAM(BaseCAM):
6 | """
7 | GradCAM, inherit from BaseCAM
8 | """
9 |
10 | def __init__(self, model_dict):
11 | super().__init__(model_dict)
12 |
13 | def forward(self, input_, class_idx=None, retain_graph=False):
14 | """Generates GradCAM result.
15 |
16 | # Arguments
17 | input_: torch.Tensor. Preprocessed image with shape (1, C, H, W).
18 | class_idx: int. Index of target class. Defaults to be index of predicted class.
19 |
20 | # Return
21 | Result of GradCAM (torch.Tensor) with shape (1, H, W).
22 | """
23 |
24 | b, c, h, w = input_.size()
25 | logit = self.model_arch(input_)
26 |
27 | if class_idx is None:
28 | score = logit[:, logit.max(1)[-1]].squeeze()
29 | else:
30 | score = logit[:, class_idx].squeeze()
31 |
32 | self.model_arch.zero_grad()
33 | score.backward(retain_graph=retain_graph)
34 |
35 | gradients = self.gradients['value']
36 | activations = self.activations['value']
37 | b, k, u, v = gradients.size()
38 |
39 | alpha = gradients.view(b, k, -1).mean(2)
40 | weights = alpha.view(b, k, 1, 1)
41 |
42 | saliency_map = (weights * activations).sum(1, keepdim=True)
43 | saliency_map = F.relu(saliency_map)
44 | saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
45 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
46 | saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data
47 |
48 | return saliency_map
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/cam/gradcampp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .gradcam import GradCAM
4 |
5 |
6 | class GradCAMpp(GradCAM):
7 | """
8 | GradCAM++, inherit from BaseCAM
9 | """
10 |
11 | def __init__(self, model_dict):
12 | super(GradCAMpp, self).__init__(model_dict)
13 |
14 | def forward(self, input_image, class_idx=None, retain_graph=False):
15 |
16 | """Generates GradCAM++ result.
17 |
18 | # Arguments
19 | input_image: torch.Tensor. Preprocessed image with shape (1, C, H, W).
20 | class_idx: int. Index of target class. Defaults to be index of predicted class.
21 |
22 | # Return
23 | Result of GradCAM++ (torch.Tensor) with shape (1, H, W).
24 | """
25 |
26 | b, c, h, w = input_image.size()
27 |
28 | logit = self.model_arch(input_image)
29 | if class_idx is None:
30 | score = logit[:, logit.max(1)[-1]].squeeze()
31 | else:
32 | score = logit[:, class_idx].squeeze()
33 |
34 | self.model_arch.zero_grad()
35 | score.backward(retain_graph=retain_graph)
36 | gradients = self.gradients['value']
37 | activations = self.activations['value']
38 | b, k, u, v = gradients.size()
39 |
40 | alpha_num = gradients.pow(2)
41 |
42 | global_sum = activations.view(b, k, u * v).sum(-1, keepdim=True).view(b, k, 1, 1)
43 | alpha_denom = gradients.pow(2).mul(2) + global_sum.mul(gradients.pow(3))
44 |
45 | alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom))
46 |
47 | alpha = alpha_num.div(alpha_denom + 1e-7)
48 | positive_gradients = F.relu(score.exp() * gradients)
49 | weights = (alpha * positive_gradients).view(b, k, u * v).sum(-1).view(b, k, 1, 1)
50 |
51 | saliency_map = (weights * activations).sum(1, keepdim=True)
52 | saliency_map = F.relu(saliency_map)
53 | saliency_map = F.interpolate(saliency_map, size=(224, 224), mode='bilinear', align_corners=False)
54 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
55 | saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data
56 |
57 | return saliency_map
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/cam/scorecam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .basecam import *
4 |
5 |
6 | class ScoreCAM(BaseCAM):
7 | """
8 | ScoreCAM, inherit from BaseCAM
9 | """
10 |
11 | def __init__(self, model_dict):
12 | super().__init__(model_dict)
13 |
14 | def forward(self, input, class_idx=None, retain_graph=False):
15 |
16 | """Generates ScoreCAM result.
17 |
18 | # Arguments
19 | input_: torch.Tensor. Preprocessed image with shape (1, C, H, W).
20 | class_idx: int. Index of target class. Defaults to be index of predicted class.
21 |
22 | # Return
23 | Result of GradCAM (torch.Tensor) with shape (1, H, W).
24 | """
25 |
26 | b, c, h, w = input.size()
27 |
28 | logit = self.model_arch(input)
29 |
30 | if class_idx is None:
31 | predicted_class = logit.max(1)[-1]
32 | score = logit[:, predicted_class].squeeze()
33 | else:
34 | predicted_class = torch.LongTensor([class_idx])
35 | score = logit[:, predicted_class].squeeze()
36 |
37 | self.model_arch.zero_grad()
38 | score.backward(retain_graph=retain_graph)
39 | gradients = self.gradients['value']
40 | activations = self.activations['value']
41 | b, k, u, v = gradients.size()
42 | score_saliency_map = torch.zeros((1, 1, h, w))
43 |
44 | with torch.no_grad():
45 | for i in range(k):
46 | saliency_map = torch.unsqueeze(activations[:, i, :, :], 1)
47 | saliency_map = F.relu(saliency_map)
48 | saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
49 |
50 | if (saliency_map.max() == saliency_map.min()).numpy().tolist() == 1:
51 | continue
52 |
53 | norm_saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())
54 | output = self.model_arch(input * norm_saliency_map)
55 | output = (output - output.mean()) / output.std()
56 | score = output[0][predicted_class]
57 | score = F.softmax(score)
58 | score_saliency_map += score * saliency_map
59 |
60 | score_saliency_map = F.relu(score_saliency_map)
61 | score_saliency_map_min, score_saliency_map_max = score_saliency_map.min(), score_saliency_map.max()
62 |
63 | if score_saliency_map_min == score_saliency_map_max:
64 | return None
65 |
66 | score_saliency_map = (score_saliency_map - score_saliency_map_min).div(score_saliency_map_max - score_saliency_map_min).data
67 | return score_saliency_map
68 |
--------------------------------------------------------------------------------
/xdeep/xlocal/gradient/explainers.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Thu Aug 02 16:21:22 2019
3 | @author: Haofan Wang - github.com/haofanwang
4 | """
5 |
6 |
7 | from xdeep.utils import *
8 |
9 | from xdeep.xlocal.gradient.backprop.guided_backprop import Backprop
10 | from xdeep.xlocal.gradient.backprop.integrate_grad import generate_integrated_grad
11 | from xdeep.xlocal.gradient.backprop.smooth_grad import generate_smooth_grad
12 |
13 | from xdeep.xlocal.gradient.cam.gradcam import GradCAM
14 | from xdeep.xlocal.gradient.cam.gradcampp import GradCAMpp
15 | from xdeep.xlocal.gradient.cam.scorecam import ScoreCAM
16 |
17 |
18 | class ImageInterpreter(object):
19 |
20 | """ Class for image explanation
21 |
22 | ImageInterpreter provides unified interface to call different visualization methods.
23 | The user can use it by several lines. It includes '__init__()' and 'explain()'. The
24 | users can specify the name of visualization method and target layer. If params are
25 | not specified, default params will be used.
26 | """
27 |
28 | def __init__(self, model):
29 | """Init
30 |
31 | # Arguments
32 | model: torchvision.models. A pretrained model.
33 | result: torch.Tensor. Generated saliency map.
34 |
35 | """
36 | self.model = model
37 | self.result = None
38 |
39 | def vanilla_backprop(self, model):
40 | return Backprop(model, guided=False)
41 |
42 | def guided_backprop(self, model):
43 | return Backprop(model, guided=True)
44 |
45 | def smooth_grad(self, model, input_, guided=False):
46 | explainer = Backprop(model, guided=guided)
47 | return generate_smooth_grad(explainer, input_)
48 |
49 | def smooth_guided_grad(self, model, input_, guided=True):
50 | explainer = Backprop(model, guided=guided)
51 | return generate_smooth_grad(explainer, input_)
52 |
53 | def integrate_grad(self, model, input_, guided=False):
54 | explainer = Backprop(model, guided=guided)
55 | return generate_integrated_grad(explainer, input_)
56 |
57 | def integrate_guided_grad(self, model, input_, guided=True):
58 | explainer = Backprop(model, guided=guided)
59 | return generate_integrated_grad(explainer, input_)
60 |
61 | def gradcam(self, model_dict):
62 | return GradCAM(model_dict)
63 |
64 | def gradcampp(self, model_dict):
65 | return GradCAMpp(model_dict)
66 |
67 | def scorecam(self, model_dict):
68 | return ScoreCAM(model_dict)
69 |
70 | def explain(self, image, method_name, viz=True, target_layer_name=None, target_class=None, save_path=None):
71 |
72 | """Function to call different local visualization methods.
73 |
74 | # Arguments
75 | image: str or PIL.Image. The input image to ImageInterpreter. User can directly provide the path of input
76 | image, or provide PIL.Image foramt image. For example, image can be './test.jpg' or
77 | Image.open('./test.jpg').convert('RGB').
78 | method_name: str. The name of interpreter method. Currently support for 'vanilla_backprop', 'guided_backprop',
79 | 'smooth_grad', 'smooth_guided_grad', 'integrate_grad', 'integrate_guided_grad', 'gradcam',
80 | 'gradcampp', 'scorecam'.
81 | viz: bool. Visualize or not. Defaults to True.
82 | target_layer_name: str. The layer to hook gradients and activation map. Defaults to the name of the latest
83 | activation map. User can also provide their target layer like 'features_29' or 'layer4'
84 | with respect to different network architectures.
85 | target_class: int. The index of target class. Default to be the index of predicted class.
86 | save_path: str. Path to save the saliency map. Default to be None.
87 | """
88 |
89 | method_switch = {"vanilla_backprop": self.vanilla_backprop,
90 | "guided_backprop": self.guided_backprop,
91 | "smooth_grad": self.smooth_grad,
92 | "smooth_guided_grad": self.smooth_guided_grad,
93 | "integrate_grad": self.integrate_grad,
94 | "integrate_guided_grad": self.integrate_guided_grad,
95 | "gradcam": self.gradcam,
96 | "gradcampp": self.gradcampp,
97 | "scorecam": self.scorecam}
98 |
99 | if type(image) is str:
100 | image = load_image(image)
101 |
102 | norm_image = apply_transforms(image)
103 |
104 | if method_name not in method_switch.keys():
105 | raise Exception("Non-support method!", method_name)
106 |
107 | if 'backprop' in method_name:
108 | explainer = method_switch[method_name](self.model)
109 | self.result = explainer.calculate_gradients(norm_image, target_class=target_class)
110 | elif 'cam' in method_name:
111 | model_dict = dict(arch=self.model, layer_name=target_layer_name, input_size=(224, 224))
112 | explainer = method_switch[method_name](model_dict)
113 | self.result = explainer(norm_image, class_idx=target_class)
114 | else:
115 | self.result = method_switch[method_name](self.model, norm_image)
116 |
117 | if viz is True:
118 | visualize(norm_image, self.result, save_path=save_path)
119 |
120 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/.ipynb_checkpoints/Image_Tutorial-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 3,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import warnings\n",
10 | "warnings.filterwarnings('ignore')\n",
11 | "import tensorflow as tf\n",
12 | "slim = tf.contrib.slim\n",
13 | "import sys\n",
14 | "sys.path.append('/Users/zhangzijian/models/research/slim')\n",
15 | "import matplotlib.pyplot as plt\n",
16 | "from nets import inception\n",
17 | "from preprocessing import inception_preprocessing\n",
18 | "from datasets import imagenet\n",
19 | "import time"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 4,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import xdeep_image"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": []
37 | }
38 | ],
39 | "metadata": {
40 | "kernelspec": {
41 | "display_name": "Python 3",
42 | "language": "python",
43 | "name": "python3"
44 | },
45 | "language_info": {
46 | "codemirror_mode": {
47 | "name": "ipython",
48 | "version": 3
49 | },
50 | "file_extension": ".py",
51 | "mimetype": "text/x-python",
52 | "name": "python",
53 | "nbconvert_exporter": "python",
54 | "pygments_lexer": "ipython3",
55 | "version": "3.6.6"
56 | }
57 | },
58 | "nbformat": 4,
59 | "nbformat_minor": 2
60 | }
61 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/perturbation/__init__.py
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/cle/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/perturbation/cle/__init__.py
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/exceptions.py:
--------------------------------------------------------------------------------
1 | class XDeepError(Exception):
2 | """Raise for XDeep errors"""
3 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/explainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Base explainer for xdeep.
3 | """
4 | from .exceptions import XDeepError
5 |
6 |
7 | class Explainer(object):
8 | """A base explainer."""
9 |
10 | def __init__(self, predict_proba, class_names):
11 | """Init function.
12 |
13 | # Arguments
14 | predict_proba: Function. Classifier prediction probability function.
15 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
16 | """
17 | self.labels = list()
18 | self.explainer = None
19 | self.explanation = None
20 | self.instance = None
21 | self.original_pred = None
22 | self.class_names = class_names
23 | self.predict_proba = predict_proba
24 |
25 | def set_parameters(self, **kwargs):
26 | """Parameter setter.
27 |
28 | # Arguments
29 | **kwargs: Other parameters that depends on which method you use. For more detail, please check xdeep documentation.
30 | """
31 | pass
32 |
33 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs):
34 | """Explain function.
35 |
36 | # Arguments
37 | instance: One instance to be explained.
38 | top_labels: Integer. Number of labels you care about.
39 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
40 | **kwargs: Other parameters that depends on which method you use. For more detail, please check xdeep documentation.
41 | """
42 | self.instance = instance
43 | try:
44 | proba = self.predict_proba([instance])[0]
45 | except:
46 | proba = self.predict_proba(instance)[0]
47 | self.original_pred = proba
48 | if top_labels is not None:
49 | self.labels = proba.argsort()[-top_labels:]
50 | else:
51 | self.labels = labels
52 |
53 | def show_explanation(self, **kwargs):
54 | """Visualization of explanation.
55 |
56 | # Arguments
57 | **kwargs: Other parameters that depends on which method you use. For more detail, please check xdeep documentation.
58 | """
59 | if not hasattr(self, 'explanation'):
60 | raise XDeepError("Please call explain function first.")
61 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/image_explainers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/perturbation/image_explainers/__init__.py
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/image_explainers/anchor_image.py:
--------------------------------------------------------------------------------
1 | # The implementation of anchor refers the original authors' codes in GitHub https://github.com/marcotcr/anchor.
2 | # The Copyright of algorithm anchor is reserved for (c) 2018, Marco Tulio Correia Ribeiro.
3 | import copy
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 |
7 | from skimage.segmentation import mark_boundaries
8 | from anchor.anchor_image import AnchorImage
9 | from ..explainer import Explainer
10 |
11 |
12 | class XDeepAnchorImageExplainer(Explainer):
13 |
14 | def __init__(self, predict_proba, class_names):
15 | """Init function.
16 |
17 | # Arguments
18 | predict_proba: Function. Classifier prediction probability function.
19 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
20 | """
21 | Explainer.__init__(self, predict_proba, class_names)
22 | # Initialize explainer
23 | self.set_parameters()
24 |
25 | def set_parameters(self, **kwargs):
26 | """Parameter setter for anchor_text.
27 |
28 | # Arguments
29 | **kwargs: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor
30 | """
31 | self.explainer = AnchorImage(**kwargs)
32 |
33 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs):
34 | """Generate explanation for a prediction using certain method.
35 | Anchor does not use top_labels and labels.
36 |
37 | # Arguments
38 | instance: Array. An image to be explained.
39 | **kwargs: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor.
40 | """
41 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels)
42 | try:
43 | self.labels = self.predict_proba([instance]).argsort()[0][-1:]
44 | except:
45 | self.labels = self.predict_proba(instance).argsort()[0][-1:]
46 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba, **kwargs)
47 |
48 | def __show_image_no_axis(self, image, boundaries=None, save=None):
49 | """Show image with no axis
50 |
51 | # Arguments
52 | image: Array. The image you want to show.
53 | boundaries: Boundaries
54 | save: Str. Save path.
55 | """
56 | fig = plt.figure()
57 | ax = plt.Axes(fig, [0., 0., 1., 1.])
58 | ax.set_axis_off()
59 | fig.add_axes(ax)
60 | if boundaries is not None:
61 | ax.imshow(mark_boundaries(image, boundaries))
62 | else:
63 | ax.imshow(image)
64 | if save is not None:
65 | plt.savefig(save)
66 | plt.show()
67 |
68 | def show_explanation(self, deprocess=None):
69 | """Visualization of explanation of anchor_image.
70 |
71 | # Arguments
72 | deprocess: Function. A function to deprocess the image.
73 | """
74 | Explainer.show_explanation(self)
75 | exp = self.explanation
76 | segments = exp[0]
77 | exp = exp[1]
78 | explainer = self.explainer
79 | predict_fn = self.predict_proba
80 | image = self.instance
81 |
82 | print()
83 | print("Anchor Explanation")
84 | print()
85 | if deprocess is not None:
86 | image = deprocess(copy.deepcopy(image))
87 | temp = copy.deepcopy(image)
88 | temp_img = copy.deepcopy(temp)
89 | temp[:] = np.average(temp)
90 | for x in exp:
91 | temp[segments == x[0]] = temp_img[segments == x[0]]
92 | print('Prediction ', predict_fn(np.expand_dims(image, 0))[0].argmax())
93 | assert hasattr(exp, '__len__')
94 | if len(exp) == 0:
95 | print("Fail to find Anchor")
96 | else:
97 | print('Confidence ', exp[-1][2])
98 | self.__show_image_no_axis(temp)
99 | if len(exp[-1][3]) != 0:
100 | print('Counter Examples:')
101 | for e in exp[-1][3]:
102 | data = e[:-1]
103 | temp = explainer.dummys[e[-1]].copy()
104 | for x in data.nonzero()[0]:
105 | temp[segments == x] = image[segments == x]
106 | self.__show_image_no_axis(temp)
107 | print('Prediction = ', predict_fn(np.expand_dims(temp, 0))[0].argmax())
108 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/image_explainers/cle_image.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from matplotlib.colors import LinearSegmentedColormap
4 |
5 | from ..cle.cle_image import CLEImageExplainer
6 | from ..explainer import Explainer
7 | from ..exceptions import XDeepError
8 |
9 | plt.rcParams.update({'font.size': 12})
10 |
11 |
12 | class XDeepCLEImageExplainer(Explainer):
13 |
14 | def __init__(self, predict_proba, class_names):
15 | """Init function.
16 |
17 | # Arguments
18 | predict_proba: Function. Classifier prediction probability function.
19 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
20 | """
21 | Explainer.__init__(self, predict_proba, class_names)
22 | # Initialize explainer
23 | self.set_parameters()
24 |
25 | def set_parameters(self, **kwargs):
26 | """Parameter setter for cle_image.
27 |
28 | # Arguments
29 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
30 | """
31 | self.explainer = CLEImageExplainer(**kwargs)
32 |
33 | def explain(self, instance, top_labels=None, labels=(1,), care_segments=None, spans=(2,), include_original_feature=True, **kwargs):
34 | """Generate explanation for a prediction using certain method.
35 |
36 | # Arguments
37 | instance: Array. An image to be explained.
38 | top_labels: Integer. Number of labels you care about.
39 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
40 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
41 | """
42 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels)
43 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba, labels=self.labels, top_labels=top_labels,
44 | care_segments=care_segments, spans=spans,
45 | include_original_feature=include_original_feature, **kwargs)
46 |
47 | def show_explanation(self, num_features=5, deprocess=None, positive_only=True):
48 | """Visualization of explanation of lime_image.
49 |
50 | # Arguments
51 | num_features: Integer. How many features you care about.
52 | deprocess: Function. A function to deprocess the image.
53 | positive_only: Boolean. Whether only show feature with positive weight.
54 | """
55 | if not isinstance(num_features, int) or num_features < 1:
56 | raise XDeepError("Number of features should be positive integer.")
57 | Explainer.show_explanation(self)
58 | exp = self.explanation
59 | labels = self.labels
60 |
61 | print()
62 | print("CLE Explanation")
63 | print()
64 |
65 | colors = list()
66 | for l in np.linspace(1, 0, 100):
67 | colors.append((245 / 255, 39 / 255, 87 / 255, l))
68 | for l in np.linspace(0, 1, 100):
69 | colors.append((24 / 255, 196 / 255, 93 / 255, l))
70 | colormap = LinearSegmentedColormap.from_list("CLE", colors)
71 |
72 | # Plot
73 | fig, axes = plt.subplots(nrows=len(labels), ncols=num_features+1,
74 | figsize=((num_features+1)*3,len(labels)*3))
75 |
76 | assert hasattr(labels, '__len__')
77 | for n in range(len(labels)):
78 | # Inverse
79 | label = labels[-n-1]
80 | result = exp.intercept[label]
81 | local_exp = exp.local_exp[label]
82 | for item in local_exp:
83 | result += item[1]
84 | print("Explanation for label {}:".format(self.class_names[label]))
85 | print("Local Prediction: {:.3f}".format(result))
86 | print("Original Prediction: {:.3f}".format(self.original_pred[label]))
87 | print()
88 | img, mask, fs, ws, ms = exp.get_image_and_mask(
89 | label, positive_only=positive_only,
90 | num_features=num_features, hide_rest=False
91 | )
92 | if deprocess is not None:
93 | img = deprocess(img)
94 |
95 | mask = np.zeros_like(mask, dtype=np.float64)
96 | for index in range(len(ws)):
97 | mask[ms[index] == 1] += ws[index]
98 |
99 | max_val = max(np.max(ws), np.max(mask), abs(np.min(ws)), abs(np.min(mask)))
100 | min_val = -max_val
101 |
102 | if len(labels) == 1:
103 | axs = axes
104 | else:
105 | axs = axes[n]
106 | axs[0].axis('off')
107 | axs[0].set_title("{}".format(self.class_names[label]))
108 | axs[0].imshow(img, alpha=0.9)
109 | axs[0].imshow(mask, cmap=colormap, vmin=min_val, vmax=max_val)
110 |
111 | for idx in range(1, num_features+1):
112 | m = ms[idx-1] * ws[idx-1]
113 | title = "seg."
114 | for item in fs[idx-1]:
115 | title += str(item)+"&"
116 | title = title[:-1]
117 | axs[idx].set_title(title)
118 | axs[idx].imshow(img, alpha=0.5)
119 | axs[idx].imshow(m, cmap=colormap, vmin=min_val, vmax=max_val)
120 | axs[idx].axis('off')
121 | plt.show()
122 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/image_explainers/lime_image.py:
--------------------------------------------------------------------------------
1 | # The implementation of LIME refers the original authors' codes in GitHub https://github.com/limetext/lime.
2 | # The Copyright of algorithm LIME is reserved for (c) 2016, Marco Tulio Correia Ribeiro.
3 | import matplotlib.pyplot as plt
4 | from skimage.segmentation import mark_boundaries
5 |
6 | from lime.lime_image import LimeImageExplainer
7 |
8 | from ..explainer import Explainer
9 |
10 | plt.rcParams.update({'font.size': 12})
11 |
12 |
13 | class XDeepLimeImageExplainer(Explainer):
14 |
15 | def __init__(self, predict_proba, class_names):
16 | """Init function.
17 |
18 | # Arguments
19 | predict_proba: Function. Classifier prediction probability function.
20 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
21 | """
22 | Explainer.__init__(self, predict_proba, class_names)
23 | # Initialize explainer
24 | self.set_parameters()
25 |
26 | def set_parameters(self, **kwargs):
27 | """Parameter setter for lime_text.
28 |
29 | # Arguments
30 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
31 | """
32 | self.explainer = LimeImageExplainer(**kwargs)
33 |
34 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs):
35 | """Generate explanation for a prediction using certain method.
36 |
37 | # Arguments
38 | instance: Array. An image to be explained.
39 | top_labels: Integer. Number of labels you care about.
40 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
41 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
42 | """
43 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels)
44 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba,
45 | top_labels=top_labels, labels=self.labels, **kwargs)
46 |
47 | def show_explanation(self, deprocess=None, positive_only=True, num_features=5, hide_rest=False):
48 | """Visualization of explanation of lime_image.
49 |
50 | # Arguments
51 | deprocess: Function. A function to deprocess the image.
52 | positive_only: Boolean. Whether only show feature with positive weight.
53 | num_features: Integer. Numbers of feature you care about.
54 | hide_rest: Boolean. Whether to hide rest of the image.
55 | """
56 | Explainer.show_explanation(self)
57 | exp = self.explanation
58 | labels = self.labels
59 |
60 | print()
61 | print("LIME Explanation")
62 | print()
63 |
64 | fig, axs = plt.subplots(nrows=1, ncols=len(labels), figsize=(len(labels)*3, 3))
65 |
66 | assert hasattr(labels, '__len__')
67 | for idx in range(len(labels)):
68 | # Inverse
69 | label = labels[-idx-1]
70 | result = exp.intercept[label]
71 | local_exp = exp.local_exp[label]
72 | for item in local_exp:
73 | result += item[1]
74 | print("Explanation for label {}:".format(self.class_names[label]))
75 | print("Local Prediction: {:.3f}".format(result))
76 | print("Original Prediction: {:.3f}".format(self.original_pred[label]))
77 | print()
78 | temp, mask = exp.get_image_and_mask(
79 | label, positive_only=positive_only,
80 | num_features=num_features, hide_rest=hide_rest
81 | )
82 | if deprocess is not None:
83 | temp = deprocess(temp)
84 | if len(labels) == 1:
85 | axs.imshow(mark_boundaries(temp, mask), alpha=0.9)
86 | axs.axis('off')
87 | axs.set_title("{}".format(self.class_names[label]))
88 | else:
89 | axs[idx].imshow(mark_boundaries(temp, mask), alpha=0.9)
90 | axs[idx].axis('off')
91 | axs[idx].set_title("{}".format(self.class_names[label]))
92 | plt.show()
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/image_explainers/shap_image.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from PIL import Image
4 | from matplotlib.colors import LinearSegmentedColormap
5 |
6 | import shap
7 |
8 | from ..explainer import Explainer
9 |
10 |
11 | class XDeepShapImageExplainer(Explainer):
12 |
13 | def __init__(self, predict_proba, class_names, n_segment, segment):
14 | """Init function.
15 |
16 | # Arguments
17 | predict_proba: Function. Classifier prediction probability function.
18 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
19 | n_segment: Integer. Number of segments in the image.
20 | segment: Array. An array with 2 dimensions, segments_slic of the image.
21 | """
22 | Explainer.__init__(self, predict_proba, class_names)
23 | self.n_segment = n_segment
24 | self.segment = segment
25 |
26 | @staticmethod
27 | def _mask_image(zs, segmentation, image, background=None):
28 | """Mask image function for shap."""
29 | if background is None:
30 | background = image.mean((0, 1))
31 | out = np.zeros((zs.shape[0], image.shape[0], image.shape[1], image.shape[2]))
32 | for i in range(zs.shape[0]):
33 | out[i, :, :, :] = image
34 | for j in range(zs.shape[1]):
35 | if zs[i, j] == 0:
36 | out[i][segmentation == j, :] = background
37 | return out
38 |
39 | def set_parameters(self, n_segment, segment):
40 | """Parameter setter for shap_image. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py.
41 |
42 | # Arguments
43 | n_segment: Integer. number of segments in the image.
44 | segment: Array. An array with 2 dimensions, segments_slic of the image.
45 | """
46 | self.n_segment = n_segment
47 | self.segment = segment
48 |
49 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs):
50 | """Generate explanation for a prediction using certain method.
51 |
52 | # Arguments
53 | instance: Array. An image to be explained.
54 | top_labels: Integer. Number of labels you care about.
55 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
56 | **kwarg: Parameters setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py.
57 | """
58 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels)
59 |
60 | def f(z):
61 | masked = self._mask_image(z, self.segment, self.instance, 255)
62 | return self.predict_proba(masked)
63 |
64 | self.explainer = shap.KernelExplainer(f, np.zeros((1, self.n_segment)), **kwargs)
65 | self.explanation = self.explainer.shap_values(np.ones((1, self.n_segment)), **kwargs)
66 |
67 | def __fill_segmentation(self, values, segmentation):
68 | """Fill segmentation
69 |
70 | # Arguments
71 | values: Array, weights.
72 | segmentation: Array. An array with 2 dimensions, segments_slic of the image.
73 |
74 | # Return
75 | out
76 | """
77 | out = np.zeros(segmentation.shape)
78 | for i in range(len(values)):
79 | out[segmentation == i] = values[i]
80 | return out
81 |
82 | def show_explanation(self, deprocess=None):
83 | """Visualization of explanation of shap.
84 |
85 | # Arguments
86 | deprocess: Function. A function to deprocess the image.
87 | """
88 | Explainer.show_explanation(self)
89 | shap_values = self.explanation
90 | labels = self.labels
91 | segments_slic = self.segment
92 | img = self.instance
93 |
94 | if deprocess is not None:
95 | img = deprocess(img)
96 | img = Image.fromarray(np.uint8(img*255))
97 | colors = []
98 | for l in np.linspace(1, 0, 100):
99 | colors.append((245 / 255, 39 / 255, 87 / 255, l))
100 | for l in np.linspace(0, 1, 100):
101 | colors.append((24 / 255, 196 / 255, 93 / 255, l))
102 | colormap = LinearSegmentedColormap.from_list("shap", colors)
103 | # Plot our explanations
104 | fig, axes = plt.subplots(nrows=1, ncols=len(labels)+1, figsize=(12, 4))
105 | inds = labels
106 | axes[0].imshow(img)
107 | axes[0].axis('off')
108 | max_val = np.max([np.max(np.abs(shap_values[i][:, :-1])) for i in range(len(shap_values))])
109 | assert hasattr(labels, '__len__')
110 | for i in range(len(labels)):
111 | m = self.__fill_segmentation(shap_values[inds[i]][0], segments_slic)
112 | axes[i + 1].set_title(self.class_names[labels[i]])
113 | axes[i + 1].imshow(img, alpha=0.5)
114 | axes[i + 1].imshow(m, cmap=colormap, vmin=-max_val, vmax=max_val)
115 | axes[i + 1].axis('off')
116 | plt.show()
117 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/tabular_explainers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/perturbation/tabular_explainers/__init__.py
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/tabular_explainers/anchor_tabular.py:
--------------------------------------------------------------------------------
1 | # The implementation of anchor refers the original authors' codes in GitHub https://github.com/marcotcr/anchor.
2 | # The Copyright of algorithm anchor is reserved for (c) 2018, Marco Tulio Correia Ribeiro.import numpy as np
3 | import numpy as np
4 |
5 | from anchor.anchor_tabular import AnchorTabularExplainer
6 |
7 | from ..explainer import Explainer
8 | from ..exceptions import XDeepError
9 |
10 |
11 | class XDeepAnchorTabularExplainer(Explainer):
12 |
13 | def __init__(self, class_names, feature_names, data, categorical_names=None):
14 | """Init function.
15 |
16 | # Arguments
17 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
18 | feature_names: List. list of names (strings) corresponding to the columns in the training data.
19 | data: Array. Full data including train data, validation_data and test data.
20 | categorical_names: Dict. A dict which maps from int to list of names, where categorical_names[x][y] represents the name of the yth value of column x.
21 | """
22 | Explainer.__init__(self, None, class_names)
23 | self.feature_names = feature_names
24 | self.data = data
25 | self.categorical_names = categorical_names
26 | # Initialize explainer
27 | self.set_parameters()
28 |
29 | def set_parameters(self, **kwargs):
30 | """Parameter setter for anchor_tabular.
31 |
32 | # Arguments
33 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
34 | """
35 | class_names = kwargs.pop("class_names", self.class_names)
36 | feature_names = kwargs.pop("feature_names", self.feature_names)
37 | data = kwargs.pop("data", self.data)
38 | categorical_names = kwargs.pop("categorical_names", self.categorical_names)
39 | self.explainer = AnchorTabularExplainer(class_names, feature_names, data=data,
40 | categorical_names=categorical_names)
41 |
42 | def get_anchor_encoder(self, train_data, train_labels, validation_data, validation_labels, discretizer='quartile'):
43 | """Get encoder for tabular data.
44 |
45 | If you want to use 'anchor' to explain tabular classifier. You need to get this encoder to encode your data, and train another model.
46 |
47 | # Arguments
48 | train_data: Array. Train data.
49 | train_labels: Array. Train labels.
50 | validation_data: Array. Validation set.
51 | validation_labels: Array. Validation labels.
52 | discretizer: Str. Discretizer for data. Please choose between 'quartile' and 'decile'.
53 |
54 | # Return
55 | A encoder object which has function 'transform'.
56 | """
57 | self.explainer.fit(train_data, train_labels, validation_data, validation_labels,
58 | discretizer=discretizer)
59 | return self.explainer.encoder
60 |
61 | def set_anchor_predict_proba(self, predict_proba):
62 | """Predict function setter.
63 |
64 | # Arguments
65 | predict_proba: Function. A classifier prediction probability function.
66 | """
67 | self.predict_proba = predict_proba
68 |
69 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs):
70 | """Generate explanation for a prediction using certain method.
71 | Anchor does not use top_labels and labels.
72 |
73 | # Arguments
74 | instance: Array. One row of tabular data to be explained.
75 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
76 | """
77 | if self.predict_proba is None:
78 | raise XDeepError("Please call set_anchor_predict_proba to pass in your new predict function.")
79 |
80 | def predict_label(x):
81 | return np.argmax(self.predict_proba(x), axis=1)
82 | self.instance = instance
83 | try:
84 | self.labels = self.predict_proba(self.explainer.encoder.transform([instance])).argsort()[0][-1:]
85 | except:
86 | self.labels = self.predict_proba(self.explainer.encoder.transform(instance)).argsort()[0][-1:]
87 | self.explanation = self.explainer.explain_instance(instance, predict_label, **kwargs)
88 |
89 | def show_explanation(self, show_in_note_book=True, verbose=True):
90 | """Visualization of explanation of anchor_tabular.
91 |
92 | # Arguments
93 | show_in_note_book: Boolean. Whether show in jupyter notebook.
94 | verbose: Boolean. Whether print out examples and counter examples.
95 | """
96 | Explainer.show_explanation(self)
97 | exp = self.explanation
98 |
99 | print()
100 | print("Anchor Explanation")
101 | print()
102 | if verbose:
103 | print()
104 | print('Examples where anchor applies and model predicts same to instance:')
105 | print()
106 | print(exp.examples(only_same_prediction=True))
107 | print()
108 | print('Examples where anchor applies and model predicts different with instance:')
109 | print()
110 | print(exp.examples(only_different_prediction=True))
111 |
112 | label = self.labels[0]
113 | print('Prediction: {}'.format(label))
114 | print('Anchor: %s' % (' AND '.join(exp.names())))
115 | print('Precision: %.2f' % exp.precision())
116 | print('Coverage: %.2f' % exp.coverage())
117 | if show_in_note_book:
118 | pass
119 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/tabular_explainers/cle_tabular.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | from ..cle.cle_tabular import CLETabularExplainer
5 | from ..explainer import Explainer
6 |
7 |
8 | class XDeepCLETabularExplainer(Explainer):
9 |
10 | def __init__(self, predict_proba, class_names, feature_names, train,
11 | categorical_features=None, categorical_names=None):
12 | """Init function.
13 |
14 | # Arguments
15 | predict_proba: Function. A classifier prediction probability function.
16 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
17 | feature_names: List. A list of names (strings) corresponding to the columns in the training data.
18 | train: Array. Train data.
19 | categorical_features: List. A list of indices (ints) corresponding to the categorical columns. Everything else will be considered continuous. Values in these columns MUST be integers.
20 | categorical_names: Dict. A dict which maps from int to list of names, where categorical_names[x][y] represents the name of the yth value of column x.
21 | """
22 | Explainer.__init__(self, predict_proba, class_names)
23 | self.train = train
24 | self.feature_names = feature_names
25 | self.categorical_features = categorical_features
26 | self.categorical_names = categorical_names
27 | # Initialize explainer
28 | self.set_parameters()
29 |
30 | def set_parameters(self, **kwargs):
31 | """Parameter setter for cle_tabular.
32 |
33 | # Arguments
34 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
35 | """
36 | train = kwargs.pop("train", self.train)
37 | class_names = kwargs.pop("class_names", self.class_names)
38 | feature_names = kwargs.pop("feature_names", self.feature_names)
39 | categorical_features = kwargs.pop("categorical_features", self.categorical_features)
40 | categorical_names = kwargs.pop("categorical_names", self.categorical_names)
41 | self.explainer = CLETabularExplainer(train, class_names=class_names,
42 | feature_names=feature_names, categorical_features=categorical_features,
43 | categorical_names=categorical_names, **kwargs)
44 |
45 | def explain(self, instance, top_labels=None, labels=(1,), care_cols=None, spans=(2,), include_original_feature=True, **kwargs):
46 | """Generate explanation for a prediction using certain method.
47 |
48 | # Arguments
49 | instance: Array. One row of tabular data to be explained.
50 | top_labels: Integer. Number of labels you care about.
51 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
52 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
53 | """
54 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels)
55 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba, labels=self.labels,
56 | top_labels=top_labels,
57 | care_cols=care_cols, spans=spans,
58 | include_original_feature=include_original_feature, **kwargs)
59 |
60 | def show_explanation(self, span=3, plot=True):
61 | """Visualization of explanation of cle_text.
62 |
63 | # Arguments
64 | span: Boolean. Each row shows how many features.
65 | plot: Boolean. Whether plot a figure.
66 | """
67 | Explainer.show_explanation(self)
68 | exp = self.explanation
69 | labels = self.labels
70 |
71 | print()
72 | print("CLE Explanation")
73 | print("Instance: {}".format(self.instance))
74 | print()
75 |
76 | words = exp.domain_mapper.exp_feature_names
77 | D = len(words)
78 |
79 | # parts is a list which contains all the combination of features.
80 | parts = self.explainer.all_combinations
81 |
82 | assert hasattr(labels, '__len__')
83 | for i in range(len(labels)):
84 | label = labels[i]
85 | result = exp.intercept[label]
86 | local_exp = exp.local_exp[label]
87 | for item in local_exp:
88 | result += item[1]
89 | print("Explanation for label {}:".format(self.class_names[label]))
90 | print("Local Prediction: {:.3f}".format(result))
91 | print("Original Prediction: {:.3f}".format(self.original_pred[label]))
92 | print()
93 | local_exp = exp.local_exp[label]
94 | for idx in range(len(local_exp)):
95 | if self.explainer.include_original_feature:
96 | if local_exp[idx][0] >= D:
97 | index = local_exp[idx][0] - D
98 | key = ""
99 | for item in parts[index]:
100 | key += words[item] + " AND "
101 | key = key[:-5]
102 | local_exp[idx] = (key, local_exp[idx][1])
103 | else:
104 | local_exp[idx] = (words[local_exp[idx][0]], local_exp[idx][1])
105 | else:
106 | index = local_exp[idx][0]
107 | key = ""
108 | for item in parts[index]:
109 | key += words[item] + " AND "
110 | key = key[:-5]
111 | local_exp[idx] = (key, local_exp[idx][1])
112 |
113 | for idx in range(len(local_exp)):
114 | print(" {:30} : {:.3f} |".format(local_exp[idx][0], local_exp[idx][1]), end="")
115 | if idx % span == span - 1:
116 | print()
117 | print()
118 | if len(local_exp) % span != 0:
119 | print()
120 |
121 | if plot:
122 | fig, axs = plt.subplots(nrows=1, ncols=len(labels))
123 | vals = [x[1] for x in local_exp]
124 | names = [x[0] for x in local_exp]
125 | vals.reverse()
126 | names.reverse()
127 | pos = np.arange(len(vals)) + .5
128 | colors = ['green' if x > 0 else 'red' for x in vals]
129 | if len(labels) == 1:
130 | ax = axs
131 | else:
132 | ax = axs[i]
133 | ax.barh(pos, vals, align='center', color=colors)
134 | ax.set_yticks(pos)
135 | ax.set_yticklabels(names)
136 | title = 'Local explanation for class %s' % self.class_names[label]
137 | ax.set_title(title)
138 | if plot:
139 | plt.show()
140 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/tabular_explainers/lime_tabular.py:
--------------------------------------------------------------------------------
1 | # The implementation of LIME refers the original authors' codes in GitHub https://github.com/limetext/lime.
2 | # The Copyright of algorithm LIME is reserved for (c) 2016, Marco Tulio Correia Ribeiro.
3 | import matplotlib.pyplot as plt
4 |
5 | from lime.lime_tabular import LimeTabularExplainer
6 |
7 | from ..explainer import Explainer
8 |
9 |
10 | class XDeepLimeTabularExplainer(Explainer):
11 |
12 | def __init__(self, predict_proba, class_names, feature_names, train,
13 | categorical_features=None, categorical_names=None):
14 | """Init function.
15 |
16 | # Arguments
17 | predict_proba: Function. A classifier prediction probability function.
18 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
19 | feature_names: List. A list of names (strings) corresponding to the columns in the training data.
20 | train: Array. Train data.
21 | categorical_features: List. A list of indices (ints) corresponding to the categorical columns. Everything else will be considered continuous. Values in these columns MUST be integers.
22 | categorical_names: Dict. A dict which maps from int to list of names, where categorical_names[x][y] represents the name of the yth value of column x.
23 | """
24 | Explainer.__init__(self, predict_proba, class_names)
25 | self.train = train
26 | self.feature_names = feature_names
27 | self.categorical_features = categorical_features
28 | self.categorical_names = categorical_names
29 | # Initialize explainer
30 | self.set_parameters()
31 |
32 | def set_parameters(self, **kwargs):
33 | """Parameter setter for lime_tabular.
34 |
35 | # Arguments
36 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
37 | """
38 | train = kwargs.pop("train", self.train)
39 | class_names = kwargs.pop("class_names", self.class_names)
40 | feature_names = kwargs.pop("feature_names", self.feature_names)
41 | categorical_features = kwargs.pop("categorical_features", self.categorical_features)
42 | categorical_names = kwargs.pop("categorical_names", self.categorical_names)
43 | self.explainer = LimeTabularExplainer(train, class_names=class_names,
44 | feature_names=feature_names, categorical_features=categorical_features,
45 | categorical_names=categorical_names, **kwargs)
46 |
47 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs):
48 | """Generate explanation for a prediction using certain method.
49 |
50 | # Arguments
51 | instance: Array. One row of tabular data to be explained.
52 | top_labels: Integer. Number of labels you care about.
53 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
54 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
55 | """
56 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels)
57 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba,
58 | top_labels=top_labels, labels=self.labels, **kwargs)
59 |
60 | def show_explanation(self, span=3):
61 | """Visualization of explanation of lime_text.
62 |
63 | # Arguments
64 | span: Integer. Each row shows how many features.
65 | """
66 | Explainer.show_explanation(self)
67 | exp = self.explanation
68 | labels = self.labels
69 |
70 | print()
71 | print("LIME Explanation")
72 | print("Instance: {}".format(self.instance))
73 | print()
74 |
75 | assert hasattr(labels, '__len__')
76 | for label in labels:
77 | result = exp.intercept[label]
78 | local_exp = exp.local_exp[label]
79 | for item in local_exp:
80 | result += item[1]
81 | print("Explanation for label {}:".format(self.class_names[label]))
82 | print("Local Prediction: {:.3f}".format(result))
83 | print("Original Prediction: {:.3f}".format(self.original_pred[label]))
84 | print()
85 | exp_list = exp.as_list(label=label)
86 | for idx in range(len(exp_list)):
87 | print(" {:20} : {:.3f} |".format(exp_list[idx][0], exp_list[idx][1]), end="")
88 | if idx % span == span - 1:
89 | print()
90 | print()
91 | exp.as_pyplot_figure(label=label)
92 | plt.show()
93 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/tabular_explainers/shap_tabular.py:
--------------------------------------------------------------------------------
1 | from IPython.display import display
2 | from scipy.sparse.csr import csr_matrix
3 |
4 | import shap
5 |
6 | from ..explainer import Explainer
7 |
8 |
9 | class XDeepShapTabularExplainer(Explainer):
10 |
11 | def __init__(self, predict_proba, class_names, train):
12 | """Init function.
13 |
14 | # Arguments
15 | predict_proba: Function. Classifier prediction probability function.
16 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
17 | train: Array. Train data, in this case a list of str.
18 | """
19 | Explainer.__init__(self, predict_proba, class_names)
20 | self.train = train
21 | # Initialize explainer
22 | self.set_parameters()
23 |
24 | def set_parameters(self, **kwargs):
25 | """Parameter setter for shap.
26 |
27 | # Arguments
28 | **kwargs: Parameters setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py.
29 | """
30 | predict_proba = kwargs.pop("predict_proba", self.predict_proba)
31 | train = kwargs.pop("train", self.train)
32 | self.explainer = shap.KernelExplainer(predict_proba, train, **kwargs)
33 |
34 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs):
35 | """Generate explanation for a prediction using certain method.
36 |
37 | # Arguments
38 | instance: Array. One row of tabular data to be explained.
39 | top_labels: Integer. Number of labels you care about.
40 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
41 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
42 | """
43 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels)
44 | self.explanation = self.explainer.shap_values(instance, **kwargs)
45 |
46 | def show_explanation(self, show_in_note_book=True):
47 | """Visualization of explanation of shap.
48 |
49 | # Arguments
50 | show_in_note_book: Boolean. Whether show in jupyter notebook.
51 | """
52 | Explainer.show_explanation(self)
53 | shap_values = self.explanation
54 | expected_value = self.explainer.expected_value
55 | class_names = self.class_names
56 | labels = self.labels
57 | instance = self.instance
58 |
59 | shap.initjs()
60 | print()
61 | print("Shap Explanation")
62 | print()
63 |
64 | assert hasattr(labels, '__len__')
65 | if len(instance.shape) == 1 or instance.shape[0] == 1:
66 | for item in labels:
67 | print("Shap value for label {}:".format(class_names[item]))
68 | print(shap_values[item])
69 |
70 | if show_in_note_book:
71 | for item in labels:
72 | if isinstance(instance, csr_matrix):
73 | display(shap.force_plot(expected_value[item], shap_values[item], instance.A))
74 | else:
75 | display(shap.force_plot(expected_value[item], shap_values[item], instance))
76 | else:
77 | if show_in_note_book:
78 | shap.summary_plot(shap_values, instance)
79 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/text_explainers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datamllab/xdeep/c2f0323b0ec55d684ee24dbe35a6046fe0074663/xdeep/xlocal/perturbation/text_explainers/__init__.py
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/text_explainers/anchor_text.py:
--------------------------------------------------------------------------------
1 | # The implementation of anchor refers the original authors' codes in GitHub https://github.com/marcotcr/anchor.
2 | # The Copyright of algorithm anchor is reserved for (c) 2018, Marco Tulio Correia Ribeiro.
3 | import spacy
4 | import numpy as np
5 |
6 | from anchor.anchor_text import AnchorText
7 |
8 | from ..explainer import Explainer
9 |
10 |
11 | class XDeepAnchorTextExplainer(Explainer):
12 |
13 | def __init__(self, predict_proba, class_names):
14 | """Init function.
15 |
16 | # Arguments
17 | predict_proba: Function. Classifier prediction probability function.
18 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
19 | """
20 | Explainer.__init__(self, predict_proba, class_names)
21 | # Initialize explainer
22 | self.set_parameters()
23 |
24 | def set_parameters(self, nlp=None, **kwargs):
25 | """Parameter setter for anchor_text.
26 |
27 | # Arguments
28 | nlp: Object. A spacy model object.
29 | **kwargs: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor
30 | """
31 | if nlp is None:
32 | print("Use default nlp = spacy.load('en_core_web_sm') to initialize anchor")
33 | try:
34 | nlp = spacy.load('en_core_web_sm')
35 | except OSError:
36 | print("Cannot load en_core_web_sm")
37 | print("Please run the command 'python -m spacy download en' as admin, then rerun your code.")
38 |
39 | class_names = kwargs.pop("class_names", self.class_names)
40 | if nlp is not None:
41 | self.explainer = AnchorText(nlp, class_names=class_names, **kwargs)
42 |
43 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs):
44 | """Generate explanation for a prediction using certain method.
45 | Anchor does not use top_labels and labels.
46 |
47 | # Arguments
48 | instance: Str. A raw text string to be explained.
49 | **kwargs: Parameters setter. For more detail, please check https://github.com/marcotcr/anchor.
50 | """
51 | def predict_label(x):
52 | return np.argmax(self.predict_proba(x), axis=1)
53 |
54 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels)
55 | try:
56 | self.labels = self.predict_proba([instance]).argsort()[0][-1:]
57 | except:
58 | self.labels = self.predict_proba(instance).argsort()[0][-1:]
59 | if self.explainer is not None:
60 | self.explanation = self.explainer.explain_instance(instance, predict_label, **kwargs)
61 |
62 | def show_explanation(self, show_in_note_book=True, verbose=True):
63 | """Visualization of explanation of anchor_text.
64 |
65 | # Arguments
66 | show_in_note_book: Boolean. Whether show in jupyter notebook.
67 | verbose: Boolean. Whether print out examples and counter examples.
68 | """
69 | Explainer.show_explanation(self)
70 | exp = self.explanation
71 | instance = self.instance
72 |
73 | print()
74 | print("Anchor Explanation")
75 | print("Instance: {}".format(instance))
76 | print()
77 | if verbose:
78 | print()
79 | print('Examples where anchor applies and model predicts same to instance:')
80 | print()
81 | print('\n'.join([x[0] for x in exp.examples(only_same_prediction=True)]))
82 | print()
83 | print('Examples where anchor applies and model predicts different with instance:')
84 | print()
85 | print('\n'.join([x[0] for x in exp.examples(only_different_prediction=True)]))
86 |
87 | label = self.labels[0]
88 | print('Prediction: {}'.format(label))
89 | print('Anchor: %s' % (' AND '.join(exp.names())))
90 | print('Precision: %.2f' % exp.precision())
91 | print('Coverage: %.2f' % exp.coverage())
92 | if show_in_note_book:
93 | exp.show_in_notebook()
94 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/text_explainers/cle_text.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | from ..cle.cle_text import CLETextExplainer
5 |
6 | from ..explainer import Explainer
7 |
8 |
9 | class XDeepCLETextExplainer(Explainer):
10 |
11 | def __init__(self, predict_proba, class_names):
12 | """Init function.
13 |
14 | # Arguments
15 | predict_proba: Function. Classifier prediction probability function.
16 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
17 | """
18 | Explainer.__init__(self, predict_proba, class_names)
19 | # Initialize explainer
20 | self.set_parameters()
21 |
22 | def set_parameters(self, **kwargs):
23 | """Parameter setter for cle_text.
24 |
25 | # Arguments
26 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
27 | """
28 | class_names = kwargs.pop("class_names", self.class_names)
29 | self.explainer = CLETextExplainer(class_names=class_names, **kwargs)
30 |
31 | def explain(self, instance, top_labels=None, labels=(1,), care_words=None, spans=(2,), include_original_feature=True, **kwargs):
32 | """Generate explanation for a prediction using certain method.
33 |
34 | # Arguments
35 | instance: Str. A raw text string to be explained.
36 | top_labels: Integer. Number of labels you care about.
37 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
38 | **kwarg: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
39 | """
40 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels)
41 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba, labels=self.labels, top_labels=top_labels,
42 | care_words=care_words, spans=spans, include_original_feature=include_original_feature,
43 | **kwargs)
44 |
45 | def show_explanation(self, span=3, plot=True):
46 | """Visualization of explanation of cle_text.
47 |
48 | # Arguments
49 | span: Integer. Each row shows how many features.
50 | plot: Boolean. Whether plots a figure.
51 | """
52 | Explainer.show_explanation(self)
53 | exp = self.explanation
54 | labels = self.labels
55 |
56 | print()
57 | print("CLE Explanation")
58 | print("Instance: {}".format(self.instance))
59 | print()
60 |
61 | words = exp.domain_mapper.indexed_string.inverse_vocab
62 | D = len(words)
63 |
64 | # parts is a list which contains all the combination of features.
65 | parts = self.explainer.all_combinations
66 |
67 | assert hasattr(labels, '__len__')
68 | for i in range(len(labels)):
69 | label = labels[i]
70 | result = exp.intercept[label]
71 | local_exp = exp.local_exp[label]
72 | for item in local_exp:
73 | result += item[1]
74 | print("Explanation for label {}:".format(self.class_names[label]))
75 | print("Local Prediction: {:.3f}".format(result))
76 | print("Original Prediction: {:.3f}".format(self.original_pred[label]))
77 | print()
78 | for idx in range(len(local_exp)):
79 | if self.explainer.include_original_feature:
80 | if local_exp[idx][0] >= D:
81 | index = local_exp[idx][0] - D
82 | key = ""
83 | for item in parts[index]:
84 | key += words[item]+" AND "
85 | key = key[:-5]
86 | local_exp[idx] = (key, local_exp[idx][1])
87 | else:
88 | local_exp[idx] = (words[local_exp[idx][0]], local_exp[idx][1])
89 | else:
90 | index = local_exp[idx][0]
91 | key = ""
92 | for item in parts[index]:
93 | key += words[item] + " AND "
94 | key = key[:-5]
95 | local_exp[idx] = (key, local_exp[idx][1])
96 |
97 | for idx in range(len(local_exp)):
98 | print(" {:20} : {:.3f} |".format(local_exp[idx][0], local_exp[idx][1]), end="")
99 | if idx % span == span - 1:
100 | print()
101 | print()
102 | if len(local_exp) % span != 0:
103 | print()
104 |
105 | if plot:
106 | fig, axs = plt.subplots(nrows=1, ncols=len(labels))
107 | vals = [x[1] for x in local_exp]
108 | names = [x[0] for x in local_exp]
109 | vals.reverse()
110 | names.reverse()
111 | pos = np.arange(len(vals)) + .5
112 | colors = ['green' if x > 0 else 'red' for x in vals]
113 | if len(labels) == 1:
114 | ax = axs
115 | else:
116 | ax = axs[i]
117 |
118 | ax.barh(pos, vals, align='center', color=colors)
119 | ax.set_yticks(pos)
120 | ax.set_yticklabels(names)
121 | title = 'Local explanation for class %s' % self.class_names[label]
122 | ax.set_title(title)
123 | if plot:
124 | plt.show()
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/text_explainers/lime_text.py:
--------------------------------------------------------------------------------
1 | # The implementation of LIME refers the original authors' codes in GitHub https://github.com/limetext/lime.
2 | # The Copyright of algorithm LIME is reserved for (c) 2016, Marco Tulio Correia Ribeiro.
3 |
4 | import matplotlib.pyplot as plt
5 |
6 | from lime.lime_text import LimeTextExplainer
7 | from ..explainer import Explainer
8 |
9 |
10 | class XDeepLimeTextExplainer(Explainer):
11 |
12 | def __init__(self, predict_proba, class_names):
13 | """Init function.
14 |
15 | # Arguments
16 | predict_proba: Function. Classifier prediction probability function.
17 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
18 | """
19 | Explainer.__init__(self, predict_proba, class_names)
20 | # Initialize explainer
21 | self.set_parameters()
22 |
23 | def set_parameters(self, **kwargs):
24 | """Parameter setter for lime_text.
25 |
26 | # Arguments
27 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
28 | """
29 | class_names = kwargs.pop("class_names", self.class_names)
30 | self.explainer = LimeTextExplainer(class_names=class_names, **kwargs)
31 |
32 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs):
33 | """Generate explanation for a prediction using certain method.
34 |
35 | # Arguments
36 | instance: Str. A raw text string to be explained.
37 | top_labels: Integer. Number of labels you care about.
38 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
39 | **kwargs: Parameters setter. For more detail, please check https://lime-ml.readthedocs.io/en/latest/index.html.
40 | """
41 | Explainer.explain(self, instance, top_labels=top_labels, labels=labels)
42 | self.explanation = self.explainer.explain_instance(instance, self.predict_proba,
43 | top_labels=top_labels, labels=self.labels, **kwargs)
44 |
45 | def show_explanation(self, span=3):
46 | """Visualization of explanation of lime_text.
47 |
48 | # Arguments
49 | span: Integer. Each row shows how many features.
50 | """
51 | Explainer.show_explanation(self)
52 | exp = self.explanation
53 | labels = self.labels
54 |
55 | print()
56 | print("LIME Explanation")
57 | print("Instance: {}".format(self.instance))
58 | print()
59 |
60 | assert hasattr(labels, '__len__')
61 | for label in labels:
62 | result = exp.intercept[label]
63 | local_exp = exp.local_exp[label]
64 | for item in local_exp:
65 | result += item[1]
66 | print("Explanation for label {}:".format(self.class_names[label]))
67 | print("Local Prediction: {:.3f}".format(result))
68 | print("Original Prediction: {:.3f}".format(self.original_pred[label]))
69 | print()
70 | exp_list = exp.as_list(label=label)
71 | for idx in range(len(exp_list)):
72 | print(" {:20} : {:.3f} |".format(exp_list[idx][0], exp_list[idx][1]), end="")
73 | if idx % span == span - 1:
74 | print()
75 | print()
76 | exp.as_pyplot_figure(label=label)
77 | plt.show()
78 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/text_explainers/shap_text.py:
--------------------------------------------------------------------------------
1 | from scipy.sparse.csr import csr_matrix
2 | from IPython.display import display
3 |
4 | import shap
5 |
6 | from ..explainer import Explainer
7 |
8 |
9 | class XDeepShapTextExplainer(Explainer):
10 |
11 | def __init__(self, predict_vectorized, class_names, vectorizer, train):
12 | """Init function. The data pass to 'shap' cannot be str but vector.As a result, you need to pass in the vectorizer.
13 |
14 | # Arguments
15 | predict_vectorized: Function. Classifier prediction probability function.
16 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
17 | vectorizer: Vectorizer. A vectorizer which has 'transform' function that transforms list of str to vector.
18 | train: Array. Train data, in this case a list of str.
19 | """
20 | Explainer.__init__(self, predict_vectorized, class_names)
21 | self.vectorizer = vectorizer
22 | self.train = train
23 | # Initialize explainer
24 | self.set_parameters()
25 |
26 | def set_parameters(self, **kwargs):
27 | """Parameter setter for shap. The data pass to 'shap' cannot be str but vector.As a result, you need to pass in the vectorizer.
28 |
29 | # Arguments
30 | **kwargs: Shap kernel explainer parameter setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py.
31 | """
32 | predict_vectorized = kwargs.pop('predict_vectorized', self.predict_proba)
33 | vectorizer = kwargs.pop('vectorizer', self.vectorizer)
34 | vector = kwargs.pop('train', self.train)
35 | if not isinstance(vector, csr_matrix):
36 | vector = vectorizer.transform(vector)
37 | self.explainer = shap.KernelExplainer(predict_vectorized, vector, **kwargs)
38 |
39 | def explain(self, instance, top_labels=None, labels=(1,), **kwargs):
40 | """Generate explanation for a prediction using certain method.
41 |
42 | # Arguments
43 | instance: Str. A raw text string to be explained.
44 | top_labels: Integer. Number of labels you care about.
45 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
46 | **kwarg: Parameters setter. For more detail, please check https://github.com/slundberg/shap/blob/master/shap/explainers/kernel.py.
47 | """
48 | if not isinstance(instance, csr_matrix):
49 | vector = self.vectorizer.transform([instance])
50 | else:
51 | vector = instance
52 | self.instance = vector
53 | proba = self.predict_proba(self.instance)
54 | self.original_pred = proba[0]
55 | if top_labels is not None:
56 | self.labels = proba.argsort()[0][-top_labels:]
57 | else:
58 | self.labels = labels
59 | self.explanation = self.explainer.shap_values(vector, **kwargs)
60 |
61 | def show_explanation(self, show_in_note_book=True):
62 | """Visualization of explanation of shap.
63 |
64 | # Arguments
65 | show_in_note_book: Booleam. Whether show in jupyter notebook.
66 | """
67 | Explainer.show_explanation(self)
68 | shap_values = self.explanation
69 | expected_value = self.explainer.expected_value
70 | class_names = self.class_names
71 | labels = self.labels
72 | instance = self.instance
73 |
74 | shap.initjs()
75 | print()
76 | print("Shap Explanation")
77 | print()
78 |
79 | if len(instance.shape) == 1 or instance.shape[0] == 1:
80 | for label in labels:
81 | print("Shap value for label {}:".format(class_names[label]))
82 | print(shap_values[label])
83 |
84 | if show_in_note_book:
85 | for label in labels:
86 | if isinstance(instance, csr_matrix):
87 | display(shap.force_plot(expected_value[label], shap_values[label], instance.A))
88 | else:
89 | display(shap.force_plot(expected_value[label], shap_values[label], instance))
90 | else:
91 | if show_in_note_book:
92 | shap.summary_plot(shap_values, instance)
93 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/xdeep_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Base wrapper for xdeep.
3 | """
4 | from .exceptions import XDeepError
5 |
6 |
7 | class Wrapper(object):
8 | """A wrapper to wrap all methods."""
9 |
10 | def __init__(self, predict_proba, class_names):
11 | """Init function.
12 |
13 | # Arguments
14 | predict_proba: Function. Classifier prediction probability function, which takes a list of d strings and outputs a (d, k) numpy array with prediction probabilities, where k is the number of classes.
15 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
16 | """
17 | self.methods = list()
18 | self.explainers = dict()
19 |
20 | self.predict_proba = predict_proba
21 | self.class_names = class_names
22 |
23 | def __initialization(self):
24 | """Initialize default explainers."""
25 | pass
26 |
27 | def __check_avaliability(self, method, exp_required=False):
28 | """Check if the method is valid.
29 |
30 | # Arguments
31 | method: Str. The method you want to use.
32 | exp_required: Boolean. Whether check the existence of explanation.
33 | """
34 | if method not in self.methods:
35 | raise XDeepError("Please input correct explain_method from {}.".format(self.methods))
36 | if method not in self.explainers:
37 | raise XDeepError("Sorry. You need to initialize {}_explainer first.".format(method))
38 | if exp_required:
39 | exp = self.explainers[method].explanation
40 | if exp is None:
41 | raise XDeepError("Sorry. You haven't got {} explanation yet.".format(method))
42 |
43 | def set_parameters(self, method, **kwargs):
44 | """Parameter setter.
45 |
46 | # Arguments
47 | method: Str. The method you want to use.
48 | **kwargs: Other parameters that depends on which method you use. For more detail, please check xdeep documentation.
49 | """
50 | self.__check_avaliability(method)
51 | self.explainers[method].set_parameters(**kwargs)
52 |
53 | def explain(self, method, instance, top_labels=None, labels=(1,), **kwargs):
54 | """Generate explanation for a prediction using certain method.
55 |
56 | # Arguments
57 | method: Str. The method you want to use.
58 | instance: Instance to be explained.
59 | top_labels: Integer. Number of labels you care about.
60 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
61 | **kwargs: Parameters setter. For more detail, please check 'explain_instance' in corresponding method.
62 | """
63 | # Check validation of parameters
64 | self.__check_avaliability(method)
65 | self.explainers[method].explain(instance, top_labels=top_labels, labels=labels, **kwargs)
66 |
67 | def get_explanation(self, method):
68 | """Explanation getter.
69 |
70 | # Arguments
71 | method: Str. The method you want to use.
72 |
73 | # Return
74 | An Explanation object with the corresponding method.
75 | """
76 | # Check validation of parameters
77 | self.__check_avaliability(method, exp_required=True)
78 | exp = self.explainers[method].explanation
79 | return exp
80 |
81 | def show_explanation(self, method, **kwargs):
82 | """Visualization of explanation of the corresponding method.
83 |
84 | # Arguments
85 | method: Str. The method you want to use.
86 | **kwargs: parameters setter. For more detail, please check xdeep documentation.
87 | """
88 | # Check validation of parameters
89 | self.__check_avaliability(method, exp_required=True)
90 | self.explainers[method].show_explanation(**kwargs)
91 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/xdeep_image.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for explaining image classifiers.
3 | """
4 |
5 | from .xdeep_base import Wrapper
6 | from .image_explainers.lime_image import XDeepLimeImageExplainer
7 | from .image_explainers.cle_image import XDeepCLEImageExplainer
8 | from .image_explainers.anchor_image import XDeepAnchorImageExplainer
9 | from .image_explainers.shap_image import XDeepShapImageExplainer
10 |
11 |
12 | class ImageExplainer(Wrapper):
13 | """Integrated explainer which explains text classifiers."""
14 |
15 | def __init__(self, predict_proba, class_names):
16 | """Init function.
17 |
18 | # Arguments
19 | predict_proba: Function. Classifier prediction probability function.
20 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
21 | """
22 | Wrapper.__init__(self, predict_proba, class_names=class_names)
23 | self.methods = ['lime', 'cle', 'anchor', 'shap']
24 | self.__initialization()
25 |
26 | def __initialization(self):
27 | """Initializer. Use default parameters to initialize 'lime' and 'anchor' explainer."""
28 | print("Initialize default 'lime', 'cle', 'anchor' explainers. " +
29 | "Please explicitly initialize 'shap' explainer before use it.")
30 | self.explainers['lime'] = XDeepLimeImageExplainer(self.predict_proba, self.class_names)
31 | self.explainers['cle'] = XDeepCLEImageExplainer(self.predict_proba, self.class_names)
32 | self.explainers['anchor'] = XDeepAnchorImageExplainer(self.predict_proba, self.class_names)
33 |
34 | def initialize_shap(self, n_segment, segment):
35 | """Init function.
36 |
37 | # Arguments
38 | n_segment: Integer. Number of segments in the image.
39 | segment: Array. An array with 2 dimensions, segments_slic of the image.
40 | """
41 | self.explainers['shap'] = XDeepShapImageExplainer(self.predict_proba, self.class_names, n_segment, segment)
42 |
43 | def explain(self, method, instance, top_labels=2, labels=(1,), **kwargs):
44 | """Generate explanation for a prediction using certain method.
45 |
46 | # Arguments
47 | method: Str. The method you want to use.
48 | instance: Instance to be explained.
49 | top_labels: Integer. Number of labels you care about.
50 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
51 | **kwargs: Parameters setter. For more detail, please check 'explain_instance' in corresponding method.
52 | """
53 | Wrapper.explain(self, method, instance, top_labels=top_labels, labels=labels, **kwargs)
54 |
55 | def show_explanation(self, method, **kwargs):
56 | """Visualization of explanation of the corresponding method.
57 |
58 | # Arguments
59 | method: Str. The method you want to use.
60 | **kwargs: parameters setter. For more detail, please check xdeep documentation.
61 | """
62 | Wrapper.show_explanation(self, method, **kwargs)
63 |
64 | def get_explanation(self, method):
65 | """Explanation getter.
66 |
67 | # Arguments
68 | method: Str. The method you want to use.
69 |
70 | # Return
71 | An Explanation object with the corresponding method.
72 | """
73 | Wrapper.get_explanation(self, method)
74 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/xdeep_tabular.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for explaining tabular classifiers.
3 | """
4 |
5 | from .xdeep_base import Wrapper
6 | from .exceptions import XDeepError
7 | from .tabular_explainers.lime_tabular import XDeepLimeTabularExplainer
8 | from .tabular_explainers.cle_tabular import XDeepCLETabularExplainer
9 | from .tabular_explainers.anchor_tabular import XDeepAnchorTabularExplainer
10 | from .tabular_explainers.shap_tabular import XDeepShapTabularExplainer
11 |
12 |
13 | class TabularExplainer(Wrapper):
14 | """Integrated explainer which explains text classifiers."""
15 |
16 | def __init__(self, predict_proba, class_names, feature_names, train,
17 | categorical_features=None, categorical_names=None):
18 | """Init function.
19 |
20 | # Arguments
21 | predict_proba: Function. Classifier prediction probability function.
22 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
23 | feature_names: List. A list of names (strings) corresponding to the columns in the training data.
24 | train: Array. Train data, in this case a list of tabular data.
25 | categorical_features: List. A list of indices (ints) corresponding to the categorical columns. Everything else will be considered continuous. Values in these columns MUST be integers.
26 | categorical_names: Dict. A dict which maps from int to list of names, where categorical_names[x][y] represents the name of the yth value of column x.
27 | """
28 | Wrapper.__init__(self, predict_proba, class_names=class_names)
29 | self.methods = ['lime', 'cle', 'anchor', 'shap']
30 | self.feature_names = feature_names
31 | self.train = train
32 | self.categorical_features = categorical_features
33 | self.categorical_names = categorical_names
34 | self.__initialization()
35 |
36 | def __initialization(self):
37 | """Initializer. Use default parameters to initialize 'lime' and 'shap' explainer."""
38 | print("Initialize default 'lime', 'cle', 'shap' explainers. "
39 | "Please explicitly initialize 'anchor' explainer before use it.")
40 | self.explainers['lime'] = XDeepLimeTabularExplainer(self.predict_proba, self.class_names, self.feature_names,
41 | self.train, categorical_features=self.categorical_features,
42 | categorical_names=self.categorical_names)
43 | self.explainers['cle'] = XDeepCLETabularExplainer(self.predict_proba, self.class_names, self.feature_names,
44 | self.train, categorical_features=self.categorical_features,
45 | categorical_names=self.categorical_names)
46 | self.explainers['shap'] = XDeepShapTabularExplainer(self.predict_proba, self.class_names, self.train)
47 |
48 | def initialize_anchor(self, data):
49 | """Explicit shap initializer.
50 |
51 | # Arguments
52 | data: Array. Full data including train data, validation_data and test data.
53 | """
54 | print("If you want to use 'anchor' to explain tabular classifier. "
55 | "You need to get this encoder to encode your data, and train another model.")
56 | self.explainers['anchor'] = XDeepAnchorTabularExplainer(self.class_names, self.feature_names, data,
57 | categorical_names=self.categorical_names)
58 |
59 | def get_anchor_encoder(self, train_data, train_labels, validation_data, validation_labels, discretizer='quartile'):
60 | """Get encoder for tabular data.
61 |
62 | If you want to use 'anchor' to explain tabular classifier. You need to get this encoder to encode your data, and train another model.
63 |
64 | # Arguments
65 | train_data: Array. Train data.
66 | train_labels: Array. Train labels.
67 | validation_data: Array. Validation set.
68 | validation_labels: Array. Validation labels.
69 | discretizer: Str. Discretizer for data. Please choose between 'quartile' and 'decile'.
70 |
71 | # Return
72 | A encoder object which has function 'transform'.
73 | """
74 | if 'anchor' not in self.explainers:
75 | raise XDeepError("Please initialize anchor explainer first.")
76 | return self.explainers['anchor'].get_anchor_encoder(train_data, train_labels,
77 | validation_data, validation_labels, discretizer='quartile')
78 |
79 | def set_anchor_predict_proba(self, predict_proba):
80 | """Anchor predict function setter.
81 | Because you will get an encoder and train new model, you will need to update the predict function.
82 |
83 | # Arguments
84 | predict_proba: Function. A new classifier prediction probability function.
85 | """
86 | self.explainers['anchor'].set_anchor_predict_proba(predict_proba)
87 |
88 | def explain(self, method, instance, top_labels=None, labels=(1,), **kwargs):
89 | """Generate explanation for a prediction using certain method.
90 |
91 | # Arguments
92 | method: Str. The method you want to use.
93 | instance: Instance to be explained.
94 | top_labels: Integer. Number of labels you care about.
95 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
96 | **kwargs: Parameters setter. For more detail, please check 'explain_instance' in corresponding method.
97 | """
98 | Wrapper.explain(self, method, instance, top_labels=top_labels, labels=labels, **kwargs)
99 |
100 |
101 | def show_explanation(self, method, **kwargs):
102 | """Visualization of explanation of the corresponding method.
103 |
104 | # Arguments
105 | method: Str. The method you want to use.
106 | **kwargs: parameters setter. For more detail, please check xdeep documentation.
107 | """
108 | Wrapper.show_explanation(self, method, **kwargs)
109 |
110 | def get_explanation(self, method):
111 | """Explanation getter.
112 |
113 | # Arguments
114 | method: Str. The method you want to use.
115 |
116 | # Return
117 | An Explanation object with the corresponding method.
118 | """
119 | Wrapper.get_explanation(self, method)
120 |
--------------------------------------------------------------------------------
/xdeep/xlocal/perturbation/xdeep_text.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for explaining text classifiers.
3 | """
4 |
5 | from .xdeep_base import Wrapper
6 | from .text_explainers.lime_text import XDeepLimeTextExplainer
7 | from .text_explainers.cle_text import XDeepCLETextExplainer
8 | from .text_explainers.anchor_text import XDeepAnchorTextExplainer
9 | from .text_explainers.shap_text import XDeepShapTextExplainer
10 |
11 |
12 | class TextExplainer(Wrapper):
13 | """Integrated explainer which explains text classifiers."""
14 |
15 | def __init__(self, predict_proba, class_names):
16 | """Init function.
17 |
18 | # Arguments
19 | predict_proba: Function. Classifier prediction probability function.
20 | class_names: List. A list of class names, ordered according to whatever the classifier is using.
21 | """
22 | Wrapper.__init__(self, predict_proba, class_names=class_names)
23 | self.methods = ['lime', 'cle', 'anchor', 'shap']
24 | self.__initialization()
25 |
26 | def __initialization(self):
27 | """Initializer. Use default parameters to initialize 'lime' and 'anchor' explainer."""
28 | print("Initialize default 'lime', 'cle', 'anchor' explainers. " +
29 | "Please explicitly initialize 'shap' explainer before use it.")
30 | self.explainers['lime'] = XDeepLimeTextExplainer(self.predict_proba, self.class_names)
31 | self.explainers['cle'] = XDeepCLETextExplainer(self.predict_proba, self.class_names)
32 | self.explainers['anchor'] = XDeepAnchorTextExplainer(self.predict_proba, self.class_names)
33 |
34 | def initialize_shap(self, predict_vectorized, vectorizer, train):
35 | """Explicit shap initializer.
36 |
37 | # Arguments
38 | predict_vectorized: Function. Classifier prediction probability function which need vector passed in.
39 | vectorizer: Vectorizer. A vectorizer which has 'transform' function that transforms list of str to vector.
40 | train: Array. Train data, in this case a list of str.
41 | """
42 | self.explainers['shap'] = XDeepShapTextExplainer(predict_vectorized, self.class_names, vectorizer, train)
43 |
44 | def explain(self, method, instance, top_labels=None, labels=(1,), **kwargs):
45 | """Generate explanation for a prediction using certain method.
46 |
47 | # Arguments
48 | method: Str. The method you want to use.
49 | instance: Instance to be explained.
50 | top_labels: Integer. Number of labels you care about.
51 | labels: Tuple. Labels you care about, if top_labels is not none, it will be replaced by the predicted top labels.
52 | **kwargs: Parameters setter. For more detail, please check 'explain_instance' in corresponding method.
53 | """
54 | Wrapper.explain(self, method, instance, top_labels=top_labels, labels=labels, **kwargs)
55 |
56 | def show_explanation(self, method, **kwargs):
57 | """Visualization of explanation of the corresponding method.
58 |
59 | # Arguments
60 | method: Str. The method you want to use.
61 | **kwargs: parameters setter. For more detail, please check xdeep documentation.
62 | """
63 | Wrapper.show_explanation(self, method, **kwargs)
64 |
65 | def get_explanation(self, method):
66 | """Explanation getter.
67 |
68 | # Arguments
69 | method: Str. The method you want to use.
70 |
71 | # Return
72 | An Explanation object with the corresponding method.
73 | """
74 | Wrapper.get_explanation(self, method)
75 |
--------------------------------------------------------------------------------