├── .gitignore ├── .travis.yml ├── LICENSE.md ├── README.md ├── deepexplain ├── __init__.py ├── tensorflow │ ├── __init__.py │ ├── methods.py │ └── utils.py └── tests │ ├── __init__.py │ └── test_tensorflow.py ├── docs └── comparison.png ├── examples ├── data │ ├── .gitkeep │ ├── images │ │ ├── .gitkeep │ │ ├── 0c7ac4a8c9dfa802.png │ │ ├── 1c2e9fe8b0b2fdf2.png │ │ ├── 4fc263d35a3ad3ee.png │ │ └── 5b3a8c63e41802e7.png │ └── models │ │ └── .gitkeep ├── inception_tensorflow.ipynb ├── mint_cnn_keras.ipynb ├── mnist_tensorflow.ipynb ├── multiple_input_keras.ipynb └── utils.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | 47 | # Translations 48 | *.mo 49 | *.pot 50 | 51 | # Django stuff: 52 | *.log 53 | 54 | # Sphinx documentation 55 | docs/_build/ 56 | 57 | # PyBuilder 58 | target/ 59 | 60 | # PyCharm 61 | .idea 62 | 63 | examples/\.ipynb_checkpoints/ 64 | 65 | *.index 66 | 67 | *.meta 68 | 69 | *.ckpt 70 | 71 | *.data-00000-of-00001 72 | 73 | *.h5 74 | 75 | *.hdf5 76 | 77 | *.npy 78 | 79 | *.html 80 | 81 | *.p 82 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | matrix: 3 | include: 4 | - python: 2.7 5 | env: KERAS_BACKEND=tensorflow 6 | - python: 3.6 7 | env: KERAS_BACKEND=tensorflow 8 | 9 | install: 10 | # install TensorFlow (CPU version). 11 | - pip install tensorflow --upgrade --upgrade-strategy only-if-needed 12 | - pip install keras --upgrade --upgrade-strategy only-if-needed 13 | - pip install -e . 14 | 15 | # command to run tests 16 | script: 17 | - nosetests 18 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2017 ETH Zurich (Marco Ancona) 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | DeepExplain: attribution methods for Deep Learning 2 | [![Build Status](https://travis-ci.org/marcoancona/DeepExplain.svg?branch=master)](https://travis-ci.org/marcoancona/DeepExplain) 3 | === 4 | DeepExplain provides a unified framework for state-of-the-art gradient *and* perturbation-based attribution methods. 5 | It can be used by researchers and practitioners for better undertanding the recommended existing models, as well for benchmarking other attribution methods. 6 | 7 | It supports **Tensorflow** as well as **Keras** with Tensorflow backend. **Only Tensorflow V1 is supported. For V2, there is an open pull-request, that works if eager execution is disabled.** 8 | 9 | Implements the following methods: 10 | 11 | **Gradient-based attribution methods** 12 | - [**Saliency maps**](https://arxiv.org/abs/1312.6034) 13 | - [**Gradient * Input**](https://arxiv.org/abs/1605.01713) 14 | - [**Integrated Gradients**](https://arxiv.org/abs/1703.01365) 15 | - [**DeepLIFT**](https://arxiv.org/abs/1704.02685), in its first variant with Rescale rule (*) 16 | - [**ε-LRP**](http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140) (*) 17 | 18 | Methods marked with (*) are implemented as modified chain-rule, as better explained in [Towards better understanding of gradient-based attribution methods for Deep Neural Networks](https://openreview.net/forum?id=Sy21R9JAW), Ancona *et al*, ICLR 2018. As such, the result might be slightly different from the original implementation. 19 | 20 | **Pertubration-based attribution methods** 21 | - [**Occlusion**](https://arxiv.org/abs/1311.2901), as an extension 22 | of the [grey-box method by Zeiler *et al*](https://arxiv.org/abs/1311.2901). 23 | - [**Shapley Value sampling**](https://www.sciencedirect.com/science/article/pii/S0305054808000804) 24 | 25 | ## What are attributions? 26 | Consider a network and a specific input to this network (eg. an image, if the network is trained for image classification). The input is multi-dimensional, made of several features. In the case of images, each pixel can be considered a feature. The goal of an attribution method is to determine a real value `R(x_i)` for each input feature, with respect to a target neuron of interest (for example, the activation of the neuron corresponsing to the correct class). 27 | 28 | When the attributions of all input features are arranged together to have the same shape of the input sample we talk about *attribution maps* (as in the picture below), where red and blue colors indicate respectively features that contribute positively to the activation of the target output and features having a suppressing effect on it. 29 | ![Attribution methods comparison on InceptionV3](https://github.com/marcoancona/DeepExplain/blob/master/docs/comparison.png) 30 | 31 | This can help to better understand the network behavior, which features mostly contribute to the output and possible reasons for missclassification. 32 | 33 | 34 | DeepExplain Quickstart 35 | === 36 | ## Installation 37 | ```unix 38 | pip install -e git+https://github.com/marcoancona/DeepExplain.git#egg=deepexplain 39 | ``` 40 | 41 | Notice that DeepExplain assumes you already have installed `Tensorflow > 1.0` and (optionally) `Keras > 2.0`. 42 | 43 | ## Usage 44 | 45 | Working examples for Tensorflow and Keras can be found in the `example` folder of the repository. DeepExplain 46 | consists of a single method: `explain(method_name, target_tensor, input_tensor, samples, ...args)`. 47 | 48 | 49 | Parameter name | Short name | Type | Description 50 | ---------------|------|------|------------ 51 | `method_name` | | string, required | Name of the method to run (see [Which method to use?](#which-method-to-use)). 52 | `target_tensor` | `T` | Tensor, required | Tensorflow Tensor object representing the output of the model for which attributions are seeked (see [Which tensor to target?](#which-neuron-to-target)). 53 | `input_tensor` | `X` | Tensor, required | Symbolic input to the network. 54 | `input_data` | `xs` | numpy array, required | Batch of input samples to be fed to `X` and for which attributions are seeked. Notice that the first dimension must always be the batch size. 55 | `target_weights` | `ys` | numpy array, optional | Batch of weights to be applied to `T` if this has more than one output. Usually necessary on classification problems where there are multiple output units and we need to target a specific one to generate explanations for. In this case, `ys` can be provided with the one-hot encoding of the desired unit. 56 | `batch_size` | |int, optional| By default, DeepExplain will try to evaluate the model using all data in `xs` at the same time. If `xs` contains many samples, it might be necessary to split the processing in batches. In this case, providing a `batch_size` greater than zero will automatically split the evaluation into chunks of the given size. 57 | `...args` | | various, optional | Method-specific parameters (see below). 58 | 59 | The method `explain` must be called within a DeepExplain context: 60 | 61 | ```python 62 | # Pseudo-code 63 | from deepexplain.tensorflow import DeepExplain 64 | 65 | # Option 1. Create and train your model within a DeepExplain context 66 | 67 | with DeepExplain(session=...) as de: # < enter DeepExplain context 68 | model = init_model() # < construct the model 69 | model.fit() # < train the model 70 | 71 | attributions = de.explain(...) # < compute attributions 72 | 73 | # Option 2. First create and train your model, then apply DeepExplain. 74 | # IMPORTANT: in order to work correctly, the graph to analyze 75 | # must always be (re)constructed within the context! 76 | 77 | model = init_model() # < construct the model 78 | model.fit() # < train the model 79 | 80 | with DeepExplain(session=...) as de: # < enter DeepExplain context 81 | new_model = init_model() # < assumes init_model() returns a *new* model with the weights of `model` 82 | attributions = de.explain(...) # < compute attributions 83 | ``` 84 | 85 | When initializing the context, make sure to pass the `session` parameter: 86 | 87 | ```python 88 | # With Tensorflow 89 | import tensorflow as tf 90 | # ...build model 91 | sess = tf.Session() 92 | # ... use session to train your model if necessary 93 | with DeepExplain(session=sess) as de: 94 | ... 95 | 96 | # With Keras 97 | import keras 98 | from keras import backend as K 99 | 100 | model = Sequential() # functional API is also supported 101 | # ... build model and train 102 | 103 | with DeepExplain(session=K.get_session()) as de: 104 | ... 105 | ``` 106 | 107 | See concrete examples [here](https://github.com/marcoancona/DeepExplain/tree/master/examples). 108 | 109 | ## Which method to use? 110 | DeepExplain supports several methods. The main partition is between *gradient-based methods* and *perturbation-based methods*. The former are faster, given that they estimate attributions with a few forward and backward iterations through the network. The latter perturb the input and measure the change in output with respect to the original input. This requires to sequentially test each feature (or group of features) and therefore takes more time, but tends to produce smoother results. 111 | 112 | Cooperative game theory suggests [**Shapley Values**](https://en.wikipedia.org/wiki/Shapley_value) as a unique way to distribute attribution to features such that some important theoretical properties are satisfied. Unfortunately, computing Shapley Values exactly is prohibitively expensive, therefore DeepExplain provides a sampling-based approximation. By changing the `samples` parameters, one can adjust the trade-off between performance and error. Notice that this method will still be significantly slower than other methods in this library. 113 | 114 | Some methods allow tunable parameters. See the table below. 115 | 116 | Method | `method_name` | Optional parameters | Notes 117 | ---------------|:------|:------------|----- 118 | Saliency | `saliency` | | [*Gradient*] Only positive attributions. 119 | Gradient * Input | `grad*input` | | [*Gradient*] Fast. May be affected by noisy gradients and saturation of the nonlinerities. 120 | Integrated Gradients | `intgrad` |`steps`, `baseline` | [*Gradient*] Similar to Gradient * Input, but performs `steps` iterations (default: 100) though the network, varying the input from `baseline` (default: zero) to the actual provided sample. When provided, `baseline` must be a numpy array with the size of the input (but no batch dimension since the same baseline will be used for all inputs in the batch). 121 | epsilon-LRP | `elrp` | `epsilon` | [*Gradient*]Computes Layer-wise Relevance Propagation. Only recommanded with ReLU or Tanh nonlinearities. Value for `epsilon` must be greater than zero (default: .0001). 122 | DeepLIFT (Rescale) | `deeplift` | `baseline` | [*Gradient*] In most cases a faster approximation of Integrated Gradients. Do not apply to networks with multiplicative units (ie. LSTM or GRU). When provided, `baseline` must be a numpy array with the size of the input, without the batch dimension (default: zero). 123 | Occlusion | `occlusion` | `window_shape`, `step` | [*Perturbation*] Computes rolling window view of the input array and replace each window with zero values, measuring the effect of the perturbation on the target output. The optional parameters `window_shape` and `step` behave like in [skimage](http://scikit-image.org/docs/dev/api/skimage.util.html#skimage.util.view_as_windows). By default, each feature is tested independently (`window_shape=1` and `step=1`), however this might be extremely slow for large inputs (such as ImageNet images). When the input presents some local coherence (eg. images), you might prefer larger values for `window_shape`. In this case the attributions of the features in each window will be summed up. Notice that the result might vary significantly for different window sizes. 124 | Shapley Value sampling | `shapley_sampling` | `samples`, `sampling_dims` | [*Perturbation*] Computes approximate Shapley Values by sampling `samples` times each input feature. Notice that this method can be significantly slower than all the others as it runs the network `samples*n` times, where `n` is the number of input features in your input. The parameter `sampling_dims` (a list of integers) can be used to select which dimensions should be sampled. For example, if the inputs are RGB images, `sampling_dims=[1,2]` would sample pixels considering the three color channels atomic. Instead `sampling_dims=[1,2,3]` (default) will samples over the channels as well. 125 | 126 | ## Which neuron to target? 127 | In general, any tensor that represents the activation of any hidden or output neuron can be user as `target_tensor`. If your network performs a classification task (ie. one output neuron for each possible class) you might want to target the neuron corresponding to the *correct class* for a given sample, such that the attribution map might help you undertand the reasons for this neuron to (not) activate. However you can also target the activation of another class, for example a class that is often missclassified, to have insight about features that activate this class. 128 | 129 | **Important**: Tensors in Tensorflow and Keras usually include the activations of *all* neurons of a layer. If you pass such a tensor to `explain` you will get the *sum* attribution map for all neurons the Tensor refers to. If you want to target a specific neuron you need either to slice the component you are interested in or multiply it for a binary mask that only select the target neuron. 130 | 131 | ```python 132 | # Example on MNIST (classification, with 10 output classes) 133 | X = Placeholder(...) # input tensor 134 | T = model(X) # output layer, 2-dimensional Tensor of shape (1, 10), where first dimension is the batch size 135 | ys = [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]] # numpy array of shape (1, 10) with one-hot encoding of labels 136 | 137 | # We need to target only one of the 10 output units in `T` 138 | # Option 1 (recommanded): use the `ys` parameter 139 | de.explain('method_name', T, X, xs, ys=ys) 140 | 141 | # Option 2: manually mask the target. This will not work with batch processing. 142 | T *= ys # < masked target tensor: only the second component of `logits` will be used to compute attributions 143 | de.explain('method_name', T, X, xs) 144 | 145 | ``` 146 | 147 | **Softmax**: if the network last activation is a Softmax, it is recommanded to target the activations *before* this normalization. 148 | 149 | ### Performance: Explainer API 150 | If you need to run `explain()` multiple times (for example, new data to process with the same model comes in over time) it is recommanded that you use the Explainer API. This provides a way to *compile* the graph operations needed to generate the explanations and *evaluate* this graph in two different steps. 151 | 152 | Within a DeepExplain context (`de`), call `de.get_explainer()`. This method takes the same arguments of `explain()` except `xs`, `ys` and `batch_size`. It returns an explainer object (`explainer`) which provides a `run()` method. Call `explainer.run(xs, [ys], [batch_size])` to generate the explanations. Calling `run()` multiple times will not add new operations to the computational graph. 153 | 154 | 155 | ```python 156 | # Normal API: 157 | 158 | for i in range(100): 159 | # The following line will become slower and slower as new operations are added to the computational graph at each iteration 160 | attributions = de.explain('saliency', T, X, xs[i], ys=ys[i], batch_size=3) 161 | 162 | # Use the Explainer API instead: 163 | 164 | # First create an explainer 165 | explainer = de.get_explainer('saliency', T, X) 166 | for i in range(100): 167 | # Then generate explanations for some data without slowing things down 168 | attributions = explainer.run(xs[i], ys=ys[i], batch_size=3) 169 | ``` 170 | 171 | 172 | ### NLP / Embedding lookups 173 | The most common cause of `ValueError("None values not supported.")` is `run()` being called with a `tensor_input` and `target_tensor` that are disconnected in the backpropagation. This is common when an embedding lookup layer is used, since the lookup operation does not propagate the gradient. To generate attributions for NLP models, the input of DeepExplain should be the result of the embedding lookup instead of the original model input. Then, attributions for each word are found by summing up along the appropriate dimension of the resulting attribution matrix. 174 | 175 | Tensorflow pseudocode: 176 | ```python 177 | input_x = graph.get_operation_by_name("input_x").outputs[0] 178 | # Get a reference to the embedding tensor 179 | embedding = graph.get_operation_by_name("embedding").outputs[0] 180 | pre_softmax = graph.get_operation_by_name("output/scores").outputs[0] 181 | 182 | # Evaluate the embedding tensor on the model input (in other words, perform the lookup) 183 | embedding_out = sess.run(embedding, {input_x: x_test}) 184 | # Run DeepExplain with the embedding as input 185 | attributions = de.explain('elrp', pre_softmax * y_test_logits, embedding, embedding_out) 186 | ``` 187 | 188 | ### Multiple inputs 189 | Models with multiple inputs are supported for gradient-based methods. Instead, the `Occlusion` method will raise an exception if 190 | called on a model with multiple inputs (how perturbation should be generated for multiple inputs is actually not well defined). 191 | 192 | For a minimal (toy) example see the [example folder](https://github.com/marcoancona/DeepExplain/tree/master/examples). 193 | 194 | ## Contributing 195 | DeepExplain is still in active development. If you experience problems, feel free to open an issue. Contributions to extend the functinalities of this framework and/or to add support for other methods are welcome. 196 | 197 | ## License 198 | MIT 199 | -------------------------------------------------------------------------------- /deepexplain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcoancona/DeepExplain/87fb43a13ac2a3b285a030b87df899cc40100c94/deepexplain/__init__.py -------------------------------------------------------------------------------- /deepexplain/tensorflow/__init__.py: -------------------------------------------------------------------------------- 1 | from .methods import DeepExplain -------------------------------------------------------------------------------- /deepexplain/tensorflow/methods.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import numpy as np 7 | from skimage.util import view_as_windows 8 | import warnings, logging 9 | import tensorflow as tf 10 | from tensorflow.python.framework import ops 11 | from tensorflow.python.ops import nn_grad, math_grad 12 | from collections import OrderedDict 13 | from .utils import make_batches, slice_arrays, to_list, unpack_singleton, placeholder_from_data 14 | 15 | SUPPORTED_ACTIVATIONS = [ 16 | 'Relu', 'Elu', 'Sigmoid', 'Tanh', 'Softplus' 17 | ] 18 | 19 | UNSUPPORTED_ACTIVATIONS = [ 20 | 'CRelu', 'Relu6', 'Softsign' 21 | ] 22 | 23 | _ENABLED_METHOD_CLASS = None 24 | _GRAD_OVERRIDE_CHECKFLAG = 0 25 | 26 | 27 | # ----------------------------------------------------------------------------- 28 | # UTILITY FUNCTIONS 29 | # ----------------------------------------------------------------------------- 30 | 31 | 32 | def activation(type): 33 | """ 34 | Returns Tensorflow's activation op, given its type 35 | :param type: string 36 | :return: op 37 | """ 38 | if type not in SUPPORTED_ACTIVATIONS: 39 | warnings.warn('Activation function (%s) not supported' % type) 40 | f = getattr(tf.nn, type.lower()) 41 | return f 42 | 43 | 44 | def original_grad(op, grad): 45 | """ 46 | Return original Tensorflow gradient for an op 47 | :param op: op 48 | :param grad: Tensor 49 | :return: Tensor 50 | """ 51 | if op.type not in SUPPORTED_ACTIVATIONS: 52 | warnings.warn('Activation function (%s) not supported' % op.type) 53 | opname = '_%sGrad' % op.type 54 | if hasattr(nn_grad, opname): 55 | f = getattr(nn_grad, opname) 56 | else: 57 | f = getattr(math_grad, opname) 58 | return f(op, grad) 59 | 60 | 61 | # ----------------------------------------------------------------------------- 62 | # ATTRIBUTION METHODS BASE CLASSES 63 | # ----------------------------------------------------------------------------- 64 | 65 | 66 | class AttributionMethod(object): 67 | """ 68 | Attribution method base class 69 | """ 70 | def __init__(self, T, X, session, keras_learning_phase=None): 71 | self.T = T # target Tensor 72 | self.X = X # input Tensor 73 | self.Y_shape = [None,] + T.get_shape().as_list()[1:] 74 | # Most often T contains multiple output units. In this case, it is often necessary to select 75 | # a single unit to compute contributions for. This can be achieved passing 'ys' as weight for the output Tensor. 76 | self.Y = tf.placeholder(tf.float32, self.Y_shape) 77 | # placeholder_from_data(ys) if ys is not None else 1.0 # Tensor that represents weights for T 78 | self.T = self.T * self.Y 79 | self.symbolic_attribution = None 80 | self.session = session 81 | self.keras_learning_phase = keras_learning_phase 82 | self.has_multiple_inputs = type(self.X) is list or type(self.X) is tuple 83 | logging.info('Model with multiple inputs: %s' % self.has_multiple_inputs) 84 | 85 | # Set baseline 86 | # TODO: now this sets a baseline also for those methods that does not require it 87 | self._set_check_baseline() 88 | 89 | # References 90 | self._init_references() 91 | 92 | # Create symbolic explanation once during construction (affects only gradient-based methods) 93 | self.explain_symbolic() 94 | 95 | def explain_symbolic(self): 96 | return None 97 | 98 | def run(self, xs, ys=None, batch_size=None): 99 | pass 100 | 101 | def _init_references(self): 102 | pass 103 | 104 | def _check_input_compatibility(self, xs, ys=None, batch_size=None): 105 | if ys is not None: 106 | if not self.has_multiple_inputs and len(xs) != len(ys): 107 | raise RuntimeError('When provided, ys must have the same batch size as xs (xs has batch size {} and ys {})'.format(len(xs), len(ys))) 108 | elif self.has_multiple_inputs and np.all([len(i) != len(ys) for i in xs]): 109 | raise RuntimeError('When provided, ys must have the same batch size as all elements of xs') 110 | if batch_size is not None and batch_size > 0: 111 | if self.T.shape[0].value is not None and self.T.shape[0].value is not batch_size: 112 | raise RuntimeError('When using batch evaluation, the first dimension of the target tensor ' 113 | 'must be compatible with the batch size. Found %s instead' % self.T.shape[0].value) 114 | if isinstance(self.X, list): 115 | for x in self.X: 116 | if x.shape[0].value is not None and x.shape[0].value is not batch_size: 117 | raise RuntimeError('When using batch evaluation, the first dimension of the input tensor ' 118 | 'must be compatible with the batch size. Found %s instead' % x.shape[ 119 | 0].value) 120 | else: 121 | if self.X.shape[0].value is not None and self.X.shape[0].value is not batch_size: 122 | raise RuntimeError('When using batch evaluation, the first dimension of the input tensor ' 123 | 'must be compatible with the batch size. Found %s instead' % self.X.shape[0].value) 124 | 125 | def _session_run_batch(self, T, xs, ys=None): 126 | feed_dict = {} 127 | if self.has_multiple_inputs: 128 | for k, v in zip(self.X, xs): 129 | feed_dict[k] = v 130 | else: 131 | feed_dict[self.X] = xs 132 | 133 | # If ys is not passed, produce a vector of ones that will be broadcasted to all batch samples 134 | feed_dict[self.Y] = ys if ys is not None else np.ones([1,] + self.Y_shape[1:]) 135 | 136 | if self.keras_learning_phase is not None: 137 | feed_dict[self.keras_learning_phase] = 0 138 | return self.session.run(T, feed_dict) 139 | 140 | def _session_run(self, T, xs, ys=None, batch_size=None): 141 | num_samples = len(xs) 142 | if self.has_multiple_inputs is True: 143 | num_samples = len(xs[0]) 144 | if len(xs) != len(self.X): 145 | raise RuntimeError('List of input tensors and input data have different lengths (%s and %s)' 146 | % (str(len(xs)), str(len(self.X)))) 147 | if batch_size is not None: 148 | for xi in xs: 149 | if len(xi) != num_samples: 150 | raise RuntimeError('Evaluation in batches requires all inputs to have ' 151 | 'the same number of samples') 152 | 153 | if batch_size is None or batch_size <= 0 or num_samples <= batch_size: 154 | return self._session_run_batch(T, xs, ys) 155 | else: 156 | outs = [] 157 | batches = make_batches(num_samples, batch_size) 158 | for batch_index, (batch_start, batch_end) in enumerate(batches): 159 | # Get a batch from data 160 | xs_batch = slice_arrays(xs, batch_start, batch_end) 161 | # If the target tensor has one entry for each sample, we need to batch it as well 162 | ys_batch = None 163 | if ys is not None: 164 | ys_batch = slice_arrays(ys, batch_start, batch_end) 165 | batch_outs = self._session_run_batch(T, xs_batch, ys_batch) 166 | batch_outs = to_list(batch_outs) 167 | if batch_index == 0: 168 | # Pre-allocate the results arrays. 169 | for batch_out in batch_outs: 170 | shape = (num_samples,) + batch_out.shape[1:] 171 | outs.append(np.zeros(shape, dtype=batch_out.dtype)) 172 | for i, batch_out in enumerate(batch_outs): 173 | outs[i][batch_start:batch_end] = batch_out 174 | return unpack_singleton(outs) 175 | 176 | def _set_check_baseline(self): 177 | # Do nothing for those methods that have no baseline required 178 | if not hasattr(self, "baseline"): 179 | return 180 | 181 | if self.baseline is None: 182 | if self.has_multiple_inputs: 183 | self.baseline = [np.zeros([1,] + xi.get_shape().as_list()[1:]) for xi in self.X] 184 | else: 185 | self.baseline = np.zeros([1,] + self.X.get_shape().as_list()[1:]) 186 | 187 | else: 188 | if self.has_multiple_inputs: 189 | for i, xi in enumerate(self.X): 190 | if list(self.baseline[i].shape) == xi.get_shape().as_list()[1:]: 191 | self.baseline[i] = np.expand_dims(self.baseline[i], 0) 192 | else: 193 | raise RuntimeError('Baseline shape %s does not match expected shape %s' 194 | % (self.baseline[i].shape, self.X.get_shape().as_list()[1:])) 195 | else: 196 | if list(self.baseline.shape) == self.X.get_shape().as_list()[1:]: 197 | self.baseline = np.expand_dims(self.baseline, 0) 198 | else: 199 | raise RuntimeError('Baseline shape %s does not match expected shape %s' 200 | % (self.baseline.shape, self.X.get_shape().as_list()[1:])) 201 | 202 | 203 | class GradientBasedMethod(AttributionMethod): 204 | """ 205 | Base class for gradient-based attribution methods 206 | """ 207 | def get_symbolic_attribution(self): 208 | return tf.gradients(self.T, self.X) 209 | 210 | def explain_symbolic(self): 211 | if self.symbolic_attribution is None: 212 | self.symbolic_attribution = self.get_symbolic_attribution() 213 | return self.symbolic_attribution 214 | 215 | def run(self, xs, ys=None, batch_size=None): 216 | self._check_input_compatibility(xs, ys, batch_size) 217 | results = self._session_run(self.explain_symbolic(), xs, ys, batch_size) 218 | return results[0] if not self.has_multiple_inputs else results 219 | 220 | @classmethod 221 | def nonlinearity_grad_override(cls, op, grad): 222 | return original_grad(op, grad) 223 | 224 | 225 | class PerturbationBasedMethod(AttributionMethod): 226 | """ 227 | Base class for perturbation-based attribution methods 228 | """ 229 | def __init__(self, T, X, session, keras_learning_phase): 230 | super(PerturbationBasedMethod, self).__init__(T, X, session, keras_learning_phase) 231 | self.base_activation = None 232 | 233 | 234 | 235 | # ----------------------------------------------------------------------------- 236 | # ATTRIBUTION METHODS 237 | # ----------------------------------------------------------------------------- 238 | """ 239 | Returns zero attributions. For testing only. 240 | """ 241 | 242 | 243 | class DummyZero(GradientBasedMethod): 244 | 245 | def get_symbolic_attribution(self,): 246 | return tf.gradients(self.T, self.X) 247 | 248 | @classmethod 249 | def nonlinearity_grad_override(cls, op, grad): 250 | input = op.inputs[0] 251 | return tf.zeros_like(input) 252 | 253 | """ 254 | Saliency maps 255 | https://arxiv.org/abs/1312.6034 256 | """ 257 | 258 | 259 | class Saliency(GradientBasedMethod): 260 | 261 | def get_symbolic_attribution(self): 262 | return [tf.abs(g) for g in tf.gradients(self.T, self.X)] 263 | 264 | 265 | """ 266 | Gradient * Input 267 | https://arxiv.org/pdf/1704.02685.pdf - https://arxiv.org/abs/1611.07270 268 | """ 269 | 270 | 271 | class GradientXInput(GradientBasedMethod): 272 | 273 | def get_symbolic_attribution(self): 274 | return [g * x for g, x in zip( 275 | tf.gradients(self.T, self.X), 276 | self.X if self.has_multiple_inputs else [self.X])] 277 | 278 | 279 | """ 280 | Integrated Gradients 281 | https://arxiv.org/pdf/1703.01365.pdf 282 | """ 283 | 284 | 285 | class IntegratedGradients(GradientBasedMethod): 286 | 287 | def __init__(self, T, X, session, keras_learning_phase, steps=100, baseline=None): 288 | self.steps = steps 289 | self.baseline = baseline 290 | super(IntegratedGradients, self).__init__(T, X, session, keras_learning_phase) 291 | 292 | def run(self, xs, ys=None, batch_size=None): 293 | self._check_input_compatibility(xs, ys, batch_size) 294 | 295 | gradient = None 296 | for alpha in list(np.linspace(1. / self.steps, 1.0, self.steps)): 297 | xs_mod = [b + (x - b) * alpha for x, b in zip(xs, self.baseline)] if self.has_multiple_inputs \ 298 | else self.baseline + (xs - self.baseline) * alpha 299 | _attr = self._session_run(self.explain_symbolic(), xs_mod, ys, batch_size) 300 | if gradient is None: gradient = _attr 301 | else: gradient = [g + a for g, a in zip(gradient, _attr)] 302 | 303 | results = [g * (x - b) / self.steps for g, x, b in zip( 304 | gradient, 305 | xs if self.has_multiple_inputs else [xs], 306 | self.baseline if self.has_multiple_inputs else [self.baseline])] 307 | 308 | return results[0] if not self.has_multiple_inputs else results 309 | 310 | 311 | """ 312 | Layer-wise Relevance Propagation with epsilon rule 313 | http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140 314 | """ 315 | 316 | 317 | class EpsilonLRP(GradientBasedMethod): 318 | eps = None 319 | 320 | def __init__(self, T, X, session, keras_learning_phase, epsilon=1e-4): 321 | assert epsilon > 0.0, 'LRP epsilon must be greater than zero' 322 | global eps 323 | eps = epsilon 324 | super(EpsilonLRP, self).__init__(T, X, session, keras_learning_phase) 325 | 326 | def get_symbolic_attribution(self): 327 | return [g * x for g, x in zip( 328 | tf.gradients(self.T, self.X), 329 | self.X if self.has_multiple_inputs else [self.X])] 330 | 331 | @classmethod 332 | def nonlinearity_grad_override(cls, op, grad): 333 | output = op.outputs[0] 334 | input = op.inputs[0] 335 | return grad * output / (input + eps * 336 | tf.where(input >= 0, tf.ones_like(input), -1 * tf.ones_like(input))) 337 | 338 | """ 339 | DeepLIFT 340 | This reformulation only considers the "Rescale" rule 341 | https://arxiv.org/abs/1704.02685 342 | """ 343 | 344 | 345 | class DeepLIFTRescale(GradientBasedMethod): 346 | 347 | _deeplift_ref = {} 348 | 349 | def __init__(self, T, X, session, keras_learning_phase, baseline=None): 350 | self.baseline = baseline 351 | super(DeepLIFTRescale, self).__init__(T, X, session, keras_learning_phase) 352 | 353 | def get_symbolic_attribution(self): 354 | return [g * (x - b) for g, x, b in zip( 355 | tf.gradients(self.T, self.X), 356 | self.X if self.has_multiple_inputs else [self.X], 357 | self.baseline if self.has_multiple_inputs else [self.baseline])] 358 | 359 | @classmethod 360 | def nonlinearity_grad_override(cls, op, grad): 361 | output = op.outputs[0] 362 | input = op.inputs[0] 363 | ref_input = cls._deeplift_ref[op.name] 364 | ref_output = activation(op.type)(ref_input) 365 | delta_out = output - ref_output 366 | delta_in = input - ref_input 367 | instant_grad = activation(op.type)(0.5 * (ref_input + input)) 368 | return tf.where(tf.abs(delta_in) > 1e-5, grad * delta_out / delta_in, 369 | original_grad(instant_grad.op, grad)) 370 | 371 | def _init_references(self): 372 | # print ('DeepLIFT: computing references...') 373 | sys.stdout.flush() 374 | self._deeplift_ref.clear() 375 | ops = [] 376 | g = tf.get_default_graph() 377 | for op in g.get_operations(): 378 | if len(op.inputs) > 0 and not op.name.startswith('gradients'): 379 | if op.type in SUPPORTED_ACTIVATIONS: 380 | ops.append(op) 381 | YR = self._session_run([o.inputs[0] for o in ops], self.baseline) 382 | for (r, op) in zip(YR, ops): 383 | self._deeplift_ref[op.name] = r 384 | # print('DeepLIFT: references ready') 385 | sys.stdout.flush() 386 | 387 | 388 | """ 389 | Occlusion method 390 | Generalization of the grey-box method presented in https://arxiv.org/pdf/1311.2901.pdf 391 | This method performs a systematic perturbation of contiguous hyperpatches in the input, 392 | replacing each patch with a user-defined value (by default 0). 393 | 394 | window_shape : integer or tuple of length xs_ndim 395 | Defines the shape of the elementary n-dimensional orthotope the rolling window view. 396 | If an integer is given, the shape will be a hypercube of sidelength given by its value. 397 | 398 | step : integer or tuple of length xs_ndim 399 | Indicates step size at which extraction shall be performed. 400 | If integer is given, then the step is uniform in all dimensions. 401 | """ 402 | 403 | 404 | class Occlusion(PerturbationBasedMethod): 405 | 406 | def __init__(self, T, X, session, keras_learning_phase, window_shape=None, step=None): 407 | super(Occlusion, self).__init__(T, X, session, keras_learning_phase) 408 | if self.has_multiple_inputs: 409 | raise RuntimeError('Multiple inputs not yet supported for perturbation methods') 410 | 411 | input_shape = X[0].get_shape().as_list() 412 | if window_shape is not None: 413 | assert len(window_shape) == len(input_shape), \ 414 | 'window_shape must have length of input (%d)' % len(input_shape) 415 | self.window_shape = tuple(window_shape) 416 | else: 417 | self.window_shape = (1,) * len(input_shape) 418 | 419 | if step is not None: 420 | assert isinstance(step, int) or len(step) == len(input_shape), \ 421 | 'step must be integer or tuple with the length of input (%d)' % len(input_shape) 422 | self.step = step 423 | else: 424 | self.step = 1 425 | self.replace_value = 0.0 426 | logging.info('Input shape: %s; window_shape %s; step %s' % (input_shape, self.window_shape, self.step)) 427 | 428 | def run(self, xs, ys=None, batch_size=None): 429 | self._check_input_compatibility(xs, ys, batch_size) 430 | input_shape = xs.shape[1:] 431 | batch_size = xs.shape[0] 432 | total_dim = np.asscalar(np.prod(input_shape)) 433 | 434 | # Create mask 435 | index_matrix = np.arange(total_dim).reshape(input_shape) 436 | idx_patches = view_as_windows(index_matrix, self.window_shape, self.step).reshape((-1,) + self.window_shape) 437 | heatmap = np.zeros_like(xs, dtype=np.float32).reshape((-1), total_dim) 438 | w = np.zeros_like(heatmap) 439 | 440 | # Compute original output 441 | eval0 = self._session_run(self.T, xs, ys, batch_size) 442 | 443 | # Start perturbation loop 444 | for i, p in enumerate(idx_patches): 445 | mask = np.ones(input_shape).flatten() 446 | mask[p.flatten()] = self.replace_value 447 | masked_xs = mask.reshape((1,) + input_shape) * xs 448 | delta = eval0 - self._session_run(self.T, masked_xs, ys, batch_size) 449 | delta_aggregated = np.sum(delta.reshape((batch_size, -1)), -1, keepdims=True) 450 | heatmap[:, p.flatten()] += delta_aggregated 451 | w[:, p.flatten()] += p.size 452 | 453 | attribution = np.reshape(heatmap / w, xs.shape) 454 | if np.isnan(attribution).any(): 455 | warnings.warn('Attributions generated by Occlusion method contain nans, ' 456 | 'probably because window_shape and step do not allow to cover the all input.') 457 | return attribution 458 | 459 | 460 | """ 461 | Shapley Value sampling 462 | Computes approximate Shapley Values using "Polynomial calculation of the Shapley value based on sampling", 463 | Castro et al, 2009 (https://www.sciencedirect.com/science/article/pii/S0305054808000804) 464 | 465 | samples : integer (default 5) 466 | Defined the number of samples for each input feature. 467 | Notice that evaluating a model samples * n_input_feature times might take a while. 468 | 469 | sampling_dims : list of dimension indexes to run sampling on (feature dimensions). 470 | By default, all dimensions except the batch dimension will be sampled. 471 | For example, with a 4-D tensor that contains color images, single color channels are sampled. 472 | To sample pixels, instead, use sampling_dims=[1,2] 473 | """ 474 | 475 | 476 | class ShapleySampling(PerturbationBasedMethod): 477 | 478 | def __init__(self, T, X, session, keras_learning_phase, samples=5, sampling_dims=None): 479 | super(ShapleySampling, self).__init__(T, X, session, keras_learning_phase) 480 | if self.has_multiple_inputs: 481 | raise RuntimeError('Multiple inputs not yet supported for perturbation methods') 482 | dims = len(X.shape) 483 | if sampling_dims is not None: 484 | if not 0 < len(sampling_dims) <= (dims - 1): 485 | raise RuntimeError('sampling_dims must be a list containing 1 to %d elements' % (dims-1)) 486 | if 0 in sampling_dims: 487 | raise RuntimeError('Cannot sample batch dimension: remove 0 from sampling_dims') 488 | if any([x < 1 or x > dims-1 for x in sampling_dims]): 489 | raise RuntimeError('Invalid value in sampling_dims') 490 | else: 491 | sampling_dims = list(range(1, dims)) 492 | 493 | self.samples = samples 494 | self.sampling_dims = sampling_dims 495 | 496 | def run(self, xs, ys=None, batch_size=None): 497 | xs_shape = list(xs.shape) 498 | batch_size = xs.shape[0] 499 | n_features = int(np.asscalar(np.prod([xs.shape[i] for i in self.sampling_dims]))) 500 | result = np.zeros((xs_shape[0], n_features)) 501 | 502 | run_shape = list(xs_shape) # a copy 503 | run_shape = np.delete(run_shape, self.sampling_dims).tolist() 504 | run_shape.insert(1, -1) 505 | 506 | reconstruction_shape = [xs_shape[0]] 507 | for j in self.sampling_dims: 508 | reconstruction_shape.append(xs_shape[j]) 509 | 510 | for r in range(self.samples): 511 | p = np.random.permutation(n_features) 512 | x = xs.copy().reshape(run_shape) 513 | y = None 514 | for i in p: 515 | if y is None: 516 | y = self._session_run(self.T, x.reshape(xs_shape), ys, batch_size) 517 | x[:, i] = 0 518 | y0 = self._session_run(self.T, x.reshape(xs_shape), ys, batch_size) 519 | delta = y - y0 520 | delta_aggregated = np.sum(delta.reshape((batch_size, -1)), -1, keepdims=False) 521 | result[:, i] += delta_aggregated 522 | y = y0 523 | 524 | shapley = result / self.samples 525 | return shapley.reshape(reconstruction_shape) 526 | 527 | 528 | # ----------------------------------------------------------------------------- 529 | # END ATTRIBUTION METHODS 530 | # ----------------------------------------------------------------------------- 531 | 532 | 533 | attribution_methods = OrderedDict({ 534 | 'zero': (DummyZero, 0), 535 | 'saliency': (Saliency, 1), 536 | 'grad*input': (GradientXInput, 2), 537 | 'intgrad': (IntegratedGradients, 3), 538 | 'elrp': (EpsilonLRP, 4), 539 | 'deeplift': (DeepLIFTRescale, 5), 540 | 'occlusion': (Occlusion, 6), 541 | 'shapley_sampling': (ShapleySampling, 7) 542 | }) 543 | 544 | 545 | 546 | @ops.RegisterGradient("DeepExplainGrad") 547 | def deepexplain_grad(op, grad): 548 | global _ENABLED_METHOD_CLASS, _GRAD_OVERRIDE_CHECKFLAG 549 | _GRAD_OVERRIDE_CHECKFLAG = 1 550 | if _ENABLED_METHOD_CLASS is not None \ 551 | and issubclass(_ENABLED_METHOD_CLASS, GradientBasedMethod): 552 | return _ENABLED_METHOD_CLASS.nonlinearity_grad_override(op, grad) 553 | else: 554 | return original_grad(op, grad) 555 | 556 | 557 | class DeepExplain(object): 558 | 559 | def __init__(self, graph=None, session=tf.get_default_session()): 560 | self.method = None 561 | self.batch_size = None 562 | self.session = session 563 | self.graph = session.graph if graph is None else graph 564 | self.graph_context = self.graph.as_default() 565 | self.override_context = self.graph.gradient_override_map(self.get_override_map()) 566 | self.keras_phase_placeholder = None 567 | self.context_on = False 568 | if self.session is None: 569 | raise RuntimeError('DeepExplain: could not retrieve a session. Use DeepExplain(session=your_session).') 570 | 571 | def __enter__(self): 572 | # Override gradient of all ops created in context 573 | self.graph_context.__enter__() 574 | self.override_context.__enter__() 575 | self.context_on = True 576 | return self 577 | 578 | def __exit__(self, type, value, traceback): 579 | self.graph_context.__exit__(type, value, traceback) 580 | self.override_context.__exit__(type, value, traceback) 581 | self.context_on = False 582 | 583 | def get_explainer(self, method, T, X, **kwargs): 584 | if not self.context_on: 585 | raise RuntimeError('Explain can be called only within a DeepExplain context.') 586 | global _ENABLED_METHOD_CLASS, _GRAD_OVERRIDE_CHECKFLAG 587 | self.method = method 588 | if self.method in attribution_methods: 589 | method_class, method_flag = attribution_methods[self.method] 590 | else: 591 | raise RuntimeError('Method must be in %s' % list(attribution_methods.keys())) 592 | if isinstance(X, list): 593 | for x in X: 594 | if 'tensor' not in str(type(x)).lower(): 595 | raise RuntimeError('If a list, X must contain only Tensorflow Tensor objects') 596 | else: 597 | if 'tensor' not in str(type(X)).lower(): 598 | raise RuntimeError('X must be a Tensorflow Tensor object or a list of them') 599 | 600 | if 'tensor' not in str(type(T)).lower(): 601 | raise RuntimeError('T must be a Tensorflow Tensor object') 602 | 603 | logging.info('DeepExplain: running "%s" explanation method (%d)' % (self.method, method_flag)) 604 | self._check_ops() 605 | _GRAD_OVERRIDE_CHECKFLAG = 0 606 | 607 | _ENABLED_METHOD_CLASS = method_class 608 | method = _ENABLED_METHOD_CLASS(T, X, 609 | self.session, 610 | keras_learning_phase=self.keras_phase_placeholder, 611 | **kwargs) 612 | 613 | if issubclass(_ENABLED_METHOD_CLASS, GradientBasedMethod) and _GRAD_OVERRIDE_CHECKFLAG == 0: 614 | warnings.warn('DeepExplain detected you are trying to use an attribution method that requires ' 615 | 'gradient override but the original gradient was used instead. You might have forgot to ' 616 | '(re)create your graph within the DeepExlain context. Results are not reliable!') 617 | _ENABLED_METHOD_CLASS = None 618 | _GRAD_OVERRIDE_CHECKFLAG = 0 619 | self.keras_phase_placeholder = None 620 | return method 621 | 622 | def explain(self, method, T, X, xs, ys=None, batch_size=None, **kwargs): 623 | explainer = self.get_explainer(method, T, X, **kwargs) 624 | return explainer.run(xs, ys, batch_size) 625 | 626 | @staticmethod 627 | def get_override_map(): 628 | return dict((a, 'DeepExplainGrad') for a in SUPPORTED_ACTIVATIONS) 629 | 630 | def _check_ops(self): 631 | """ 632 | Heuristically check if any op is in the list of unsupported activation functions. 633 | This does not cover all cases where explanation methods would fail, and must be improved in the future. 634 | Also, check if the placeholder named 'keras_learning_phase' exists in the graph. This is used by Keras 635 | and needs to be passed in feed_dict. 636 | :return: 637 | """ 638 | g = tf.get_default_graph() 639 | for op in g.get_operations(): 640 | if len(op.inputs) > 0 and not op.name.startswith('gradients'): 641 | if op.type in UNSUPPORTED_ACTIVATIONS: 642 | warnings.warn('Detected unsupported activation (%s). ' 643 | 'This might lead to unexpected or wrong results.' % op.type) 644 | elif 'keras_learning_phase' in op.name: 645 | self.keras_phase_placeholder = op.outputs[0] 646 | 647 | 648 | 649 | 650 | 651 | -------------------------------------------------------------------------------- /deepexplain/tensorflow/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | # Some of the following functions for batch processing have been borrowed and adapter from Keras 5 | # https://github.com/keras-team/keras/blob/master/keras/utils/generic_utils.py 6 | # https://github.com/keras-team/keras/blob/master/keras/engine/training_utils.py 7 | 8 | 9 | def make_batches(size, batch_size): 10 | """Returns a list of batch indices (tuples of indices). 11 | # Arguments 12 | size: Integer, total size of the data to slice into batches. 13 | batch_size: Integer, batch size. 14 | # Returns 15 | A list of tuples of array indices. 16 | """ 17 | num_batches = (size + batch_size - 1) // batch_size # round up 18 | return [(i * batch_size, min(size, (i + 1) * batch_size)) 19 | for i in range(num_batches)] 20 | 21 | 22 | def to_list(x, allow_tuple=False): 23 | """Normalizes a list/tensor into a list. 24 | If a tensor is passed, we return 25 | a list of size 1 containing the tensor. 26 | # Arguments 27 | x: target object to be normalized. 28 | allow_tuple: If False and x is a tuple, 29 | it will be converted into a list 30 | with a single element (the tuple). 31 | Else converts the tuple to a list. 32 | # Returns 33 | A list. 34 | """ 35 | if isinstance(x, list): 36 | return x 37 | if allow_tuple and isinstance(x, tuple): 38 | return list(x) 39 | return [x] 40 | 41 | 42 | def unpack_singleton(x): 43 | """Gets the equivalent np-array if the iterable has only one value. 44 | Otherwise return the iterable. 45 | # Argument 46 | x: A list or tuple. 47 | # Returns 48 | The same iterable or the iterable converted to a np-array. 49 | """ 50 | if len(x) == 1: 51 | return np.array(x) 52 | return x 53 | 54 | 55 | def slice_arrays(arrays, start=None, stop=None): 56 | """Slices an array or list of arrays. 57 | """ 58 | if arrays is None: 59 | return [None] 60 | elif isinstance(arrays, list): 61 | return [None if x is None else x[start:stop] for x in arrays] 62 | else: 63 | return arrays[start:stop] 64 | 65 | 66 | def placeholder_from_data(numpy_array): 67 | if numpy_array is None: 68 | return None 69 | return tf.placeholder('float', [None,] + list(numpy_array.shape[1:])) -------------------------------------------------------------------------------- /deepexplain/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcoancona/DeepExplain/87fb43a13ac2a3b285a030b87df899cc40100c94/deepexplain/tests/__init__.py -------------------------------------------------------------------------------- /deepexplain/tests/test_tensorflow.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import pkg_resources 3 | import logging, warnings 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | from deepexplain.tensorflow import DeepExplain 8 | from deepexplain.tensorflow.methods import original_grad # test only 9 | 10 | activations = {'Relu': tf.nn.relu, 11 | 'Sigmoid': tf.nn.sigmoid, 12 | 'Softplus': tf.nn.softplus, 13 | 'Tanh': tf.nn.tanh} 14 | 15 | 16 | def simple_model(activation, session): 17 | X = tf.placeholder("float", [None, 2]) 18 | w1 = tf.Variable(initial_value=[[1.0, -1.0], [-1.0, 1.0]]) 19 | b1 = tf.Variable(initial_value=[1.5, -1.0]) 20 | w2 = tf.Variable(initial_value=[[1.1, 1.4], [-0.5, 1.0]]) 21 | b2 = tf.Variable(initial_value=[0.0, 2.0]) 22 | 23 | layer1 = activation(tf.matmul(X, w1) + b1) 24 | out = tf.matmul(layer1, w2) + b2 25 | session.run(tf.global_variables_initializer()) 26 | return X, out 27 | 28 | 29 | def simpler_model(session): 30 | """ 31 | Implements ReLU( ReLU(x1 - 1) - ReLU(x2) ) 32 | : 33 | """ 34 | X = tf.placeholder("float", [None, 2]) 35 | w1 = tf.Variable(initial_value=[[1.0, 0.0], [0.0, 1.0]], trainable=False) 36 | b1 = tf.Variable(initial_value=[-1.0, 0], trainable=False) 37 | w2 = tf.Variable(initial_value=[[1.0], [-1.0]], trainable=False) 38 | 39 | l1 = tf.nn.relu(tf.matmul(X, w1) + b1) 40 | out = tf.nn.relu(tf.matmul(l1, w2)) 41 | session.run(tf.global_variables_initializer()) 42 | return X, out 43 | 44 | 45 | def min_model(session): 46 | """ 47 | Implements min(xi) 48 | """ 49 | X = tf.placeholder("float", [None, 2]) 50 | out = tf.reduce_min(X,1) 51 | session.run(tf.global_variables_initializer()) 52 | return X, out 53 | 54 | 55 | def min_model_2d(session): 56 | """ 57 | Implements min(xi) 58 | """ 59 | X = tf.placeholder("float", [None, 2, 2]) 60 | out = tf.reduce_min(tf.reshape(X, (-1, 4)), 1) 61 | session.run(tf.global_variables_initializer()) 62 | return X, out 63 | 64 | 65 | def simple_multi_inputs_model(session): 66 | """ 67 | Implements Relu (3*x1|2*x2) | is a concat op 68 | : 69 | """ 70 | X1 = tf.placeholder("float", [None, 2]) 71 | X2 = tf.placeholder("float", [None, 2]) 72 | w1 = tf.Variable(initial_value=[[3.0, 0.0], [0.0, 3.0]], trainable=False) 73 | w2 = tf.Variable(initial_value=[[2.0, 0.0], [0.0, 2.0]], trainable=False) 74 | 75 | out = tf.nn.relu(tf.concat([X1*w1, X2*w2], 1)) 76 | session.run(tf.global_variables_initializer()) 77 | return X1, X2, out 78 | 79 | 80 | def simple_multi_inputs_model2(session): 81 | """ 82 | Implements Relu (3*x1|2*x2) | is a concat op 83 | : 84 | """ 85 | X1 = tf.placeholder("float", [None, 2]) 86 | X2 = tf.placeholder("float", [None, 1]) 87 | w1 = tf.Variable(initial_value=[3.0], trainable=False) 88 | w2 = tf.Variable(initial_value=[2.0], trainable=False) 89 | 90 | out = tf.nn.relu(tf.concat([X1*w1, X2*w2], 1)) 91 | session.run(tf.global_variables_initializer()) 92 | return X1, X2, out 93 | 94 | 95 | def train_xor(session): 96 | # Since setting seed is not always working on TF, initial weights values are hardcoded for reproducibility 97 | X = tf.placeholder("float", [None, 2]) 98 | Y = tf.placeholder("float", [None, 1]) 99 | w1 = tf.Variable(initial_value=[[0.10711301, -0.0987727], [-1.57625198, 1.34942603]]) 100 | b1 = tf.Variable(initial_value=[-0.30955192, -0.14483099]) 101 | w2 = tf.Variable(initial_value=[[0.69259691], [-0.16255915]]) 102 | b2 = tf.Variable(initial_value=[1.53952825]) 103 | 104 | l1 = tf.nn.relu(tf.matmul(X, w1) + b1) 105 | out = tf.matmul(l1, w2) + b2 106 | session.run(tf.global_variables_initializer()) 107 | 108 | # Define loss and optimizer 109 | loss = tf.reduce_mean(tf.losses.mean_squared_error(Y, out)) 110 | train_step = tf.train.GradientDescentOptimizer(0.05).minimize(loss) 111 | 112 | # Generate dataset random 113 | np.random.seed(10) 114 | x = np.random.randint(0, 2, size=(10, 2)) 115 | y = np.expand_dims(np.logical_or(x[:, 0], x[:, 1]), -1) 116 | l = None 117 | for _ in range(100): 118 | l, _, = session.run([loss, train_step], feed_dict={X: x, Y: y}) 119 | 120 | return np.abs(l - 0.1) < 0.01 121 | 122 | 123 | class TestDeepExplainGeneralTF(TestCase): 124 | 125 | def setUp(self): 126 | self.session = tf.Session() 127 | 128 | def tearDown(self): 129 | self.session.close() 130 | tf.reset_default_graph() 131 | 132 | def test_tf_available(self): 133 | try: 134 | pkg_resources.require('tensorflow>=1.0.0') 135 | except Exception: 136 | try: 137 | pkg_resources.require('tensorflow-gpu>=1.0.0') 138 | except Exception: 139 | self.fail("Tensorflow requirement not met") 140 | 141 | def test_simple_model(self): 142 | X, out = simple_model(tf.nn.relu, self.session) 143 | xi = np.array([[1, 0]]) 144 | r = self.session.run(out, {X: xi}) 145 | self.assertEqual(r.shape, xi.shape) 146 | np.testing.assert_equal(r[0], [2.75, 5.5]) 147 | 148 | def test_simpler_model(self): 149 | X, out = simpler_model(self.session) 150 | xi = np.array([[3.0, 1.0]]) 151 | r = self.session.run(out, {X: xi}) 152 | self.assertEqual(r.shape, (xi.shape[0], 1)) 153 | np.testing.assert_equal(r[0], [1.0]) 154 | 155 | def test_training(self): 156 | session = tf.Session() 157 | r = train_xor(session) 158 | self.assertTrue(r) 159 | 160 | def test_context(self): 161 | """ 162 | DeepExplain overrides nonlinearity gradient 163 | """ 164 | # No override 165 | from deepexplain.tensorflow import DeepExplain 166 | 167 | X = tf.placeholder("float", [None, 1]) 168 | for name in activations: 169 | x1 = activations[name](X) 170 | x1_g = tf.gradients(x1, X)[0] 171 | self.assertEqual(x1_g.op.type, '%sGrad' % name) 172 | 173 | # Override (note: that need to pass graph! Multiple thread testing??) 174 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 175 | for name in activations: 176 | # Gradients of nonlinear ops are overriden 177 | x2 = activations[name](X) 178 | self.assertEqual(x2.op.get_attr('_gradient_op_type').decode('utf-8'), 'DeepExplainGrad') 179 | 180 | def test_mismatch_input_lens(self): 181 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 182 | X1 = tf.placeholder("float", [None, 1]) 183 | X2 = tf.placeholder("float", [None, 1]) 184 | w1 = tf.Variable(initial_value=[[0.10711301]]) 185 | w2 = tf.Variable(initial_value=[[0.69259691]]) 186 | out = tf.matmul(X1, w1) + tf.matmul(X2, w2) 187 | 188 | self.session.run(tf.global_variables_initializer()) 189 | with self.assertRaises(RuntimeError) as cm: 190 | de.explain('grad*input', out, [X1, X2], [[1], [2], [3]]) 191 | self.assertIn( 192 | 'List of input tensors and input data have different lengths', 193 | str(cm.exception) 194 | ) 195 | 196 | def test_multiple_inputs(self): 197 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 198 | X1 = tf.placeholder("float", [None, 1]) 199 | X2 = tf.placeholder("float", [None, 1]) 200 | w1 = tf.Variable(initial_value=[[10.0]]) 201 | w2 = tf.Variable(initial_value=[[10.0]]) 202 | out = tf.matmul(X1, w1) + tf.matmul(X2, w2) 203 | 204 | self.session.run(tf.global_variables_initializer()) 205 | attributions = de.explain('grad*input', out, [X1, X2], [[[2]], [[3]]]) 206 | self.assertEqual(len(attributions), 2) 207 | self.assertEqual(attributions[0][0], 20.0) 208 | self.assertEqual(attributions[1][0], 30.0) 209 | 210 | def test_supported_activations(self): 211 | X = tf.placeholder("float", [None, 3]) 212 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 213 | xi = [[-1, 0, 1]] 214 | Y = tf.nn.relu(X) 215 | r = self.session.run(Y, {X: xi}) 216 | np.testing.assert_almost_equal(r[0], [0, 0, 1], 7) 217 | Y = tf.nn.elu(X) 218 | r = self.session.run(Y, {X: xi}) 219 | np.testing.assert_almost_equal(r[0], [-0.632120558, 0, 1], 7) 220 | Y = tf.nn.sigmoid(X) 221 | r = self.session.run(Y, {X: xi}) 222 | np.testing.assert_almost_equal(r[0], [0.268941421, 0.5, 0.731058578], 7) 223 | Y = tf.nn.tanh(X) 224 | r = self.session.run(Y, {X: xi}) 225 | np.testing.assert_almost_equal(r[0], [-0.761594155, 0, 0.761594155], 7) 226 | Y = tf.nn.softplus(X) 227 | r = self.session.run(Y, {X: xi}) 228 | np.testing.assert_almost_equal(r[0], [0.313261687, 0.693147181, 1.31326168], 7) 229 | 230 | def test_original_grad(self): 231 | X = tf.placeholder("float", [None, 3]) 232 | for name in activations: 233 | Y = activations[name](X) 234 | grad = original_grad(Y.op, tf.ones_like(X)) 235 | self.assertTrue('Tensor' in str(type(grad))) 236 | 237 | def test_warning_unsupported_activations(self): 238 | with warnings.catch_warnings(record=True) as w: 239 | warnings.simplefilter("always") 240 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 241 | X = tf.placeholder("float", [None, 3]) 242 | Y = tf.nn.relu6(X) # < an unsupported activation 243 | xi = [[-1, 0, 1]] 244 | de.explain('elrp', Y, X, xi) 245 | assert any(["unsupported activation" in str(wi.message) for wi in w]) 246 | 247 | def test_override_as_default(self): 248 | """ 249 | In DeepExplain context, nonlinearities behave as default, including training time 250 | """ 251 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 252 | r = train_xor(self.session) 253 | self.assertTrue(r) 254 | 255 | def test_explain_not_in_context(self): 256 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 257 | pass 258 | with self.assertRaises(RuntimeError) as cm: 259 | de.explain('grad*input', None, None, None) 260 | self.assertEqual( 261 | 'Explain can be called only within a DeepExplain context.', 262 | str(cm.exception) 263 | ) 264 | 265 | def test_invalid_method(self): 266 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 267 | with self.assertRaises(RuntimeError) as cm: 268 | de.explain('invalid', None, None, None) 269 | self.assertIn('Method must be in', 270 | str(cm.exception) 271 | ) 272 | 273 | # Failing on Python 2 on Travis !? But the warning is actually there 274 | # def test_gradient_was_not_overridden(self): 275 | # X = tf.placeholder("float", [None, 3]) 276 | # Y = tf.nn.relu(X) 277 | # with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 278 | # with warnings.catch_warnings(record=True) as w: 279 | # warnings.simplefilter("always") 280 | # de.explain('grad*input', Y, X, [[0, 0, 0]]) 281 | # assert any(["DeepExplain detected" in str(wi.message) for wi in w]) 282 | 283 | def test_T_is_tensor(self): 284 | X = tf.placeholder("float", [None, 3]) 285 | Y = tf.nn.relu(X) 286 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 287 | with self.assertRaises(RuntimeError) as cm: 288 | de.explain('grad*input', [Y], X, [[0, 0, 0]]) 289 | self.assertIn('T must be a Tensorflow Tensor object', 290 | str(cm.exception) 291 | ) 292 | 293 | def test_X_is_tensor(self): 294 | X = tf.placeholder("float", [None, 3]) 295 | Y = tf.nn.relu(X) 296 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 297 | with self.assertRaises(RuntimeError) as cm: 298 | de.explain('grad*input', Y, np.eye(3), [[0, 0, 0]]) 299 | self.assertIn('Tensorflow Tensor object', 300 | str(cm.exception) 301 | ) 302 | 303 | def test_all_in_X_are_tensor(self): 304 | X = tf.placeholder("float", [None, 3]) 305 | Y = tf.nn.relu(X) 306 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 307 | with self.assertRaises(RuntimeError) as cm: 308 | de.explain('grad*input', Y, [X, np.eye(3)], [[0, 0, 0]]) 309 | self.assertIn('Tensorflow Tensor object', 310 | str(cm.exception) 311 | ) 312 | 313 | def test_X_has_compatible_batch_dim(self): 314 | X = tf.placeholder("float", [10, 3]) 315 | Y = tf.nn.relu(X) 316 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 317 | with self.assertRaises(RuntimeError) as cm: 318 | de.explain('grad*input', Y, X, [[0, 0, 0]], batch_size=2) 319 | self.assertIn('the first dimension of the input tensor', 320 | str(cm.exception) 321 | ) 322 | 323 | def test_T_has_compatible_batch_dim(self): 324 | X, out = simpler_model(self.session) 325 | xi = np.array([[-10, -5]]).repeat(50, 0) 326 | Y = out * np.expand_dims(np.array(range(50)), -1) 327 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 328 | with self.assertRaises(RuntimeError) as cm: 329 | de.explain('saliency', Y, X, xi, batch_size=10) 330 | self.assertIn('the first dimension of the target tensor', 331 | str(cm.exception) 332 | ) 333 | 334 | def test_use_of_target_weights(self): 335 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 336 | X, T = simple_model(tf.identity, self.session) 337 | xi = np.array([[1, 0]]) 338 | yi1 = np.array([[1, 0]]) 339 | yi2 = np.array([[0, 1]]) 340 | yi3 = np.array([[1, 1]]) 341 | yi4 = np.array([[0, 0]]) 342 | 343 | a1 = de.explain('saliency', T, X, xi, ys=yi1) 344 | a2 = de.explain('saliency', T, X, xi, ys=yi2) 345 | a3 = de.explain('saliency', T, X, xi, ys=yi3) 346 | a4 = de.explain('saliency', T, X, xi, ys=yi4) 347 | np.testing.assert_almost_equal(a1+a2, a3, 10) 348 | np.testing.assert_almost_equal(a4, np.array([[0.0, 0.0]]), 10) 349 | 350 | def test_use_of_target_weights_batch(self): 351 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 352 | X, T = simple_model(tf.identity, self.session) 353 | xi = np.array([[1, 0]]).repeat(20, 0) 354 | yi1 = np.array([[1, 0]]).repeat(20, 0) 355 | yi2 = np.array([[0, 1]]).repeat(20, 0) 356 | yi3 = np.array([[1, 1]]).repeat(20, 0) 357 | yi4 = np.array([[0, 0]]).repeat(20, 0) 358 | 359 | a1 = de.explain('saliency', T, X, xi, ys=yi1, batch_size=5) 360 | a2 = de.explain('saliency', T, X, xi, ys=yi2, batch_size=5) 361 | a3 = de.explain('saliency', T, X, xi, ys=yi3, batch_size=5) 362 | a4 = de.explain('saliency', T, X, xi, ys=yi4, batch_size=5) 363 | np.testing.assert_almost_equal(a1+a2, a3, 10) 364 | np.testing.assert_almost_equal(a4, np.array([[0.0, 0.0]]).repeat(20, 0), 10) 365 | 366 | def test_wrong_weight_len(self): 367 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 368 | X, T = simple_model(tf.identity, self.session) 369 | xi = np.array([[1, 0]]).repeat(20, 0) 370 | yi1 = np.array([[1, 0]]) # < not same len as xi 371 | 372 | with self.assertRaises(RuntimeError) as cm: 373 | de.explain('saliency', T, X, xi, ys=yi1, batch_size=5) 374 | self.assertIn('the number of elements in ys must equal ', 375 | str(cm.exception) 376 | ) 377 | 378 | def test_explainer_api_memory(self): 379 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 380 | X, T = simple_model(tf.identity, self.session) 381 | xi = np.array([[1, 0]]).repeat(20, 0) 382 | prev_ops_count = None 383 | explainer = de.get_explainer('saliency', T, X) 384 | for i in range(10): 385 | explainer.run(xi) 386 | # de.explain('saliency', T, X, xi) < this would fail instead 387 | ops_count = len([n.name for n in tf.get_default_graph().as_graph_def().node]) 388 | if prev_ops_count is None: 389 | prev_ops_count = ops_count 390 | else: 391 | self.assertEquals(prev_ops_count, ops_count) 392 | 393 | 394 | class TestDummyMethod(TestCase): 395 | 396 | def setUp(self): 397 | self.session = tf.Session() 398 | 399 | def tearDown(self): 400 | self.session.close() 401 | tf.reset_default_graph() 402 | 403 | def test_dummy_zero(self): 404 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 405 | X, out = simple_model(tf.nn.sigmoid, self.session) 406 | xi = np.array([[10, -10]]) 407 | attributions = de.explain('zero', out, X, xi) 408 | self.assertEqual(attributions.shape, xi.shape) 409 | np.testing.assert_almost_equal(attributions[0], [0.0, 0.0], 10) 410 | 411 | def test_gradient_restored(self): 412 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 413 | X, out = simple_model(tf.nn.sigmoid, self.session) 414 | xi = np.array([[10, -10]]) 415 | de.explain('zero', out, X, xi) 416 | r = train_xor(self.session) 417 | self.assertTrue(r) 418 | 419 | 420 | class TestSaliencyMethod(TestCase): 421 | 422 | def setUp(self): 423 | self.session = tf.Session() 424 | 425 | def tearDown(self): 426 | self.session.close() 427 | tf.reset_default_graph() 428 | 429 | def test_saliency_method(self): 430 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 431 | X, out = simpler_model( self.session) 432 | xi = np.array([[-10, -5], [3, 1]]) 433 | attributions = de.explain('saliency', out, X, xi) 434 | self.assertEqual(attributions.shape, xi.shape) 435 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [1.0, 1.0]], 10) 436 | 437 | def test_multiple_inputs(self): 438 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 439 | X1, X2, out = simple_multi_inputs_model(self.session) 440 | xi = [np.array([[-10, -5]]), np.array([[3, 1]])] 441 | attributions = de.explain('saliency', out, [X1, X2], xi) 442 | self.assertEqual(len(attributions), len(xi)) 443 | np.testing.assert_almost_equal(attributions[0], [[0.0, 0.0]], 10) 444 | np.testing.assert_almost_equal(attributions[1], [[2.0, 2.0]], 10) 445 | 446 | def test_multiple_inputs_explainer_api(self): 447 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 448 | X1, X2, out = simple_multi_inputs_model(self.session) 449 | xi = [np.array([[-10, -5]]), np.array([[3, 1]])] 450 | explainer = de.get_explainer('saliency', out, [X1, X2]) 451 | attributions = explainer.run(xi) 452 | self.assertEqual(len(attributions), len(xi)) 453 | np.testing.assert_almost_equal(attributions[0], [[0.0, 0.0]], 10) 454 | np.testing.assert_almost_equal(attributions[1], [[2.0, 2.0]], 10) 455 | 456 | 457 | class TestGradInputMethod(TestCase): 458 | 459 | def setUp(self): 460 | self.session = tf.Session() 461 | 462 | def tearDown(self): 463 | self.session.close() 464 | tf.reset_default_graph() 465 | 466 | def test_saliency_method(self): 467 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 468 | X, out = simpler_model(self.session) 469 | xi = np.array([[-10, -5], [3, 1]]) 470 | attributions = de.explain('grad*input', out, X, xi) 471 | self.assertEqual(attributions.shape, xi.shape) 472 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [3.0, -1.0]], 10) 473 | 474 | def test_saliency_method_explainer_api(self): 475 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 476 | X, out = simpler_model(self.session) 477 | xi = np.array([[-10, -5], [3, 1]]) 478 | explainer = de.get_explainer('grad*input', out, X) 479 | attributions = explainer.run(xi) 480 | self.assertEqual(attributions.shape, xi.shape) 481 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [3.0, -1.0]], 10) 482 | 483 | def test_multiple_inputs(self): 484 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 485 | X1, X2, out = simple_multi_inputs_model(self.session) 486 | xi = [np.array([[-10, -5]]), np.array([[3, 1]])] 487 | attributions = de.explain('grad*input', out, [X1, X2], xi) 488 | self.assertEqual(len(attributions), len(xi)) 489 | np.testing.assert_almost_equal(attributions[0], [[0.0, 0.0]], 10) 490 | np.testing.assert_almost_equal(attributions[1], [[6.0, 2.0]], 10) 491 | 492 | 493 | class TestIntegratedGradientsMethod(TestCase): 494 | 495 | def setUp(self): 496 | self.session = tf.Session() 497 | 498 | def tearDown(self): 499 | self.session.close() 500 | tf.reset_default_graph() 501 | 502 | def test_int_grad(self): 503 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 504 | X, out = simpler_model( self.session) 505 | xi = np.array([[-10, -5], [3, 1]]) 506 | attributions = de.explain('intgrad', out, X, xi) 507 | self.assertEqual(attributions.shape, xi.shape) 508 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [1.5, -0.5]], 10) 509 | 510 | def test_int_grad_higher_precision(self): 511 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 512 | X, out = simpler_model( self.session) 513 | xi = np.array([[-10, -5], [3, 1]]) 514 | attributions = de.explain('intgrad', out, X, xi, steps=500) 515 | self.assertEqual(attributions.shape, xi.shape) 516 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [1.5, -0.5]], 10) 517 | 518 | def test_int_grad_baseline(self): 519 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 520 | X, out = simpler_model(self.session) 521 | xi = np.array([[2, 0]]) 522 | attributions = de.explain('intgrad', out, X, xi, baseline=np.array([1, 0])) 523 | self.assertEqual(attributions.shape, xi.shape) 524 | np.testing.assert_almost_equal(attributions, [[1.0, 0.0]], 10) 525 | 526 | def test_int_grad_baseline_explainer_api(self): 527 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 528 | X, out = simpler_model(self.session) 529 | xi = np.array([[2, 0]]) 530 | explainer = de.get_explainer('intgrad', out, X, baseline=np.array([1, 0])) 531 | attributions = explainer.run(xi) 532 | self.assertEqual(attributions.shape, xi.shape) 533 | np.testing.assert_almost_equal(attributions, [[1.0, 0.0]], 10) 534 | 535 | def test_int_grad_baseline_2(self): 536 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 537 | X, out = simpler_model(self.session) 538 | xi = np.array([[2, 0], [3, 0]]) 539 | attributions = de.explain('intgrad', out, X, xi, baseline=np.array([1, 0])) 540 | self.assertEqual(attributions.shape, xi.shape) 541 | np.testing.assert_almost_equal(attributions, [[1.0, 0.0], [2.0, 0.0]], 10) 542 | 543 | def test_multiple_inputs(self): 544 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 545 | X1, X2, out = simple_multi_inputs_model(self.session) 546 | xi = [np.array([[-10, -5]]), np.array([[3, 1]])] 547 | attributions = de.explain('intgrad', out, [X1, X2], xi) 548 | self.assertEqual(len(attributions), len(xi)) 549 | np.testing.assert_almost_equal(attributions[0], [[0.0, 0.0]], 10) 550 | np.testing.assert_almost_equal(attributions[1], [[6.0, 2.0]], 10) 551 | 552 | def test_multiple_inputs_different_sizes(self): 553 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 554 | X1, X2, out = simple_multi_inputs_model2(self.session) 555 | xi = [np.array([[-10, -5]]), np.array([[3]])] 556 | attributions = de.explain('intgrad', out, [X1, X2], xi) 557 | self.assertEqual(len(attributions), len(xi)) 558 | np.testing.assert_almost_equal(attributions[0], [[0.0, 0.0]], 10) 559 | np.testing.assert_almost_equal(attributions[1], [[6]], 10) 560 | 561 | def test_intgrad_targeting_equivalence(self): 562 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 563 | X, out = simple_model(tf.nn.relu, self.session) 564 | xi = np.array([[5, 3]]) 565 | self.assertEqual(out.shape[1], 2) 566 | a1 = de.explain('intgrad', out * np.array([[1, 0]]), X, xi) 567 | b1 = de.explain('intgrad', out * np.array([[0, 1]]), X, xi) 568 | a2 = de.explain('intgrad', out, X, xi, ys=np.array([[1, 0]])) 569 | b2 = de.explain('intgrad', out, X, xi, ys=np.array([[0, 1]])) 570 | np.testing.assert_almost_equal(a1, a2, 1) 571 | np.testing.assert_almost_equal(b1, b2, 1) 572 | 573 | 574 | class TestEpsilonLRPMethod(TestCase): 575 | 576 | def setUp(self): 577 | self.session = tf.Session() 578 | 579 | def tearDown(self): 580 | self.session.close() 581 | tf.reset_default_graph() 582 | 583 | def test_elrp_method(self): 584 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 585 | X, out = simpler_model( self.session) 586 | xi = np.array([[-10, -5], [3, 1]]) 587 | attributions = de.explain('elrp', out, X, xi) 588 | self.assertEqual(attributions.shape, xi.shape) 589 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [3.0, -1.0]], 3) 590 | 591 | def test_elrp_epsilon(self): 592 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 593 | X, out = simpler_model( self.session) 594 | xi = np.array([[-10, -5], [3, 1]]) 595 | attributions = de.explain('elrp', out, X, xi, epsilon=1e-9) 596 | self.assertEqual(attributions.shape, xi.shape) 597 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [3.0, -1.0]], 7) 598 | 599 | def test_elrp_epsilon_explainer_api(self): 600 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 601 | X, out = simpler_model( self.session) 602 | xi = np.array([[-10, -5], [3, 1]]) 603 | explainer = de.get_explainer('elrp', out, X, epsilon=1e-9) 604 | attributions = explainer.run(xi) 605 | self.assertEqual(attributions.shape, xi.shape) 606 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [3.0, -1.0]], 7) 607 | 608 | def test_elrp_zero_epsilon(self): 609 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 610 | X, out = simpler_model( self.session) 611 | xi = np.array([[-10, -5], [3, 1]]) 612 | with self.assertRaises(AssertionError): 613 | de.explain('elrp', out, X, xi, epsilon=0) 614 | 615 | def test_multiple_inputs(self): 616 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 617 | X1, X2, out = simple_multi_inputs_model(self.session) 618 | xi = [np.array([[-10, -5]]), np.array([[3, 1]])] 619 | attributions = de.explain('elrp', out, [X1, X2], xi, epsilon=1e-9) 620 | self.assertEqual(len(attributions), len(xi)) 621 | np.testing.assert_almost_equal(attributions[0], [[0.0, 0.0]], 7) 622 | np.testing.assert_almost_equal(attributions[1], [[6.0, 2.0]], 7) 623 | 624 | def test_elrp_targeting_equivalence(self): 625 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 626 | X, out = simple_model(tf.nn.relu, self.session) 627 | xi = np.array([[5, 3]]) 628 | self.assertEqual(out.shape[1], 2) 629 | a1 = de.explain('elrp', out * np.array([[1, 0]]), X, xi) 630 | b1 = de.explain('elrp', out * np.array([[0, 1]]), X, xi) 631 | a2 = de.explain('elrp', out, X, xi, ys=np.array([[1, 0]])) 632 | b2 = de.explain('elrp', out, X, xi, ys=np.array([[0, 1]])) 633 | np.testing.assert_almost_equal(a1, a2, 1) 634 | np.testing.assert_almost_equal(b1, b2, 1) 635 | 636 | 637 | class TestDeepLIFTMethod(TestCase): 638 | 639 | def setUp(self): 640 | self.session = tf.Session() 641 | 642 | def tearDown(self): 643 | self.session.close() 644 | tf.reset_default_graph() 645 | 646 | def test_deeplift(self): 647 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 648 | X, out = simpler_model( self.session) 649 | xi = np.array([[-10, -5], [3, 1]]) 650 | attributions = de.explain('deeplift', out, X, xi) 651 | self.assertEqual(attributions.shape, xi.shape) 652 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [2.0, -1.0]], 10) 653 | 654 | def test_deeplift_batches(self): 655 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 656 | X, out = simpler_model( self.session) 657 | xi = np.array([[-10, -5], [3, 1]]) 658 | xi = np.repeat(xi, 25, 0) 659 | self.assertEqual(xi.shape[0], 50) 660 | attributions = de.explain('deeplift', out, X, xi, batch_size=32) 661 | self.assertEqual(attributions.shape, xi.shape) 662 | np.testing.assert_almost_equal(attributions, np.repeat([[0.0, 0.0], [2.0, -1.0]], 25, 0), 10) 663 | 664 | def test_deeplift_batches_explainer_api(self): 665 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 666 | X, out = simpler_model( self.session) 667 | xi = np.array([[-10, -5], [3, 1]]) 668 | xi = np.repeat(xi, 25, 0) 669 | self.assertEqual(xi.shape[0], 50) 670 | explaoiner = de.get_explainer('deeplift', out, X) 671 | attributions = explaoiner.run(xi, batch_size=32) 672 | self.assertEqual(attributions.shape, xi.shape) 673 | np.testing.assert_almost_equal(attributions, np.repeat([[0.0, 0.0], [2.0, -1.0]], 25, 0), 10) 674 | 675 | def test_deeplift_baseline(self): 676 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 677 | X, out = simpler_model(self.session) 678 | xi = np.array([[3, 1]]) 679 | attributions = de.explain('deeplift', out, X, xi, baseline=xi[0]) 680 | self.assertEqual(attributions.shape, xi.shape) 681 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0]], 5) 682 | 683 | def test_multiple_inputs(self): 684 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 685 | X1, X2, out = simple_multi_inputs_model(self.session) 686 | xi = [np.array([[-10, -5]]), np.array([[3, 1]])] 687 | attributions = de.explain('deeplift', out, [X1, X2], xi) 688 | self.assertEqual(len(attributions), len(xi)) 689 | np.testing.assert_almost_equal(attributions[0], [[0.0, 0.0]], 7) 690 | np.testing.assert_almost_equal(attributions[1], [[6.0, 2.0]], 7) 691 | 692 | def test_multiple_inputs_different_sizes(self): 693 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 694 | X1, X2, out = simple_multi_inputs_model2(self.session) 695 | xi = [np.array([[-10, -5]]), np.array([[3]])] 696 | attributions = de.explain('deeplift', out, [X1, X2], xi) 697 | self.assertEqual(len(attributions), len(xi)) 698 | np.testing.assert_almost_equal(attributions[0], [[0.0, 0.0]], 10) 699 | np.testing.assert_almost_equal(attributions[1], [[6]], 10) 700 | 701 | def test_multiple_inputs_different_sizes_batches(self): 702 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 703 | X1, X2, out = simple_multi_inputs_model2(self.session) 704 | xi = [np.array([[-10, -5]]).repeat(50, 0), np.array([[3]]).repeat(50, 0)] 705 | attributions = de.explain('deeplift', out, [X1, X2], xi, batch_size=32) 706 | self.assertEqual(len(attributions), len(xi)) 707 | np.testing.assert_almost_equal(attributions[0], np.repeat([[0.0, 0.0]], 50, 0), 10) 708 | np.testing.assert_almost_equal(attributions[1], np.repeat([[6]], 50, 0), 10) 709 | 710 | def test_multiple_inputs_different_sizes_batches_disable(self): 711 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 712 | X1, X2, out = simple_multi_inputs_model2(self.session) 713 | xi = [np.array([[-10, -5]]).repeat(50, 0), np.array([[3]]).repeat(50, 0)] 714 | attributions = de.explain('deeplift', out, [X1, X2], xi, batch_size=None) 715 | self.assertEqual(len(attributions), len(xi)) 716 | np.testing.assert_almost_equal(attributions[0], np.repeat([[0.0, 0.0]], 50, 0), 10) 717 | np.testing.assert_almost_equal(attributions[1], np.repeat([[6]], 50, 0), 10) 718 | 719 | def test_deeplift_targeting_equivalence(self): 720 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 721 | X, out = simple_model(tf.nn.relu, self.session) 722 | xi = np.array([[5, 3]]) 723 | self.assertEqual(out.shape[1], 2) 724 | a1 = de.explain('deeplift', out * np.array([[1, 0]]), X, xi) 725 | b1 = de.explain('deeplift', out * np.array([[0, 1]]), X, xi) 726 | a2 = de.explain('deeplift', out, X, xi, ys=np.array([[1, 0]])) 727 | b2 = de.explain('deeplift', out, X, xi, ys=np.array([[0, 1]])) 728 | np.testing.assert_almost_equal(a1, a2, 1) 729 | np.testing.assert_almost_equal(b1, b2, 1) 730 | 731 | 732 | class TestOcclusionMethod(TestCase): 733 | 734 | def setUp(self): 735 | self.session = tf.Session() 736 | 737 | def tearDown(self): 738 | self.session.close() 739 | tf.reset_default_graph() 740 | 741 | def test_occlusion(self): 742 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 743 | X, out = simpler_model( self.session) 744 | xi = np.array([[-10, -5], [3, 1]]) 745 | attributions = de.explain('occlusion', out, X, xi) 746 | self.assertEqual(attributions.shape, xi.shape) 747 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [1.0, -1.0]], 10) 748 | 749 | def test_occlusion_explainer_api(self): 750 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 751 | X, out = simpler_model( self.session) 752 | xi = np.array([[-10, -5], [3, 1]]) 753 | explainer = de.get_explainer('occlusion', out, X) 754 | attributions = explainer.run(xi) 755 | self.assertEqual(attributions.shape, xi.shape) 756 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [1.0, -1.0]], 10) 757 | 758 | def test_occlusion_batches(self): 759 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 760 | X, out = simpler_model( self.session) 761 | xi = np.array([[-10, -5], [3, 1]]).repeat(10, 0) 762 | attributions = de.explain('occlusion', out, X, xi, batch_size=5) 763 | self.assertEqual(attributions.shape, xi.shape) 764 | np.testing.assert_almost_equal(attributions, np.repeat([[0.0, 0.0], [1.0, -1.0]], 10, 0), 10) 765 | 766 | def test_window_shape(self): 767 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 768 | X, out = simpler_model(self.session) 769 | xi = np.array([[-10, -5], [3, 1]]) 770 | attributions = de.explain('occlusion', out, X, xi, window_shape=(2,)) 771 | self.assertEqual(attributions.shape, xi.shape) 772 | np.testing.assert_almost_equal(attributions, [[0.0, 0.0], [0.5, 0.5]], 10) 773 | 774 | def test_nan_warning(self): 775 | with warnings.catch_warnings(record=True) as w: 776 | warnings.simplefilter("always") 777 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 778 | X, out = simpler_model(self.session) 779 | xi = np.array([[-10, -5], [3, 1]]) 780 | attributions = de.explain('occlusion', out, X, xi, step=2) 781 | self.assertEqual(attributions.shape, xi.shape) 782 | np.testing.assert_almost_equal(attributions, [[0.0, np.nan], [1.0, np.nan]], 10) 783 | assert any(["nans" in str(wi.message) for wi in w]) 784 | 785 | def test_multiple_inputs_error(self): 786 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 787 | X1, X2, out = simple_multi_inputs_model(self.session) 788 | xi = [np.array([[-10, -5]]), np.array([[3, 1]])] 789 | with self.assertRaises(RuntimeError) as cm: 790 | de.explain('occlusion', out, [X1, X2], xi) 791 | self.assertIn('not yet supported', str(cm.exception)) 792 | 793 | def test_occlusion_targeting_equivalence(self): 794 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 795 | X, out = simple_model(tf.nn.relu, self.session) 796 | xi = np.array([[5, 3]]) 797 | self.assertEqual(out.shape[1], 2) 798 | a1 = de.explain('occlusion', out * np.array([[1, 0]]), X, xi) 799 | b1 = de.explain('occlusion', out * np.array([[0, 1]]), X, xi) 800 | a2 = de.explain('occlusion', out, X, xi, ys=np.array([[1, 0]])) 801 | b2 = de.explain('occlusion', out, X, xi, ys=np.array([[0, 1]])) 802 | np.testing.assert_almost_equal(a1, a2, 10) 803 | np.testing.assert_almost_equal(b1, b2, 10) 804 | 805 | 806 | class TestShapleySamplingMethod(TestCase): 807 | 808 | def setUp(self): 809 | self.session = tf.Session() 810 | 811 | def tearDown(self): 812 | self.session.close() 813 | tf.reset_default_graph() 814 | 815 | def test_shapley_sampling(self): 816 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 817 | X, out = min_model(self.session) 818 | xi = np.array([[2, -2], [4, 2]]) 819 | attributions = de.explain('shapley_sampling', out, X, xi, samples=300) 820 | self.assertEqual(attributions.shape, xi.shape) 821 | np.testing.assert_almost_equal(attributions, [[0.0, -2.0], [1.0, 1.0]], 1) 822 | 823 | def test_shapley_sampling_batches(self): 824 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 825 | X, out = min_model(self.session) 826 | xi = np.array([[2, -2], [4, 2]]).repeat(20, 0) 827 | attributions = de.explain('shapley_sampling', out, X, xi, samples=300, batch_size=5) 828 | self.assertEqual(attributions.shape, xi.shape) 829 | np.testing.assert_almost_equal(attributions, np.repeat([[0.0, -2.0], [1.0, 1.0]], 20, 0), 1) 830 | 831 | def test_shapley_sampling_batches_explainer_api(self): 832 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 833 | X, out = min_model(self.session) 834 | xi = np.array([[2, -2], [4, 2]]).repeat(20, 0) 835 | explainer = de.get_explainer('shapley_sampling', out, X, samples=300) 836 | attributions = explainer.run(xi, batch_size=5) 837 | self.assertEqual(attributions.shape, xi.shape) 838 | np.testing.assert_almost_equal(attributions, np.repeat([[0.0, -2.0], [1.0, 1.0]], 20, 0), 1) 839 | 840 | def test_shapley_sampling_dims(self): 841 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 842 | X, out = min_model_2d(self.session) 843 | xi = np.array([[[-1, 4], [-2, 1]]]) 844 | attributions = de.explain('shapley_sampling', out, X, xi, samples=300, sampling_dims=[1]) 845 | self.assertEqual(attributions.shape, (1, 2)) 846 | np.testing.assert_almost_equal(attributions, [[-.5, -1.5]], 1) 847 | 848 | def test_multiple_inputs_error(self): 849 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 850 | X1, X2, out = simple_multi_inputs_model(self.session) 851 | xi = [np.array([[-10, -5]]), np.array([[3, 1]])] 852 | with self.assertRaises(RuntimeError) as cm: 853 | de.explain('shapley_sampling', out, [X1, X2], xi) 854 | self.assertIn('not yet supported', str(cm.exception)) 855 | 856 | def test_shapley_targeting_equivalence(self): 857 | with DeepExplain(graph=tf.get_default_graph(), session=self.session) as de: 858 | X, out = simple_model(tf.identity, self.session) 859 | xi = np.array([[5, 3]]) 860 | self.assertEqual(out.shape[1], 2) 861 | np.random.seed(10) 862 | a1 = de.explain('shapley_sampling', out * np.array([[1, 0]]), X, xi, samples=10) 863 | np.random.seed(10) 864 | b1 = de.explain('shapley_sampling', out * np.array([[0, 1]]), X, xi, samples=10) 865 | np.random.seed(10) 866 | a2 = de.explain('shapley_sampling', out, X, xi, ys=np.array([[1, 0]]), samples=10) 867 | np.random.seed(10) 868 | b2 = de.explain('shapley_sampling', out, X, xi, ys=np.array([[0, 1]]), samples=10) 869 | np.testing.assert_almost_equal(a1, a2, 3) 870 | np.testing.assert_almost_equal(b1, b2, 3) -------------------------------------------------------------------------------- /docs/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcoancona/DeepExplain/87fb43a13ac2a3b285a030b87df899cc40100c94/docs/comparison.png -------------------------------------------------------------------------------- /examples/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcoancona/DeepExplain/87fb43a13ac2a3b285a030b87df899cc40100c94/examples/data/.gitkeep -------------------------------------------------------------------------------- /examples/data/images/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcoancona/DeepExplain/87fb43a13ac2a3b285a030b87df899cc40100c94/examples/data/images/.gitkeep -------------------------------------------------------------------------------- /examples/data/images/0c7ac4a8c9dfa802.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcoancona/DeepExplain/87fb43a13ac2a3b285a030b87df899cc40100c94/examples/data/images/0c7ac4a8c9dfa802.png -------------------------------------------------------------------------------- /examples/data/images/1c2e9fe8b0b2fdf2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcoancona/DeepExplain/87fb43a13ac2a3b285a030b87df899cc40100c94/examples/data/images/1c2e9fe8b0b2fdf2.png -------------------------------------------------------------------------------- /examples/data/images/4fc263d35a3ad3ee.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcoancona/DeepExplain/87fb43a13ac2a3b285a030b87df899cc40100c94/examples/data/images/4fc263d35a3ad3ee.png -------------------------------------------------------------------------------- /examples/data/images/5b3a8c63e41802e7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcoancona/DeepExplain/87fb43a13ac2a3b285a030b87df899cc40100c94/examples/data/images/5b3a8c63e41802e7.png -------------------------------------------------------------------------------- /examples/data/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marcoancona/DeepExplain/87fb43a13ac2a3b285a030b87df899cc40100c94/examples/data/models/.gitkeep -------------------------------------------------------------------------------- /examples/mnist_tensorflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "## DeepExplain - Tensorflow example\n", 10 | "### MNIST with a 2-layers MLP" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "collapsed": false 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "from __future__ import absolute_import\n", 22 | "from __future__ import division\n", 23 | "from __future__ import print_function\n", 24 | "\n", 25 | "import tempfile, sys, os\n", 26 | "sys.path.insert(0, os.path.abspath('..'))\n", 27 | "\n", 28 | "from tensorflow.examples.tutorials.mnist import input_data\n", 29 | "import tensorflow as tf\n", 30 | "\n", 31 | "# Download and import MNIST data\n", 32 | "tmp_dir = tempfile.gettempdir()\n", 33 | "mnist = input_data.read_data_sets(tmp_dir, one_hot=True)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": { 40 | "collapsed": true 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "# Parameters\n", 45 | "learning_rate = 0.005\n", 46 | "num_steps = 2000\n", 47 | "batch_size = 128\n", 48 | "\n", 49 | "# Network Parameters\n", 50 | "n_hidden_1 = 256 # 1st layer number of neurons\n", 51 | "n_hidden_2 = 256 # 2nd layer number of neurons\n", 52 | "num_input = 784 # MNIST data input (img shape: 28*28)\n", 53 | "num_classes = 10 # MNIST total classes (0-9 digits)\n", 54 | "\n", 55 | "# tf Graph input\n", 56 | "X = tf.placeholder(\"float\", [None, num_input])\n", 57 | "Y = tf.placeholder(\"float\", [None, num_classes])\n", 58 | "\n", 59 | "# Store layers weight & bias\n", 60 | "weights = {\n", 61 | " 'h1': tf.Variable(tf.random_normal([num_input, n_hidden_1], mean=0.0, stddev=0.05)),\n", 62 | " 'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2], mean=0.0, stddev=0.05)),\n", 63 | " 'out': tf.Variable(tf.random_normal([n_hidden_2, num_classes], mean=0.0, stddev=0.05))\n", 64 | "}\n", 65 | "biases = {\n", 66 | " 'b1': tf.Variable(tf.zeros([n_hidden_1])),\n", 67 | " 'b2': tf.Variable(tf.zeros([n_hidden_2])),\n", 68 | " 'out': tf.Variable(tf.zeros([num_classes]))\n", 69 | "}" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 4, 75 | "metadata": { 76 | "collapsed": false 77 | }, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "Step 1, Minibatch Loss= 2.2395, Training Accuracy= 0.328\n", 84 | "Step 100, Minibatch Loss= 0.3232, Training Accuracy= 0.922\n", 85 | "Step 200, Minibatch Loss= 0.2276, Training Accuracy= 0.938\n", 86 | "Step 300, Minibatch Loss= 0.2242, Training Accuracy= 0.922\n", 87 | "Step 400, Minibatch Loss= 0.1208, Training Accuracy= 0.977\n", 88 | "Step 500, Minibatch Loss= 0.0622, Training Accuracy= 0.984\n", 89 | "Step 600, Minibatch Loss= 0.1084, Training Accuracy= 0.953\n", 90 | "Step 700, Minibatch Loss= 0.3195, Training Accuracy= 0.883\n", 91 | "Step 800, Minibatch Loss= 0.1530, Training Accuracy= 0.961\n", 92 | "Step 900, Minibatch Loss= 0.1073, Training Accuracy= 0.945\n", 93 | "Step 1000, Minibatch Loss= 0.0360, Training Accuracy= 1.000\n", 94 | "Step 1100, Minibatch Loss= 0.0540, Training Accuracy= 0.984\n", 95 | "Step 1200, Minibatch Loss= 0.1075, Training Accuracy= 0.969\n", 96 | "Step 1300, Minibatch Loss= 0.0771, Training Accuracy= 0.961\n", 97 | "Step 1400, Minibatch Loss= 0.1078, Training Accuracy= 0.977\n", 98 | "Step 1500, Minibatch Loss= 0.0581, Training Accuracy= 0.969\n", 99 | "Step 1600, Minibatch Loss= 0.1255, Training Accuracy= 0.969\n", 100 | "Step 1700, Minibatch Loss= 0.0392, Training Accuracy= 0.992\n", 101 | "Step 1800, Minibatch Loss= 0.0728, Training Accuracy= 0.984\n", 102 | "Step 1900, Minibatch Loss= 0.0733, Training Accuracy= 0.977\n", 103 | "Step 2000, Minibatch Loss= 0.0292, Training Accuracy= 0.977\n", 104 | "Done\n", 105 | "Test accuracy: 0.9617\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "# Create and train model\n", 111 | "def model(x, act=tf.nn.relu): # < different activation functions lead to different explanations\n", 112 | " layer_1 = act(tf.add(tf.matmul(x, weights['h1']), biases['b1']))\n", 113 | " layer_2 = act(tf.add(tf.matmul(layer_1, weights['h2']), biases['b2']))\n", 114 | " out_layer = tf.matmul(layer_2, weights['out']) + biases['out']\n", 115 | " return out_layer\n", 116 | "\n", 117 | "# Construct model\n", 118 | "logits = model(X)\n", 119 | "\n", 120 | "# Define loss and optimizer\n", 121 | "loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(\n", 122 | " logits=logits, labels=Y))\n", 123 | "optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n", 124 | "train_op = optimizer.minimize(loss_op)\n", 125 | "\n", 126 | "# Evaluate model (with test logits, for dropout to be disabled)\n", 127 | "correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(Y, 1))\n", 128 | "accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n", 129 | "\n", 130 | "# Initialize the variables (i.e. assign their default value)\n", 131 | "init = tf.global_variables_initializer()\n", 132 | "\n", 133 | "# Train\n", 134 | "def input_transform (x): \n", 135 | " return (x - 0.5) * 2\n", 136 | "\n", 137 | "sess = tf.Session()\n", 138 | "\n", 139 | "# Run the initializer\n", 140 | "sess.run(init)\n", 141 | "\n", 142 | "for step in range(1, num_steps+1):\n", 143 | " batch_x, batch_y = mnist.train.next_batch(batch_size)\n", 144 | " batch_x = input_transform(batch_x)\n", 145 | " # Run optimization op (backprop)\n", 146 | " sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})\n", 147 | " if step % 100 == 0 or step == 1:\n", 148 | " # Calculate batch loss and accuracy\n", 149 | " loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,\n", 150 | " Y: batch_y})\n", 151 | " print(\"Step \" + str(step) + \", Minibatch Loss= \" + \\\n", 152 | " \"{:.4f}\".format(loss) + \", Training Accuracy= \" + \\\n", 153 | " \"{:.3f}\".format(acc))\n", 154 | "\n", 155 | "print(\"Done\")\n", 156 | "\n", 157 | "# Calculate accuracy for MNIST test images\n", 158 | "test_x = input_transform(mnist.test.images)\n", 159 | "test_y = mnist.test.labels\n", 160 | "\n", 161 | "print(\"Test accuracy:\", \\\n", 162 | " sess.run(accuracy, feed_dict={X: test_x, Y: test_y}))" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "### Use DeepExplain to find attributions for each input pixel" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 6, 175 | "metadata": { 176 | "collapsed": false 177 | }, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "DeepExplain: running \"saliency\" explanation method (1)\n", 184 | "DeepExplain: running \"grad*input\" explanation method (2)\n", 185 | "DeepExplain: running \"intgrad\" explanation method (3)\n", 186 | "DeepExplain: running \"elrp\" explanation method (4)\n", 187 | "DeepExplain: running \"deeplift\" explanation method (5)\n", 188 | "DeepExplain: running \"occlusion\" explanation method (6)\n", 189 | "Input shape: (784,); window_shape (1,); step 1\n", 190 | "DeepExplain: running \"occlusion\" explanation method (6)\n", 191 | "Input shape: (784,); window_shape (3,); step 1\n", 192 | "Done\n" 193 | ] 194 | }, 195 | { 196 | "data": { 197 | "image/png": "iVBORw0KGgoAAAANSUhEUgAABWAAAAC9CAYAAAA9b9SLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXeYZGWV/7+nqjrHyRGGjCIGEJgcGJIgQYKKIMiuqJgD\niq66v8U1rboGxLiuK4giIjlKGoaZ6QkgSpIcZpgZJvd0T+fuqnp/f7y3uqrrnOquDtVV1f39PE8/\nU3Pqvfe+995zzz313veerzjnQAghhBBCCCGEEEIIIWTkCeW7A4QQQgghhBBCCCGEEDJW4QAsIYQQ\nQgghhBBCCCGE5AgOwBJCCCGEEEIIIYQQQkiO4AAsIYQQQgghhBBCCCGE5AgOwBJCCCGEEEIIIYQQ\nQkiO4AAsIYQQQgghhBBCCCGE5AgOwI4CIvJVEfnfkW6bxbqciBwyEusiJBMi0iAiR+W7HxYicqWI\n/CHLtjeLyKm57hMZPCJyoYjcn/J/xjaSF0Rko4icGHwesfs1IYWMiBwQxN1InvvB66+I4fkjhBAy\n3uEA7BAQkUtE5GkRaReR7SLySxGpz9TeOfcd59yl2ax7MG1J4REklx0i0iIiTSKyVkQuE5GcX2si\nskxEtmT47hoR+VbwOfFDqjXl78kgGU78v1NEYin//2eG9Z4BoMU594/g/1eKSE+wTGL/5+dsp0eW\n7wH4Vr47MRZIuQ5SfexnQ12fc+6PzrmTR7KPFkFsX5Phu5XBddEqIrtF5BYRmZHyfTH7/phARM4X\nkQ0i0iYiO4PPnxARycX2Rup+PZjBLRG5ZoDvV4pIznOI/q6V8ULqYFIWbUflvPSz/awfRg5jG2P+\n+iOAiCwK7m/NItIYPIQ/drDr4e8dUqyMRDwVkf2DfDE8Uv0K1pv6G++jI7nuDNtbEeTG4zofKFTo\nq322VzC+ygHYQSIil8MP1HwJQB2AeQDmAHhAREqN9kzoxh9nOOdq4P3ivwB8GcBv89slk3rnXHXw\n9/YgGa52zlUDuAzAupTv35JhHZcBuC7N9udgHZMBPAzgL7nbhZHDOfcogFoROSbffRkjnJHiP9XO\nuU/lu0MjwKcC3z4EQDWA/077PuH7UwCsAXBLrgYfSF+Ce/NVAH4AYDqAafDxaSEAdW8OlhnRZDJX\nBMnvj0WkMvj/W0Uk3ffIOKUQ8syxfP2RJCJSC+AuAFcDmAhgFoBvAOjKZ78IyYQMctLUaOGcez3I\njWM52kS9c+5/AEBESkXkpuChoRORZdmuRESOFJH7gokHLv1759xy+FhPhgl9FRCRI0TkbyKyN/h7\nUESOyGYlwUPgF4KHgztF5NrgnpXYj4LxVQ7ADoLgJH4DwKedc391zvU45zYCeB+AAwB8MHjScJOI\n/EFE9gG4JP3pg4hcLCKbRGSPiPy79H0lp7dtypOBD4nI60Hw+1rKeo4TkXXiZ1ttE5GfWYPAJD84\n55qdc3cAeD+AD4nIkQAgImUi8t/BOd0hIr8SkYrEciJyuog8IclZdG9L+W6jiPybiDwbBKbfiUj5\n6O+dv6EDWA7gEet751wUwB8BzBKRKSnL9bd/XxaRreJnEL8gIicE9rD4GbqvBN89LiL7Bd9dJSKb\nRWRfYF/cT5/nBdtsEj/rd1lak5UA3j20I0IGIkguGoJY1SwizyfOccr3rwbn+DURuTDFnmlmap2I\n/F5EdgVx9esSzDhPLBdcb3uDdQ67zIRzrgnAbQDekeH7HgDXwg9ETBru9kj/iEgdgP8E8Ann3E3O\nuRbn+Ydz7kLnXFfQ7pogob1HRNoAHC8i7xaRfwTxY7OIXJm27otS7tdfS/su/d6eMb6InwH5zcD/\nW0TkfhGZHHy9Kvi3SfxMgD4zp51zrwO4Gf5h1/EAPg7/IHig47JMRLaIyOXik9FtIvIvKd9fE9x/\nHgj69IiIzAm+U7MCg324VETeDOBXAOYH/W0aqC9jnf5ijYh8G8BiAD+TlLcBRORNwbFvDO5370tZ\n3yQRuTPwy8dE5FupMTA4N58UkZcAvBTYzHuhiLwLwFcBvD/Y/pOBvU5Efhv4xdZgG+Hgu3CwL7tF\n5FX0c18ci9efiBwSXA/NwTH4c78OMH44DACcc39yzsWccx3Oufudc08BgIgcLH6W0Z7guP1RMgwg\njOD5S52V2xT40SUicqz4HDuc0u6chP8b/blGRH4hIvcGftAgItNF5CfBNf28pJTbEpGvSDInfVZE\nzk75bki5DhlZZJCTpsY4awB8EMD2QS7XA+BGAB8e8R6RXuirvbwB4Dz4B3yTAdwB4IYsl20AsNA5\nVwfgIAARFOibrRyAHRwLAJQDuCXV6JxrBXAPgJMC01kAbgJQDz8A1Yv4UfxfALgQwAz4i2zWANtd\nBOBwACcA+H/if/wAQAzA5+EddH7w/SeGsF8khwQzK7fA/wAD/KzYw+AHbw6BP///DwCC5O7/AHwM\nfuDm1wDuEJGylFVeCOAUAAcH6/l67vfC5FAAcedcprIHpQAuBrAHwN7AlnH/RORwAJ8CcGwwg/gU\nABuD1X0BwAcAnAagFsC/AmgPvnsM/lhOBHA9gL+IMSgtIrMA3A0fjCcC+CKAmyVlcBjAcwDePtgD\nQQbFXACvwMet/4CfJTpRRKoA/BTAqcH5XwDgiSzWdzV8HD0IwFJ4n/uXlO/nAngh2N73AfxWZHiz\nUkVkEoBzALyc4fsyAJcA2Oyc2z2cbZGsmA+gDMDtWbS9AMC3AdTA/yBpg/eZevhBpo+LyHuA3vv1\nLwFcBGAmfMyaba00y/hyAbxvToWfFfjFwL4k+DfxVsK6AfYhBkDNRMnAdCTzjA8D+LmITEj5/kIA\n34S/Pp5AWs5i4Zx7Dn3fksj7DI0CwYw1zrmvAViNYAa9c+5TQbx7AP6eNRXA+QB+IcmZHj+H983p\nAD4U/KXznmCbiWXMe6Fz7q8AvoNghr5zLnGPuwZAFD4POQrAyQASr4R/BMDpgf0Y+B9EmRiL1983\nAdwPYEKwzauz2LfxwIsAYuJnFp2aFksAQAB8F/58vRnAfgCuHGilwzl/4h8a3Qt/jqbAXwNPOOce\ng88/U8sXXQTg9/105X3wOfVk+Fm96wD8Pfj/TQB+lNL2Ffi8vg5+cs4fJKUsEUY+1yGDQLKYNBW0\n62+Cx1sk+ZBsh4h81djOMkkr/yZ9J1YdJ35G375gHT8K7H0ecorITBG5I9jWyyLykZT1XSkiN4qf\nbNAiIv+UQbyt55zrds79xDm3Bj5/SO1rqfhJMZ9OOR4NIvL/gmVfcM79FoBZio4MH/pqEudck3Nu\no3POwd9PYvA5Sja+mv6bq3fZQoMDsINjMoDdwcy+dLYF3wP+R8ltzrm4c64jrd15AO50zq1xznXD\nD7wN9EPqG8FT5icBPIlggMg597hzbr1zLhpcqL+GH4AghccbACYGAz8fBfB551yjc64F/ofR+UG7\njwL4tXNuQzC74Fr4JHBeyrp+FgSZRvgfMh8YYp92i58t0CQiXxy4uaIeQIthf5/4GVEd8D/izku5\nZvrbvxj8j7gjRKQkCMCvBMtdCuDrQSLgnHNPOuf2AIBz7g/OuT3BdfDDYB2HG/36IIB7nHP3BNfm\nAwD+Bj+om6Al2C8yfG5L8a+mlBv0TgA/CRKMP8MPWCRmV8UBHCkiFc65bc65fhM+8TNbzgfwb8Gs\nq40Afgj/IyvBJufcb4JXZ66Ff/A1bYj79FMRaQawGz7efzrt+4TvbwbwTgBng4wG6t4sydlQHSKy\nJKXt7c65hiAGdDrnVjrnng7+/xSAPyF5Hz0PwF3OuVXBLL5/h/dRi2ziy++ccy8GecGNyDCDOh0R\n2R/AufB+/TCA38CXtsmGHgD/GVxv9wBoRd/4eHfK/n0Nflbrflmum/RlMLHmdAAbnXO/C+5d/4Cf\n5fzeIK6dC+A/nHPtzrlng/Wl890gj+gABnUvhIhMg/fNzznn2pxzOwH8GMlc5H3wcTqRa3y3n/0e\ni9dfD/zso5lBP/NeM64QcM7tg58U4uDj0K7gh/i04PuXnXMPOOe6nHO74Acss/ldMpzzdwGAB52f\nldsTXAOJAc1rkRy8mAj/YP/6fvpxa/DbqhPArQA6nXO/D67pP8M/kEgci784594I+vtn+Jnox6Ws\na8RyHTIksp00ZU7wEJEaAA8C+Cv8A4VDADw0hH5cBeAq51wt/MSZGzO0uwF+ss5M+Nj3HRFZnvL9\nmUGbevhZgUPWVUglGIv4IID/FD/B6ysAwvC/L8noQF9NI/gt1Qn/YO07QHa+Kv5tiGb43/PnAvjJ\nYLc9GnAAdnDsBjBZ7HpbM4LvAf/jOxMzU793zrXDP6Htj9TXBdrhaw9CRA4TkbvE1wnZB++gk60V\nkLwzC0Aj/NP5SgCPJwam4ANm4in/HACXpw5cwc8gmJmyrlT/2pT23WCY7JyrD/6GUk9wL/wslnRu\ndH5G1DQAz8APRCXIuH/OuZcBfA5+tsROEblBRBL7th/8TAKFiHxRRJ4T/5pXE/xsBOs6mAP/4zZ1\n24vgr90ENQDG/eu0I8R7Uvyr3jn3m8C+1TmX+tBpE/z5b4Mv13EZgG0icreIvGmAbUwGUBKsI3V9\nqW8V9MbPIN4CQLWILJYBROYMPuP8qy1vQ3JmVCo3Bvs61Tm33Dn3eJbrJcNjD9Luzc65BUEc2oO+\nuU6f+7OIzBWRh8WXsGiG979E/Ei/X7ch8/06m/hi3ssHwvn6W59P+K9z7inn3JeyWRbAHtf3oXH6\ndlP3rxX+PjXUe8p4x4w1GdrOATA3zV8uhJ/xOgX+1blUX7XyynRfzvZemNh+CXysTWz/1/CzC4E0\n30ffGJvOWLz+roCfffNoMIPnX/tpO65wzj3nnLvEOTcbwJHw5+kngB/YD3K3rcHvkj8gu98lwzl/\nGfPDYPtniJ91+j4Aq51z2/rpx46Uzx3G/3t9Rnw5uSdS+nsk+u7rSOY6ZPBkO2kq0wSP0wFsd879\nMHgI0+Kc2zCEfvQAOEREJjvnWp1z69MbBA89FwL4crCtJwD8L/zbAQnWBA8oYvDliEbsbT3n3DPw\ns89vg59ZfpHLXb1PoqGvphHkD3Xwb8b+I8Xer686P8GxDv732Q+QfJO2oOAA7OBYBz9b75xUo4hU\nAzgVyacN/c1o3YaUH+3ia38OtUbgLwE8D+DQ4GnFV+ETRlJAiFeHnQX/ut1u+CTuLSkDU3XOC/cA\n/sfGt9MGriqdc39KWWXq7KT94WfX5oOXAYj4V8cUzr8G8FEAV0rytax+9885d71zbhF8Mu6QrHO4\nGf5pXB/E17i7Aj6xnhAE7GbY18FmANelbbvKOfdfKW3eDD/LnOSOWcFM8AS9Puycu885dxL8j67n\n4WfY9MduJGcqpa5v60CdcM6tdgOLzGVa9mn4BODnaftC8kPi3nxWFm3T78/Xwz+h3y9I2n6FZPzY\nhpR4K14EK9P9Opv4km2fMjd07pJs22ZJ6v5Vw7/++wb8q+GAf2CYYHpqV0a4H2Od9OO1GcAjaf5S\n7Zz7OIBd8KUBUh/wWLOSe9eZxb3Q2n4X+j6IrU2JhX18Hz6uZmLMXX/Oue3OuY8452bCl0z6hYgU\n5KuM+cQ59zx8KYsjA9N34I/nW4PfJR9Edr9LhnP+zPww6N9WeP88B/4NgnTR2CEhvuzBb+AHByYF\n19sz6LuvI5nrkMGT7aSpTAP4/Q3sD4YPw5eLe158Pe/TjTYzASTeikyQcTIB/AOI8gz7NlSuhc+l\n73HOvTSC6yUDQ181CB5W/QrA70VkaspXA/pqEPv/iuzrx44qHIAdBM65ZvgaHVeLyLtEpEREDoCf\nor0F2d3Yb4J/GrtAfI3MKzH0QdMaAPsAtAZPTz8+xPWQHCAitUHwugHAH1zwmh18ovXjRDARkVki\nckqw2G8AXBbMChERqRIvUpE60/STIjJb/OtUX4N/LSp1u+VpfzkZIHL+VYAH0c/rZc65FwDcB//D\nEOhn/0TkcBFZLr5+Zif8QHXidcP/BfBNETk0WO5t4utw1sD/UN0FICK+DkwtbBIzIU4RXzemXHw9\nnNQfuUvha4mR3DEVwGeC+Ple+EHve4KZM2cFM1W64F+VzvS6KQAgeOp5I4BvBz40B/4VnT/0t9wA\nSPo1lKHdtfCzvM8cxrbICOC8KNo34AdJzgt8ISQi7wBQNcDiNfDJZKeIHAf/OmuCmwCcLv6VplJ4\noaFMeVM28SUTu+B9/aAs2o40p6Xs3zcBrHf+tfNd8A8yPhjsz7+i7yDHDgCzZXyJQwyHHeh7fu8C\ncJh4kamS4O9YEXlzENdugX94WRnkdxdbK01hoHvhDgAHSCBQ6PwswPsB/DDIVULiBZQS9/Mb4eP0\nbPF1Pr+SacNj8foTkfemLLsXflCx3/vReEC8cNzliWMjfjbUBwAkZkrVwN+7m8U/nM92pv5wzt8f\nAZwoIu8TkYh4AbvU8hK/h89B34q0V3yHQRW8T+wCAPHihkemtRmxXIcMiWwnTWUawN+M7O7JbUh5\nUCm+hExv7WLn3EvOuQ/A+8P3ANwUnPtUEmXqUn/rZTWZYAT5Bfx96RQRWTSK2yX01f4IBX1OHeDN\n1lcjyPBwLt9wAHaQOOe+Dz/T9L/hBz83wDv+CS5Qeh1g+X/C1w28Af7pfit8naABlzX4Inyy2gI/\nsEWV1sLgThFpgfeLr8HXwEoVBfoy/OzR9eJf0XoQQZ0259zf4Oum/gw+6X8ZXswnlevhfzi9Cv/E\nK1Xhbxb8wGXqXy6Dz6/Rt96mxQ8AfFREpg6wf2XwAmW74Z+eTQXwb8F3P4L/QXg//HX3WwAV8IO7\nf4UXhtgEP3BrlgBxzm2Gn6HzVfikeTP8j4MQ0DtTudV50TQyfO6U5Cv+rSJya2DfAC/gthu+bs95\nwSs0IfjB0zfgX4NeiuweKn0aPql4FX6W+fXwQm9DZQHSriExntwGDyCugq9LSPJMcG/+AvwP7R3B\n36/h4+3afhb9BHw9qRb4muy9Na+C+/Un4X1qG3zMMkUHB4ovA/S9Hf5aaBD/Ouu8gZYZQa6HF4hp\nhC8X88GU7z4Cvw97ALwFfY/jCnhRju0iQqG5gbkKwHni1dR/GsweORm+5uob8Pe878HfBwE/s64u\nsF8HXxu1vzxxoHvhX4J/94jI34PPF8OLGT0L79s3IfnK92+CdT4JL0LU78DVGLz+jgWwQURa4Wfo\nftY59+pA6xoHtMCLS20QkTb4gddnAFwefP8NAEfDz76+G1kOeA7z/L0OXxPxcvg49gT6vvJ6K/xs\nqVtdsjTIsHC+LvMP4QdOdsAP7jakNRvpXIcMgkFMmso0weMuADNE5HPihYJrRGSusakX4Wf4vVtE\nSuBF3HqFk0XkgyIyJZiAkyhx1mfAPfD/tQC+Gzx8eBv8bMThTCboQ7APiQkFpZIyQUdELoK//18C\n4DMArg0G/xAck3L4e0Vikk+Z2gAZMvTVJCJykogcFTyIq4X//b8XXiR7IF+9ULxmQuIthW9jaLVw\nc49zjn95/IOvJxQFcGC++8K/wv+Dr2VyYr77kdanBgBH5bsfI7AfNwM4Ld/9GMt/8DfMNfnuB//4\nxz8H+FeHv5XvfvAvq3P1PQDX5rsf/ONfMf7BT1YYtdyZuU7h/MEPDj2DZE3fX8OXaUl8H4YfiHoN\n/gHDYwBmB98dCT+Asxf+YdhXAvuV8G82pp7vbfATqr6Y+lsNfmBqJ/yEq3/C6yMAXt3eAYgE/58N\nP5DWGPjrZSnrT99en2XT9tf8LuiTS/s7AH724h4AC1Pa/hnAb9LWl/q3kf5OX82FrwJ4L3xZllb4\nh3F3A3hb8N1Avvpt+AHrtuDf/4EvEVNwvipBZ8goIiJnwF8kAv8EdS6Aox1PBhkAEdkI4FLn3IP5\n7gshg0VELoH3X77eREieEZFrAGxxzn09330hfRFfdqAUwNPwszHvgY+dt+W1Y4QUGSJyLvwDjMOc\nn9k1Gtu8BMx1SB4IZv69AP8WxJdcUgA3V9t7AMA8AI86507I5bbI2GI8++pIFm8m2XMW/HRyAfA3\nAOdz8JUQQgghhMDX0vwTvODFDviH9bfntUeEFBkishLAEfBK2ayzSsY8zrlNADJpF+RieyeN1rbI\n2GI8+ypnwBJCCCGEEEIIIaRoCGo+Ppvh6yOcrxFMSN6hr5IEHIAlhBBCCCGEEEIIIYSQHDHaJQg4\n2kuGiozmxrY3tSlfjRkPK2pKw8pWYmi2xjN4vmXujOq3pMIhvfs1e19TtujE/fU2QvoyD3VrIdgW\n0W8BlBjbtfoCZN7HdKzj2B3TtrpSfSAl2qls7YYYZ0VIry/UusvsT7xqkrI1dut9nBQyBKjjUWUq\nmzBtVH11R7P2Vet4TijXvhoxemosCsD2VWs7YWOdVY2vKJvpq+FSZQt1tSpbe6TK2G72h9168CjG\n8lHDqc04YFz0lq92Gr5aNhhfrZygbI09+rxaviqxHmUrnTidvpoGfdUz3n21aHOACbP1NixfZQ7Q\nS7HnAN173lA7ax2nWPUUvbCRHyLTm/qGXXr0dhDW6wzv3aJsVlyF4atixFWI9g1nbDcTYuyLM9Yp\n0W5ti8eULV5RZyxrHBsD8/ocRFwNdTTrdZbq+04hxNVd+7LLAerK8pgDDCeu5iAHsLBCsLV/w8kB\nukL6flAq+roZqzmA5avWebN8KNM90sK6b5rrRJbx14rfRmy0/DcGI67qtWXOZwx/y7bui5E2mcfR\nOjZWPLfuzRkxYr95bzTWad0jyuomDegA1v4SQgghhBBCCCGEEEIIGQE4AEsIIYQQQgghhBBCCCE5\nggOwhBBCCCGEEEIIIYQQkiM4AEsIIYQQQgghhBBCCCE5QqyCuTmEIlxkqIxq8e2ulr1a1KBHFwp3\nJVr8IRrWhcsjXfvM7bRFapStMmwIVDijMHaW127EKGJtCbY0d+ki1lPLdbumHvu5TY0hmGHVIZeY\nLljdGtfFrrMV57GK8FsFwsuiWnTEd1L3Ox7R53Bftz4+E6Ja/KB00sz8+2p3h2rnIoavllQqW6Tb\nELyALSRgCZ1Ejed61vmwxAFKDYfpNn1Vi2BM16fMFFIBbIEH01eNguutMb1/lq9OzNJXLVtOfLVn\nr7KVTp5NX02DvuqhrxZ+DmBh5QXMATzMAbSvuhK9n2IItPkvjDzUWN4UgbEEvGI6Zpn9MYRmQm17\nlC1WN0O3a9exBLCF10whFkt0xTg+EtWxIW6JnmUpFmMdGwCmrzrDV839NrZdEHF1ODlAhrjaXqLj\n6nBygGzjqpUDtIxWXDX8qD2u46UVV+vL9DayzgHi2YnNAXZcbe3Rx6eupwDi6t4dhgKqcY6yFQS0\nlgVMsafBCF9ls2zMFGm1tmGIVxrCY5k05CzRzmyx+mgJeGUr1mWJ9mUUmLTIdL7SMfKZ8uo6inAR\nQgghhBBCCCGEEEJIvuAALCGEEEIIIYQQQgghhOQIDsASQgghhBBCCCGEEEJIjuAALCGEEEIIIYQQ\nQgghhOQIinCRYmFUi293N25XvtpdXqfaZSvYUul0gX6/Al0QujNckUUPgTdadcH16dVGMW+jj40d\netk6Q9xid7sWkJlSqdsBQJdR0bvVEK2oN7YTNs5ue49VkFvbppfpbbQYgh61ogtlA4AzCm23uxJl\nqwzpY2FRXlWTd1/tLNO+amHF//JMvmqIQnQaYkcWW1t7lG1GtT7Glq/uatfbtXxoj+HTUyuNQvgA\nOqKGUEK3Pr8Ty/XyVvH5jqj2QWtfZpRn6atxWwTFEijpcPpYVEAfb6ugPH1VQ1/10FeZAwDMAVIp\nphwgXpqdD8lgREksYShLbMYgsvtVZYtOPEA3NPoTbn5D2SyBq/C+bXob9bPN/oS627TNEK6K1UzT\nCxv+Il2GIKSxL6ZQmCVCFbKvMVOkLKr9OlsRoELwVSuuWlgiPbnIAYYTV3cbcdUSDhxuXN1nCHtN\nrLCXT2fEcwBni3C5SKmydRk5QFmB5gCWYJyzBLMMm8VgYq0lutkV08v3WEpaBrYwl7ZZ91eLUIYh\nmeEun07Y0jwz2lWU6IaWIJgpzDUYrHNo+WpFBUW4CCGEEEIIIYQQQgghJF9wAJYQQgghhBBCCCGE\nEEJyBAdgCSGEEEIIIYQQQgghJEdwAJYQQgghhBBCCCGEEEJyBEW4SLEwqsW3Ozo6lK9aYigVRkXn\nXUZx9Zoy+1mHJUDSAS3+YBWT7jQKqRsms4h1RUQbrQL329t0Qfn2HruQ+O52LQQwrbpM2XqMyt+V\nJbowe3uPPo6WGE6Jod5RE9bbkB67UPxeVCpbva4db9IV19uuq64cVV/tNHy1PUtftc6vJZACZO+r\nlqd3G75lideY2zV8NWosbAkgtRpiRYDtqzNqtLiFhdEd7OvU29m/TjtR2LiO6yLaV0NdWiAEABpD\nNcpWX5JdUXirqH9t1cCF4kcS+qqHvppCgfoqcwDPDsNX2wxhLaB4c4Am0TlAnd6MSaHmAJYolCur\nVrZw4yZlswSuAFvYSSxhrpAhUtVjiCVlKUoTL9PiSZbwlCXWJe3N5jpje3cqW2SqLdiVjgsbooyt\nTcoWn36oXtgQ7ImX61iZKa4iro93vMIQsTLiqrVseXVd3n012xzAikV1ZdnnANGQvs9ZMa/HsGWp\ndYQqQwzIWl9jh45tbZl+W7VpX59arffFGs4pNWKjlWvMqtHrs0Q8rRxAum0hzqaQjjfW8tY1Yd1i\nCiFftfpliW5afmXlVAAQtg60gbUdK78cDt2Go1t+mSkHsO7Z1v4ZtwhURPS1HDGOmXWNVZbodlaO\nY4l9AtkPTtqiq9qWja9yBiwhhBBCCCGEEEIIIYTkCA7AEkIIIYQQQgghhBBCSI7gACwhhBBCCCGE\nEEIIIYTkCA7AEkIIIYQQQgghhBBCSI7QlY8JIWYh6n1G4fKOqCH+YIhtNBnCJwBQFerQtuheZXNG\nkf3SqBY1MIv5d2gRgh3R2qz6OL1ah4hMRe/ve2mXstWWa7GCyZXaVmYUyy4xlEOmGwIcNSFLBEIX\nlO8p1QXhAWBCjy4g75CdyM3WVl3ov87eTM6waqG3GsYO4zBZ53Iwvlrds0fZYlWTlK0iqpfN2ldj\n2ld3G2I5PKc/AAAgAElEQVQMs2u1b2Ty1b++qH11crXu46RK7UeWGI7lq/sZwkb1IS2mEA9pX+ss\nM0Q1AEyIaxEZZyxviW28vk/bjqyqMLeTK+irHvpq6koL1FeZAwCwc4Da0sLPAZyRA0Qz5AD1w8gB\n3mg1hIFGOQdATF+roU59zl23FnaKGzEw1KrPIwC4inplE+PYWXHVUiSxxKMsXw236dgdatqmtzvl\nYL0Roy8A0LbqXmUrnfiK3vaEqcomEe2DEjGUW6fp/rhSI45ZwlxWOwBiCZcZy1txtWTni7pd9bHm\ndnKFlQNYIj9WDmDFneauDHFVdFwtNeJqvHKCsmWdA3S2KNuuuL749xiCWzONuFpdas+Je/CV3cpW\nb8TVSUZctWKolRdY4oYTI7rfVlx15fpeAgD10aHnAFtatE8cMco5gCUOZ+UFcUMcyxLcsu+aQKlk\nKWxmXfuG2J4zYlHMOG+txnVn6GOaWPsM2EKcezv072SLCRXaB2fUaBHPihJLpNQQ3IJxXA3xRsC+\nP7mI9lVLcGuoQmicAUsIIYQQQgghhBBCCCE5ggOwhBBCCCGEEEIIIYQQkiM4AEsIIYQQQgghhBBC\nCCE5ggOwhBBCCCGEEEIIIYQQkiM4AEsIIYQQQgghhBBCCCE5QlwGJbMcMaobKwQ6O7UK4MaNG5Xt\nzjvvNJe/4oorlC1kqL197GMfU7YDDzxQ2T7xiU8oW1VVlbntAsMSqcsZna3N2lcNpdGOuO5WVWej\nXl+FrcRqqR+G925WNolpFcHolEOUbWe7Vq18tUn7YNg4mkdN135gqZPWBQKECxctTuuQVmR2HVol\nNFShVUJdxCsdNqxa2Wvb1Z1JM7IvlqJnqbGDksGDqrr3KVtPuVbnNRUVu1qVrWzCtIL01S7DVysG\n46vG47rIno3aaCiadk3WvrrLkLl9tVH7asg4v8fO0L7aEdXnJyE0m52v6nMphq8iUBhN9dWmqKFi\nbNxbSyy/1FtAKIOzVkX19WQq0BpqqZZib+nE6fTV9G3TVwHQV0c6B+gon2hupyxkKChbOYDhB9Gp\nhynbeMwBrH2xcgBLmRoYPzmAxLQCtHS1KZul+A4ACGsl7WzjanTyQXp1TVv1sluf17aQ9oPYm5fp\n9RnXjSvxStbzTz4zrY+GwrsRV2NtWo08UuP9d+0vv9FrkyOXqHaWQrl1DJ0VQ61lATNejsW4mm0O\n0F1p5wARY6/CjZuUTQw/6JqkfdXKATY26Xhn/R45ZrI+ly1O+0HGHMC4Rl2ntkmVjllWDmDFVWso\nKGK4YEWJNmYaRhpODlAIcbW5tV3tWcQ4v9Z9wdon0wZAevT92frNb10n3cZ9ale79tXWbr1ta1/q\nyvT5nVCu/SXS7GP3vFPO7vtFVN9jENP9kRIjDgaxcd2K+3ptrSXaXzqN/NnKQ8uMIFCaIQcY1jk0\nYnp5RcWAvqrPJiGE9EN6ctCwZjUAILL7VdW258lHlK30rQt1uyA5X7hkWa/ttgdXD6ebhGT2VeNH\nY89Tq5St5MgFyhaddIBfd4qv3r1izdA7SQjoq6R4GJSvPvGwsjEHIKNF+sDruvvvAACEjEEtK642\n/u0fyjZ58SIAwIKP/0dyvau1nxMyGDLG1dceVW27nlmvbGXzT1e26MT9/boZV8kIkj7wuv6+WwEA\nskc/BIs371G2yNTZyhadeAAAYP7yU3ptD6xeN5xuFjQsQUAIIYQQQgghhBBCCCE5gjNgCSEmSxYt\nRMx6LSWYWJ94OjuSNKxa2fuk9t3L/Stdd6/QsxIISWXxooWIG74aHiVfPeV4P3Phvoc5s4D0D32V\nFAvMAUixsHDRYsAosZB4lT8x83UkWfvLb/TOgk3MtM3FdsjYYsHS5Yg541V2xlVSYJy6fIlZ4idh\nSsx8HUnWrbivdxbs8qX+bYMVj4y9N7c4A5YQQgghhBBCCCGEEEJyBEW4hkgspgt3X3fddcp25ZVX\nKtuWLVuy3o51fiSTmlAWfP7zn1e2H/zgB0Ne3ygyqsW397V1qANvFba2ClZXh3WR5paY/axjwt6X\nlU06DSGAyVpQrbW0Xtlq9r2ubFvLZynblIq+k98XL1qIUFoh8A13XIf2Vfpp/ku32jVZJh42Xdlm\nf/JyZYvXTFM26dQiGHNPP7/386o1DQCA8t36eO2p0wXz60PGTAirODiAULdR4N4oqh2r0sX+xRCb\nKK+qGVVfbW3XvrqjTffLKqReKbrQe0vcfjEiW1+NTtUiRk0hLboyoWuXsm0WfYynV1m+2tHHtuG2\na9DysPbV529oUDYAmHT4VGU76ItfVrbohP2VTQx/mfuuc3o/P7Lab7N625OqXcv0tylbeY8WKpAM\n92Xp1sfbRcqVzfJVi4osCsWPJPRV+mo6heqrzAF8DtDRcJda9qVb1iobAEw8fIayzbpM55zMAUaW\nzrYWLRazV//OiFVPVjYX0fsZ6u5QNgAItexQtmzjqiV2ZM1Ytbbttr3U5//zLv4CkCZGvOH6n6Fl\nva7D+tz19sypyUdoHzz4in9Ttmzj6vyTkvU31997CwAg3K5Fo3qmHa47Y/jlcONqPMu4mo1YzEhi\n5QB2XM0uB2jNkAPUN2lNCus675l8sLI1SaWyWTnA1pA+xlMrs4urnY/er5Z9+SatmQEApdX6Gj38\na19VNtNXjbh63BkX9H5O5ACVjfp47ak9QNnqnRals64HABBLMM7KASonmMunUwg5QNR6fcmgzJg2\nGunRxw4AQp3NWa3TuidZgolWO4twR9/tzl9+CpDuqzf8Ai1rH1TL7vnna+Y6S2v1tTPp7Trmlb3p\nGL3whJnKlJoDrHvoXgBJwc5UrPuYdBu+agiFArZApXUcXUmFYdM+nU0OwBIEhIxjFi9KimFsuEM/\nQMgniR9cSxYtxJKgn4/edm0+u0TySB9fve2a/HXEIJHELl28EEsX+34+fuMv8tklkkfoq6RYYA5A\nioV5F3+h9/OG63+Wx55oEoOu8049B/NO9Q+5Hrv5f/PZJZJHCjmuWjnAY7cXVh/J6JEqerXhhsLK\nBRODrvNPOBXzTzgVALD2kRX57NKIwRIEhBBCCCGEEEIIIYQQkiM4AEvIOGTxooV9ntCuXmO//loI\nrCrgvpHcU0y+mphZQMYn9FVSLBSTrzIHGN/Mu/gLfWa/rv/9j/LYm/5JzIQl45NiiqvMAcY385ef\n0mf267oV9+WxN/2TmAk7luAALCGEEEIIIYQQQgghhOQI1oAdIrfffruyXXrppSO+nbPPPlvZbrvt\ntiGv78c//rGyFYkI16hSajyaSBdYyUTMeK5RG9OF0AEgOkWLFUSzfC6y1xCv6aycrbcdSa5v+dJF\nAID4Ni/UceclXjhg9zc/jtkf/bRatvzEDyjblFe2mv2Zcelnlc0SqIhX1Ol2Jbqodk9KsfPEp+jz\nG3R/DtIF4KO7tylbePoBygbYRcwh+hyEop26XVwXmR9tItB9yNZXIbrIeF3XXrNpdOphytYV0wXp\nrSOyc58WT2gv1QIGdSkX3olLF/vtbvXF3m859wQAwNbLL8JBn/qkWrbmhHOUra7haaM3wEGXf0kb\nu/X57SrRgkylZdrW1ZPc64R2RvfLT6l2VW06DsT27lS28P6GUAeAeJnhq8aptgrKiyVAU6ELyucS\n+qqHvpqkUH11pHOAuqgttDGcHKCxVfvqSOcAZct1DjD5xc1mf5gD5Amjr9GJWpAn22XhbFEdK65K\nlsck3LhRNyvXftC94e7ez4uuuMrbmr34301nLQMAvPKr/8Mhn9O+Wnvyedq2Ssc2AJh94jxtNOKq\nJe7ijLgaak+5F4kXkLLiamSU4qolcFYIcdXKAdKFqzJi5AC13dnH1WxzgD1GXI2WaQG7OiOuuh1e\n/O7ej58LAGj+0RdQ/sGPqmVLF5+rbNM2vmH0Bph8vh5bsPrdWa7FrEoNgSsrB4i99Lje7qwmvd19\ne5QtNFmLPAJA3Ni2FS/MHCCm7xuFkAOUhLS2kqm2ZAqQ2eJYsVotXmn5aqdhi/ZoW73V77akiFxC\nMDARDzb85X/8FzteRLxTx/6qt75T295+nN4IgHCdzpVdpY7z8eopemHrXtTHX/y+lux6STWLNWu/\ndF1a0DFUYwu+SbUWNLXyFEsAbKhwAJaQcUIiQQCSP7oIKUQSA1pAcjCLkEKEvkqKBeYApFhIDL4C\nyYFXQgqR1LiaGHglpBBJDL4CKQOvJC+wBAEhhBBCCCGEEEIIIYTkCM6AJWSMk/p0dsUjawAALd/5\nRE62Nff08wEAG+66ISfrJ2Ob1NmEDz6yGgDQ+OUP5WRbc8+6GACw4fbf52T9ZGxDXyXFAnMAUiyk\nznxd831f0uL1e9bkdFuJ7RAyGKy42nnV5TnZ1twzLgAAbLjz+pysn4xtUme+rnvgLv9h98acbGvu\ne33JDc6w7R/OgCWEEEIIIYQQQgghhJAcIc7pAr45ZFQ3NlJs375d2Y4++mhl27lTF1LPlmuvvda0\nn3/++cp21VVXKdsVV1wx5G1Ho0ax68LDrHOdKxpb2pWvitGDqnZ9zveW6eLSEzIIcMAQqJAeXaS8\ntVYLa2w2xGLm1JYo2/FHvRkA0PCTr/TanFGMPzJFF1J//oCTlO2wej9xfuGixX3sDsDqNQ0AgMWL\nFgLwJ61hzeo+7ZoMHYCJPbqAtnS3K9uxp/unwIkZZwBQ37xRr9Ao5t1Ua4tS1DhDRMIooG6tc09c\nF+SeOaF6bPqqcUysgvqNFbqg/KbmLmU7qF4fu5OPehOAvjNSovt0f0pnzVG2Le94n7LNrvHXw8KF\nC/vY4wBWrPIzFpYv8bMYQtC+uqtT37KmxbUwgXUcjjvN92flquTsnerdL6h2zvCrtkmHKhsAlEcN\ncRSrcH1I2/bGtADA9Poq+moa9FUPfZU5AJAhB5jgz0+6r2abAzR2aV+dHG1UNuYA2dHVtEsfUKOv\n0tWiFw7bwjAmWcZVV1KpbKHtOp7Ep+nYcdzRxwAAHrriomQXy3Ufy2fruBo5Yr6yxSb462bB0uVp\nnXRY+/AD/rvjvY9LrBsb7vhj3z5aQizGMbOOw7zT3gsAWPvIimQf92zU6zPOVSyTiJolTpRlXJUu\nHZNLJ88u7rgaM3wayNpXmyqmKdsWK67W6XN+0juPAAA0/OI/kpvt0Mc4PGGqsm0+/FRlm1Ud/LZK\ni6sA0LD6Ef/d4qVJW1pc3WPE1SlxLfpmHYdj3+UF7PrE1X2vq3YuFFa2jrr9lA0AyqI6fpsUaA7Q\n2dGhDqg1iGUJTVpY9/VMWIJd3YYInTWGV96jr4mE32y4+be9NlP8zxDRik05SNkScTHVHwGfrz6y\n2ucASxcnc4BEXpDA0DIzj6NlW7DkeADojd8AIFGdt5v38EwYMd0ZwqAuUq5s3cZmaqsqBvRVzoAl\nhBBCCCGEEEIIIYSQHMEasISMUZYFs6VG+pFh+szXxFPYWMrznMTTriWLFva2T39aS0iCRD3NkX4i\nmD6ToKHB+2VHLPnUODG78MQli+irZEDoq6RYyFkOkMFXU1yVOQAZFPNPPC0n602f+do7GzVlJmli\nJtXCJUsx98wLAUDNhCUkQaL2q54LOjwyxdXUtyRSZ8IyrpKBWLD8ZAA5yAHSZr4m/LI7ZVgxMRN2\n2eKFvW/EpM+EHc9wAJaQMUjihxfQ97XD4fKBU5ehPOxDeTY3/dQ2iWQhGvxIu+/hoScNiUGQ1Ndl\nSHGSKmY0kmIYZ5+8FKXBey69iWw/WL7aE7xacveKVUPuR+JaTH29mxQn9FVSLIxKDjBEX03kAPcO\nw1eZA4wdUgdfU0sPDJe5p58PBK/yppYByETqoGtiINaJH2Zbf//tQ+5HYhA4mz6QwiZVeCu19MBw\nec9JS1EaPNUdbly9h3GVIDn4CvQtPTBc5p94GhDyw4eJgdf+SB10TQzEJsjG1zORKB2TWoqgmGAJ\nAkIIIYQQQgghhBBCCMkRFOFKwxLc+spX9OyB6667TtnEqCR+6KG6yPzDDz+sbNOm6ULgmdYZi8WU\nbcuWLcq2YMECZduxY4eyHXPMMcq2fv16sz95ZFSLb2/d26p8dUrPbtUuXqPPW6hDi23ESyvM7USa\n3lA2q8i5K6/VCxuF1OeffFbv53Ur7gMAyGuP6/X1aAGONw49UdnOOe6IPv9f8/3P4u8/vk33BcBL\nz+rjc9pXT1G2U395s7LdsO5ZZTtAtIBMT5Uvwp/6FO3RW3+n2lnn5fWoFoYAkjPHUjmwNruXiyxB\nsdEuFP+G4auTh+OrZVXmdsJNW7XRKBQfL9eiFdLToWzzT3x37+d1D93rPzynZ97FO7Wowc6j3qts\n75n75j7/X/ODL+DJn9q+uuWfu5Tt+K9pUYSTf/pnZfvTWu2rc6r0ra017p8Qp86afPyGq1W72AQt\nfPNGWBfCB4CuqN7O/obojlW4vqlHP2+lr9JXAfqqxZjKATb9Q6+vSy+bixzg1C9rEa/T/udWZbN8\n9cDQyOYAGzPkALEizwG6d72uLzZDgCSfcTXUquPY/Hed3ft5w23XAACiW19V7ZwRVyNvmqtsRx27\nrM//b7/gFLzxmBYSAoBtL2rRtyVXaF9d9k3tW8/8Xc+46pn2JmULte8FAMw/+cxe26N//qVqZ8VV\nMUTyAPv+FJ16mG5oxFXrXI+2CJeZAxgCfPFqLbg1KF9t3qZslrCRK6tRtgFzgCCuhrY9p9rF9ugx\nhN1vPUPZTn9H37GBFV+9BKv+4w7VDgAau/Vv/hM/tUjZzrr+PmW7Ye0/lW1OqRYsagn52JiaA/zt\nNu37sbqZyrapW4sVAYA1vDQnS33CQsgB9rVpEa5SMcTdYsa1aly/kmm8zVjeyosssT6r3byUHGDD\nXTf4dj1aEM11alu8WQtiH/veT/T5/18/+V68fOcTqh0AbH15r7IdtlALiF76tL7fb2hYqftTMUHZ\nXDgQq1uyrNe2dsX9etmyamVrs37wwx6cLDWUwiKGzVq2upIiXIQQQgghhBBCCCGEEJI3WAOWkDFI\n4gntSDGS9Q5X/9enez8v/oqeZZXVOoKaMun1ZEjx0TujcIRY84MvjNi6Vn/nk72fF3/150NaR6KW\nVurMAlKc0FdJsVDQOcAI+CpzgLFDYvbrSHH7BfrNq6Fy7bzkG4IfWv+3Ia1j3f1+ZmPqTFhSnIx0\nXF3x1UtGbF2p+cSiL/1oSOtgDjB2SMx+HSn++kn9NtdQGYlxgIZVKwH0nQlbTHAAlhBicuYJS7F+\nBH909cc5J3tFxVvuH7igdyqr1zRgbpAgb7jz+hHvFykOTj9hCdaP4GBWf5wb+OrNg/TVBx9ZjbkL\n3glgZAvik+KCvkqKhdHMAYbqq8wBCOAFt0Zy4LXfbb3nEgCDHzhed/8dmLv8XX5ZxtVxy2nLl2Dl\nCA689sdQf1s9+MhqzF10HADG1fHM/H/98ogOvPbHvFN8SZr19+kSRf3RsGolFiz2fl5MglwsQUAI\nIYQQQgghhBBCCCE5giJcaaxcuVLZTjjhBGWLx3Uh39JSXeD7d7/TRawvuOCCoXVukPzwhz9UNktQ\nLBrVRZytdt/97ndHpmNDI+/Ft0uMxxWRZi1K0F2rC+pnorRps7LFJuynbNKti2W3h5OiHsuX+ILs\nj957k2q3vWSqsm1s0gIcx030u7zg+KQQwfdjWjRl4iETlQ0A6uZoMZaSGi16Me3kk5Vt3mf+CwDQ\n0NDQa+uO6XBRvVWLiRz3gU/3+X/DmtXY06WXfWG3LqwPAMfM1IW6Lcp6WrXRKIpeVjdpVH21tV37\nasToQXiv9rWuesPXMmzH8lVr+UhcF5Rvd0k/Wr408NX7tJDLroj2rZcaDaGZSdpXfxgqV+0mHmr7\nau3+k5WtrF6LMUxadryyzfvUdwD09dVWo7D7xMYXle24cz7c5/8NDQ1oN8SKXt6r9xkA3jRJC/lY\n9/Dynha9sOWr9VPoq2nQVz30Ve2rYaMHYz0H+G9ogZX6A+uVDcg+B5h6os6p53/uBwCYAwyFzrYW\nvbMh/YKjFVdjdTP0Co19AmwRLnN5QwCsZNfLvZ/nnnkRAGDtQ/eY21Hd2fSUsm26zs/KO/vPSfGV\nq2drMapcxNX5H/8GAGD9/bf32lxEx/RwixY9VnF1zWrz2g5vfcboNRCdc7SyiXG8LRGgQo2rZg5g\nxNUuI65mzAH2Zbe8lQN0ICUHCOLqBiMHaC3XPrS7Q/+eToheDfTbqnKSLRI45a36GquarmPtpOVa\nRHHeJ74JoG9c7TDu43W7tFjXced9rM//Gxoa0NSlfe2VDDnAW6fq/ck6BzDi12jH1Y4O7avWtWaK\ncA0GYzwJoezmR6bGnYULfTme1HOdwLiVItK1T292498BAPMu/Eyv7bbzlqp2mx95weyPGP2eOe9A\nZZt+0nJlW/C57wHoWzohXmWIvBpxbP6Jp/X5f0NDg5nrtvfYw5AVJdq1KiJ6OxGnr2+rP+WVAwvG\ncQYsIYQQQgghhBBCCCGE5AgOwBJCAPins6lPaEezlkpDQwMaGhp6n+ANatk1q9GwZnXv/xcuYvH4\nsU7R+mqwbIKhrIMUF/RVUiwUra8yBxh3nP3n+/vMfr31/frNqlyx/v7bsf7+2zHv5LMGvSx9dfxR\ntHGVOcC4Y96Fn+kz+3X9H386atvecNcN2HDXDZh7+vmDXrYYfZUiXIQUOYlXZEaKfBexTgTOh1et\nGdRyiaSWCW3hknile6QoFF+9b+XqAVr2JZEoFEOSMF6hr3roq4UPcwAPc4DCJ1F+YKQYzYFXi8Qg\n7LoV9w1qOfpq4TNW4+qDjwwyrjIHKHhG+tyM5sCrRWIQdt0gr5li8lXOgCWEEEIIIYQQQgghhJAc\nwRmwadx8883KJqJr6YaMQsPnnHOOso2W4JbF5Zdfrmw33aQFGh599FFls4o4jyfajeLN9RFdfLmx\nQhdHb2ntUbYZ1brgOgDEKydoo1HQW6JdynZi8EpLCMnz1W5U2t64Uxf4T2/2sfeciNDrrwIAnnz/\n+3vtS//0fbVsqEyLqwDA9huuVbbJy7XYRniaFhgJb38WAPDYzb/pnSFRJjHVzlVq8Y9u1/dajAO4\n56U9qt3mvbYAx+w6LTIyvUqfrx0xvd/15WFl02vLLa3dhq+WaFtjlRYgaDZ8dVYmXzWKoUegtxPq\n1AX1Tzz5TP8dgNVrvK/ui+pln93epmzhtPj78XNO6vXVpy5Ivqqy+HotOoiIfYvbddMflG3Syacp\nm0zeX9saNwEANtx5Peaddh4AoNpQ6IsZ13ZH2oUXB7BiY5Nqt3WfLWowsVzvz1TDV7fFtfiBtSx9\nlb4K0FctTF8dRg4wqyb7HECi+pjmKwdYcv0PdF9KtOAsMLwcAFYOENL74qq0qNJ4zwFCHc3KFq+o\nUzZXpsXGIns2Klt08kHmdrIVRLH6c9y5l/oPJWVoWLXS96fH8PPn9Az9V6+5vs//z731YcS6fX64\n9+VdvfbRiqsuELhaf89feuNqtrm8OjbxGNzf7lLtOre/bva7vFYL6kUnzFa2sHEOYrXTzHWOJlnn\nAMP9bVWhz4eZA3Tp+/gJJ50OwAt8JeJq+v0QAF7bq2Nye0/f3y0fOetERLZtBAC8+vmkANvSu/5P\nLZtJzKnxNu2r9cv0zO9sc4AKQ/XMlerY0GXkAOu2aOGmTDnANCOGWjnAjniVstWX5j+uxg29pu64\nPnZRZ98P0ykN2bpMZSV6Q5YvOCPWLkiZ6bkqyFdbjPELa8vV5bV9/r9w4UIg4vcl1Zemn32uWnbK\nAi3oCADxLn2PlYg+PqE6414SCJxtuPN6zD3jgj62Ptso12KJuzr6Xnc9ceDlRn19NnfpGAIAkyt1\nHydV6nuHJcxVYpxXLcmo4QxYQgghhBBCCCGEEEIIyRGcAUtIkTNSs5WvnXfMiKxnuGy44zoAwHGL\nlwIAGlY/ks/ukBEkMaNwuPx+4bEjsp7hsv4e/0bB3OAp9Hh/c2AsQV8lxcKYzQGC+pipokWkuEnM\nfh0uhRZX5zGujjlG6lz+5YwlI7Ke4cIcYOyyaoTy1Q133zgi6xkuG+70bz3MPdG/lbDuwXvy2Z2c\nwBmwhBBCCCGEEEIIIYQQkiM4AEsIIYQQQgghhBBCCCE5YlyXIOjo0MWCH3rooSGv77LLLhtOd0aF\nT33qU8p28cUX56EnhU2ZUaQcYV2kubtbC0VNrNCXVYch4gIANWHdtlv0dspje3W7YJVbW5PCIPt3\nb1XtQjJF2eZV+vXNO+VsAMAzD9yKnobbVLt4sxaycBW6YDoArLpKlwqwXryZeslnlK378fvNdQLo\nI/KwrWqO+npaZ99jE4rHcOJBWqhj/RYtSgAAO1p1sfP9DMGUKZW6KHynUZh/tLF81YUMv4ppX51i\nFBlvz+CrtYbwYLr4CQBUGAIy3cFx2tKSLIB+oDQaW9GCPIurAl891YscPnPfLejZcLdqF9uzTdnM\nQu8AHrvqYWU7Lqz3ZcKHtJBh/PG/KlviDEhKwfiN0D44J9b3nhNycSydo8VSHntD+xoA7G7XIkCz\nDV+dbpxXS0RitKGveuirSYrJV4eTA7QZwhiAnQN0hbSEQ3lM37+yzQFKQlq4Z1659+k+OcC6O1W7\nmJEDSJktMdHwU50DLDCu5Wkf/ryy9ZsDpLCtUgt45TMH6C4AX3UlhjCqEVctcbdY3Uy9qCFO6Ldj\nyODE9XUuXa26XZcXgots+nty0SkHmttJZ+o7DwcALP/ONQCA1d+6DGVTtU/nM64iEMmL7HwxaTNE\ncywhs9DRpypb6bOr9DYAoGm7thmiabEJ+jqRbi3GN9pkmwNE4zpuWKI4g8kBosacs7JOLSrVHagv\nbU7JAQ6AzgFKwlroa16Jj5eJuPr0A7ci1LxD9/thLYhdMk2LjwLA367WvnqssX+1H/i0ssWf0sta\nOcCOKu0vk3v6xoGQi2HuLC2A9PcM4lKNndnlAFONuJouAJYPrD4YmuymwFXIaGh7qi2u5SL6HiuG\nIN4wz1MAACAASURBVJXVB0to1Vo21OoFDOcHonPrHrgL2PmKatf1Ty3Uvv3R58x+xDr1vbT2AC2o\nNyGsz3lJ/WRznQrjeKWnXCEBphiCb9Vldr6aLqAL2OJapeHs2mUDZ8ASQgghhBBCCCGEEEJIjhjX\nM2AJGa8kns6uv+/WPPckMwnhjYWLFlOEYxyTmE24/t5b8tyTzNBXCUBfJcUDcwBSLCRmvq746iV5\n7Ud/9IrGnHFB72cy/mBcJcVCn5mvBUpCfGv+iaeNOSEuzoAlhBBCCCGEEEIIIYSQHMEZsIQUGfNP\nPhMAcNsDDXnuCSH9k5gNcPP99FVS2NBXSbHAHIAUC/NOOw8AsOHGX+W5J4T0T8JXb3lgbZ57Qkj/\nLFy0GACweg1zgGJlXA/AvvKKLjb84osvGi01Z555prLNnTt32H0qFLZv14XeW1t1Yf3q6urR6M6o\nY04NN4pYT5U2ZYuGa5WttccuKB41Cm1vb+1RtlCKSEo3fBHpaVX68g216v68c0qyAPbCJcv8h07f\nLv78+t7vyo6cr5btfEILAbyx6h/KBgAxp/fxP//9XmWr/cZ9yvbv/3uRsklQaHtNQwMWLFwIAPjF\nLQ+qdtOq08+LQ2u3Ple/XbPR6ja+fPJhymbVf484XVC+Ymi1t3OOGMIY06B9oyekRXVicdtXraLw\ne9r0dhCZ1vuxW3wR9EkVuvB5qFn3Z/F+yWUXBuccHT7uxJ5J+mLZEcepZbue1YXid99jv7LSY5zg\nb11xu7JN/Lp+NeeLv/6gspUEvtrQ0NDb7+/8Sfv+nGk6suzr0r76P6tfM3oNfGbZwcoWtXzVKPdf\nIZkL+OcT+qqHvpqkEHx1pHOAdmvnAUTD2eUAYUmK93QHqXu2OcA7piR9ut8cwPDVzqfWKNu21XYO\n0G1cj9+8Uotr1X5L38eLNQcoyyitkmeMuOpKtHCgKzEE1YxlATuuhhs36XavPZX8T5cX84vNPEIv\n27JT2aLvSApSJQYZJOL9vGxy8hootLiKRQcAABrWrsXc4Bpb99N/0+2O0NK04WYtnPfqn7RIEwAc\ncImxbet8mcI+WkRwtLHiqiUQNNnIAaIhHVeNnx3ebvjqLiMHiFTM7v3cEwggW4KfoSYt2n3YJCMH\nCDok0WQMd1EtTNT4vL5u3rhmpbIBwK7mLmX7/hd1iYPpX9ciipf9/uPKFjJygJ8bcXVKTZoQmgNa\nDTHJGx7fYvb70gUHKFu2OUBZAeQAZYbgkmECQsZOGaJOxi0JgC3QaYk7hlPEnhLH0XL/cLsW7ZYU\nsbnEgwYEPhra/lLvd13PrFPLvnK7tj3zsPZfAHjS8FXgCWVZOl2Xvpj7xZOUrf49XiR+w+3XYe7y\nUwAAK9Y+pto1d/X11ZgDdrXpPGp3u74WAaCuXF/z5REtOlmtdb1MEbZsYAkCQgghhBBCCCGEEEII\nyRHjegYsIeOVtb/+z3x3IWese/AeHHXCMgDADfeuzGtfyPBp+NnX892FnLH24QdwzLKlAIDbHngk\nz70hw4W+SooF5gCkWFjz/c/muws5Y92K+zB36XIAwIY7rstzb8hw2XDPX/LdhZyx7qF7cfQJPge4\n5X7mAMXOhht+ke8u5IzbHngEp5/g36D43R0r8twbG86AJYQQQgghhBBCCCGEkBzBAVhCxgm9td8I\nKXB662kRUuDQV0mxwByAFAuJ+q+EFDrMAUix0Fv/leSdcV2CYMOGDUNe9nvf+56ylZcbxeyLlFdf\nfVXZNm3SRZff8pa3jEZ3Rh2rKDZK9PMKKdVF4UuNguJvtOhi0ACwrVUXrD6gXvvRnNpk8fySoBB3\n+KHfqnbxmQcq23FnfyTRW9zxkH9tROK6SPcbV+tXEnf+Y6OybX9CCycAwMRaXbA6tEcXrt8XNQqu\nH/pWZXPd7b2fE0X6j67cp9rFn3hYLxtUJo+miIK8aUaN0WugPKLPa2NHTNnqyrQ4T7MhSrN/lbmZ\nnNFlVLUvKdWhPWb4aplRPfyNVluAo7lT+/DMGn3Op6cIwyQumepHb9QrnHGAMh13RlKI5e4VXsgo\n0r5Ztdvycx1/dz6p2zW+pP0cACbM1OKB4d3tyrajS/tB2cE65rmuFIHCwFeXTdfnQJ7TonbxwFnb\nU2LOQVNtJyoxFABsX9U+vdeojX/AGPXVli5tn16tBUjoq/TVTHSYaiF6n7LNAbZl8NUt+zqVzcoB\nZtUm1R8SvjqsHMA1q3bbf65zgB1mDrBD2QBg8oQKZcs6Bzj8HcpWHDmA9on9R1mbVnr0MbYEl1yJ\nvl9bYk0le2xhvXibPvaYtJ9eZV1SLAthH1vcY1ogCDO0ry467b1+HQA23Oz921VNVO0KLa5u/NKH\nez93bdkIAHBvPUG1k39qX421NAEA2lbe1mur2X+aagcAiGgVmHCLvh7jVZOULdSyS6+vyr4mcoUV\nV8OlxlywMh1XjfCLrRl+WzUZ+eo0IweYmiK4lQgDlY9qATSxcoDTL+z9nMgBWox92Wf46st3/1PZ\nntmoYzIAvGU/fSzWNuprvrRb+2rE6LeVA7yzvFG1i65/QC8bnL5Ukag5kzPkAKHizgEsIUbbpvez\ny7jH7TPOD9D3HpWgplTfayaFkgJSEfHLlDduVO2cIQA2793v8x8khHUPBMKCKeJbCXY/9bKyPbfy\ndWX7e5POWwCgzThAk419KavNThDQ7U4ReAvE7Kr26vtTec1UZYuEvF9NS1HOau+xz4ElymiNA0UM\nn7b8vErrXSo4A5YQQgghhBBCCCGEEEJyBAdgCRknJGa+EFLoJGYTEFLo0FdJscAcgBQLidmvhBQ6\nzAFIsdA7+5XkHQ7AEkIIIYQQQgghhBBCSI4Y1zVgCSkmFi/yhd5Xr2nwBqP+GyGFgPLVtTfksTeE\nZIa+SooF5gCkWJh7/icAABtu+AUAILrNritLSL5RcXX9X/LYG0Iys2DpcgDA2kdWeEPz9jz2hgwH\nzoAlhBBCCCGEEEIIIYSQHDGuZ8C2tbUpm3OGxJ3BYYcdNtLdyRvxuFZ6C4XG99h8qaHgbKnfGSY4\n47nG7FqtXAoArzdrJUtDEBFhxCEpnwGg5C1zVbuup9cqW09wfne2JVVBZ9RpOcmJbzlY2TY+8Jyy\nRTOoCDZsb1W2D8ybpWz7trQomxx8tLLFS1KUoAOlXtn0lGpnqXyWBnKpc+qSSouXLz1IdxrAi3u0\n8m2NocrZbZyYaku9dZSxfNVwS7OdFe3mZPDVx9u6lS1mxEvLVyOH6vMbfenvypZQWW9JUQudaigg\nTznmSGXb8Xet1BnL4KsPPavVZs89bqayWb4amqMVkKOGr4Zff0K1kyl6Gwm1+Fk1yeP+6YUH6E4D\neM5Qac7aV0vGka9uo68C9NXhEDG6MJwcYEa1nW6/tlfnXyOdAyRUlve0R5P9qa1Q7eqP0DnAphU6\nB4hlUHNe/Yb2waxzgAPfrmyFlgN0GkrJhZADuEiZNorul4uU63YGPdMON+2h51frzcSjyua6O3vV\n1V23V8uOHH6samfF1d5+lyb7GiuCuHrw5z7f+7l89TMAMsTVafvp9dVNAgBUnPYvvbaKqL6HAYB7\nTfu/q5igt9OjZeTjFXXmOkeTbHMAO65qUu9FqWxv1fsf06HWjKvhg3Usir72tLIlfltta02eq6PK\nmvR2p09StsbtevyhImzHkqde36dsH1o2R9nadup1yv7aV6242nnfdapd6X76flBq5AAfOW620Wvg\npT36921dljlAbWnYXOdoEjX6ZeUAJYbNGd2vypDXdEb1dkz/D0WQuFr8Z8CV6NhvXfsIfDXUlfSR\nePMe1SzWaceddEqtDsL24QUzqpVtv4X7K1vVIYfoFVpjUft2KlOkR9/DS0P+uM4sS96jIpN03gMA\n7T06OJRFsjvXFUa7bBjXA7CEFAsLFy1Gwxqd/BJSaNBXSbFAXyXFAn2VFAvzLvoc1l/3k3x3g5AB\nYVwlxcLCxUvRsJpCmmOF/D+6JYQQQgghhBBCCCGEkDEKB2AJIYQQQgghhBBCCCEkR3AAlhBCCCGE\nEEIIIYQQQnLEuK4Bu379emUTGVox3WLGEtwaj8chFVMExBAb6Hb6EiqFbvd6s7YBwBNv6OLqx86s\nMbYdQ6ijr2hAdJIWlIgsmKb7E74WADCrJilGEd7ymGr38t0NyhY2BCamHK4FEQDglL2dyjbjWC04\nUDl5t7I1XX+1slVf9p3ez4kS5VKqRSQkUqptiX9T3NgqdA4As2r1Ojc26ULkR5Y0KltXzQxznaNJ\nZdjaL23rjutzWer0fm7SLgkAeGK7/uK4WVq0Ai4O6e5bEL19shYtLKvXhfsjIe8HUyuT11V40zOq\n3Wv3rFE2y1cnH67FDwBgcYmukD/1rfraEUMwYue1P1O2CZf/qPdzr69W6sLziOo4YPlqa7ehFgFg\n/7osfTWki9R31+trcbShrwbt6Ku9FKqvmjmA0/tq+upo5QBTtGiFlQOUBDnAjOqkcIqZA9xpiCwZ\n+WEmXz1xsha4GPEcoFwLiI73HMCVVmbXMJadwErJjhdMe9uzj+u2RxxvrKAcmP3mPqZolfYZMeKq\ni/jasdGU78JbCj+uTvz893s/J85HqEL7qus2BHIMQp1aJAwAMPNAbdv+vDLFa6ZoW70WxBttRjoH\n2Ko10gAAL+zWglTvmK7PB1wc0tP3t0vnFJ0DlBi+Whr299LZKb+tYhseVO2evkb76oQp+pqdUm4P\nyYQNQarZiw9VtuZXtmnbX36tbFUf/Vbv58SRL52j1xeq0eJuFplygFm1WiBq074eZXtzaJeyddfm\n31ctwTjDBEO/CZ1Rbcx0nLoNccewcd91AKKS5iNGXA11a/EzCXIXSRGr6nrln6rdnhd26L4Y+/x2\nI78DgMlTtV9PMmJw+aRaZYvv06Jgri5lWUMwvrddqZHDpgmWAYChxetbGvsYMuQBLdHNiLNzu4Hg\nDFhCCCGEEEIIIYQQQgjJERyAJYQQQgghhBBCCCGEkBzBAVhCCCGEEEIIIYQQQgjJERyAJYQQQggh\nhBBCCCGEkBwxrkW4SGZqa3WBZMs2Vgm16qLgHRWTlc2ojY5Qhy7+HhK7YHUsritCN3XFjJWG4Er7\nClxIp64+vzuiBbISBaKnxJqS23jwLr3s87oA9r4tehuVhtAGAEQ7dL97WtqVrXZ/3cf6Cz6tbF3B\noVmyaGHSONEozN7epG0B4ZTK2vvXlphtXjeKwk+v1aHR9ehzGIlbohb28ckVofa9ytZRro+x7au6\nWPtg+t9onHNICK6sb0F0LZECbOrS5yMc9+eirj0pJtC04h7VrmWb9su9r2o/KKmwb3HRTl00vXOv\n9tX6OXXKNuVfPqtslq+6EkMYpVOLzyRI9dU5ddYRA7a0aH/bz/Br16ML0kdiWiSPvkpfBeirFkWR\nA3Tp7fSXA0weIAdofElfn02btBhQ9TRDzAbDzAE+8EllM3OA+ul6w8wBlM2VGQJQEd3/dGG3wRJu\n0aItCJcgVjuwOFm48XVtDIRWQl3JuFkMcRXPrwIAzLvoc72meKlxDpp1no2QDyLx8uS2Yoa4DgCU\n7HpZ2aJTtRifxPW1KKYIm30t54rh5ADSad0XMsRVQ20nYw6QJmJXalwTe8NaGDHi/Pomx5NCirv/\n9g/Vrn23zl2aOnTMmbmf/RvbisGtm7V4ZfUs7TN17/+4sllx1RLckvLM4n5946qdA2xt0fs4u8aI\nq916O4WQA5QYfXBhva8lhmBWiaFclUnbPGrkAJYwlwPQk9a2xBAGTRXa6l025s9FvCnpN02vbFXt\nmg1f22HkI5YwFwBU7NG+PtHYv+pZWiSw9JC3KVtokr+PzD3rYkhJcOxrdB7mrIPbKzyWPI91xn0R\nAMpienlLhC1ibEZiFOEihBBCCCGEEEIIIYSQgoIzYAkhBc+qNQ3+w55XB2w799wPA6WZn9wSkkt6\nfXW3nqmSztz3XQZkeCJLSK6hr5JigTkAKRbWX/eTrNvOu+hzjKskb/TG1edWDNiWOQDJJxtu/z0A\nPxN4IOaddh4QKuwhzsLuHSFk3JJ4PaY3QRgEj6we/DKEDBX6KikW6KukWKCvkmIhUXpgMIOvCdau\nfGiku0NIRhhXSbEw96yLASQHXwfD2ocfGOnujCgsQUAIIYQQQgghhBBCCCE5gjNgxxlXX311Vu1+\n+ctfKtt+++030t0pWOIVuvC+VZA5bIkfhLXQQ3OXXaT5lMN0IepIyCom7YC0Qs/y+tOq2bSJhkBF\ngnhy+R2Pv6C+nv4OLZrQ06aLqJfVlpmrf8dH5ytbxX6zla30nScrW9QQHEiICJR2JouCu5CuzN82\n6ygAwLIli4INVGLdln2q3bEzdRF9ACgxHkNVG5GxO6KL1Bu1xTNIAuQO01cNHwobwgLW8Ww0xAEA\n4PgDbVEIjQPSxB5Kdr6oWh1oikwEBzTF13c/9YpqNe3oOcrWta9Lr80oZA8Ax3z2eGWrmq6FIMoX\nv0fZovVaCE6Ca6u03RDYSKFlzlwAwPKEr5ZVYc1m7avvnKGFiYC+wgcJLF/tDGtftcKKfSXnDvqq\nh76apKh8tdBygM1GDjAhuxxg1xMv6WWNHKC7TQv3lGUQXXnbh5coW7Y5gCU6xBwgOyxftV6/tAS3\nrOMZs4SiAFTO0+dNS8AAiMfUtsJNm3WzMRRXURJEqBmHpjTU/Ynt8sI3Cz/z7V6bPK1nwMqbF+lt\nAHDGeU0XkgQAFzfijeR/zpXlq2VO38ct0U3LV/dliKtL5uhzaYsg6RwAz69RrSaWWQJQgW+lHOvW\nrVq88R2XzlO2F297UtnK6+3Icdh73q5sZfU6lk04+Sxl66mZpmxWXMX0g1W7trr9AaTE1bIqM64e\nnSkHMNytMmwITlXouGox2jlAQhAwFYkb4mARfT8Mi/bVEuu+DqCqVB+oujK9fCgeQ1V33+NvxtU9\n25QNUe/jsR3J9iVV2qenHKEFrsKGIp5k2Jc6Q7Rw+jFaJLDqTUcoW0JwK5VEvEu9V8Qrtb+0i/eO\n3nw1XIptMX09uaghxAc757LOV5dxPwmH9Xay8dX8R2NCCCGEEEIIIYQQQggZo3AAlhAypli5Sj+9\nJqQQWUFfJUUCfZUUC8wBSLHQ8NOv5bsLhGQF4yopFoohX2UJAkKKhAXHnwSg8AtLjzaJ12OYHBQO\nC5csAwA0rFqZ134UGonXY4ohORgv0Fdt6KuFB3MAG+YAhcf85acAANatuC/PPSksEqUHOPBaODAH\nsGFcLTwYV22KMV/lDFhCCCGEEEIIIYQQQgjJEeKcXaA8R4zqxv5/e3caHld5JXj81KbS6kU2qzFh\nCziEBKYTkKyyLO/GBuwOCUvD4KSn5wGSTtJNmu6hmw48oZsHZkhn5gNJIOkJ6WSgE4gJNoux8SJZ\nLlkinTyZZMg4DTH7gG1ZlqxdqmU+3Lq16JySSpavpJL/vy8uv3Vv3beqTr331a23zhnL+++/r9qu\nuuoq1Xbo0CHVdtddd6m2Rx555OR0zEPnnnuuaovHdVLit99+W7UFg1O6YNrO+OyRQ129KlbnJPvU\ndv7BbtVmJYh++vUe8zh+Iyv85Wfq5Oohv08+u6Yhp+3XL/wvtV0ypFM/166/IX3b/Ya3O6Hfyxaj\nwMq1wYOqbfB3r6o2EZHw5bpogFVk4UjJ6aptKKuSxYaVzvMMpr4eikaj6ftCb7SofRfd/DUREXnk\npy+n2w736uIH116sk4s7x9HvgZVf3NrOKjQzq6JsUmP1yHEdq7PiOt6sWI1X6gIwPzmgC3WIiISD\n+vu6K8+eZW77mRGx+qvtm9U2iZIK1Va38ur0bTdWjwzpBPCv/KFDtX1+vi5+MPjrvWb/Sj+hC8bF\nZuniNYeCOmYGY5nE/O7zdAtJRZub0vcFDuhjL9r0dyIi8tBTL6XbDhtFbjZcot8XEbsIkFkYyIjA\ngBG/leXEKrFKrFpm+hygK17YHGBjqZ4LDrzWptpERMKfqFNt1ue2PaRjdSrnANZ53CogM13nAEMd\nH6pYTRr9KnRcTez6oXkcv1GIKPjxiN6w432pueXLOU2tO7bq4xifk7plq7L22WL2Q0RkuEXfF75C\nF4Gb6LjqSxqFeIYyBaJqrrnRuRF0PnfZ46rs/5nat+6L9zvbfetv0m2Jnk61Xcnym81+J4NGYZmg\nUfIlYBTrMgp4lVZUTf0cIFHYuGoV6tv8ut5OxJ4DXHqaPo+X+H0FzQHcIkbZaq/7k/Rtd1ztjOnX\neN+7ep7ymSp9XWH4wC/0cUXEXzlHNwZ1oUffeZ9Qbe8FM0W4NqxynmfIGlcPtqp9F934lyKSO662\n9+nX4eqL7KKn1vneKmxkzQGssbZikucA/f39KlZ9Iwu2iZgFD7uM+rCHeu2CcVYx2eoy/f7OLwvI\nuhW5Y9yre3UBPytW61auS9923/e+mL4kZ73upb06Vv2HdGFEEZG4UQDMX6Hn3j6j6Fv29QJ3pa+k\nCu9lx6rV76X1znno2R2Z8bd3SI/dISvYRGSWUWisPFRg/J7gfJUVsAAAAAAAAADgEXLAAkVic+qb\nHXcVjPvNe9uLT4+6n/vNbGTpsnSuo5cbp3+elOxvvFBcfp6KVXdlQe2ajSIy+ooWETtWn9vZ7E0n\nT6KcVS8oKsQqisXJnAO8tIc5ALzT9tSjIiLplbC16z8nIiKtL+lVodnccbd2zcaCx+LpgHG1eJ3M\nOcCLuxlX4Z1tu50V/e5K2LqGFSIi0tK0e9T93Pc8EolIJOKsFn2liVidSlyABYqM+0fYLav1T6hG\nE93bmJ4kzBQ1N39JHvlp41R3A3m4E9ubVhs/UxzFTIzV2tv+Uh7618ap7gbyIFYziNXpjTlABnOA\n6S19IXbT18a1X+uOLekLYTNF5EvfyEk9gOmFOUBGzY13Mq5OY+6F2OtW1o9rv2g0mr4AO1OsaFiS\nk3qgWJCCAAAAAAAAAAA8ckqvgF2wYIFqe/LJJ1XbypUrVdtjjz2m2r761a+qtoULF55g7ybunnvu\nUW1W4bE777xTtU1xwa0pN6tEfzcxmKhUbbGATuo+FNcJoueVl5jHOWYk3379qE5IXx7S/YkFnMT7\nn9pwm+xodH76WvKTf1TbhSoySfuTnU4y7cWLnWIZ2cnVW948qva9rPaTqi2wVLeJiGx/Q+9f2aXj\nqL3vsGq74wqnEESkviGdIL5jQBeHm3/uFbkNJWVyUXW52m7BLF2UYE5YJ9kWEemP6UTdZUYC/w96\n9Ht1eoVOlD7ZqqzYCOqk5/0BI36NnPALZukiDyJ28v0D7TpWK8P6PY/7ndfpyqs/J680ObEa3vyw\n2i5Qmnnfkp1OnCxLJVfPLgbU9O+6iFHtwotUWzDyUdUmIrLDiNWQbpIPj+vx8r8sdgp1RJYuE3eY\n6I7r92DWhSMKOoYrZNFp+j1YMEu/CXNK7VgdiutYDRtJ5YlVYlWEWJ2ImT4HiNQVNgf4ZO3HVVtg\niW4TYQ4wVRJls3WjUTxKt4j4B3QRI/8ZuliviEiiU49l8Td/o/c3iq5IzHntatdslLaffEdERPq3\n/VhtlswaNxKd7SIi8qnLrxQRka2brknf5zOqpJ1jFOEq+dQq1SYi0h81fl5uFDYabNdFFKs2/pmI\niNSuu14kVezMP9irtgtclBurvtJyCX/s02q7RK8ufhcziqOJiPiGB1RbskTHf/DoW/ox555jPuZk\nMucAST2uDlrjqlF8ZzxzgIMd/aqtPKTHBGsO4HviPrVdMGtcTXR8KCIiSyLGuHpQj4sfr9FFiEJ/\npNtERH57WMfW/HJj3NGnDbnydOfzlD0H6BrUI8HcBZflNoxjXJ2bdw6g3y9rDvChUZzqtPKpvwaR\nMMrGx336PN4/rF/P7kF97uoetItwdQ3o80qPsW3Ir98PSTqdrFu6XFoanYJcwWPvqM0SpVnniFQh\nseVLncLZ2/dkUmf933b9GRmM6c+iv/xy3RcREaPdmrtUGp+72T6nbf2KpSKp211GIa1kUv/fKo41\nyzjfW4XhRmsfacCI6ZARJ4VgBSwAAAAAAAAAeGTqv2IAcELcFS9rltXLmmVOHpjGO9eOus/eB+4Q\nEZErHnrW286NU6S+IX3bLWZwVH+BjSLlriJY3VAvqxucWN371WtG20WaH/yiiIh88h+f8bZz45Sd\n68stwtBtLS1CUSJWUSyYA6BYuKtea27+ktTc/CUREdl11+dG3Wf3331BREQa7v+ep30br9p116dv\nt25zPkcnuAgK05A1B9j5heWj7rPn7/+TiIh8+r+/4G3nxsmaA3TaCzFRhNxVr3XLVkrdMufX2q9u\n/v6o+7hj1qevudnbzo3T+hWZXzC8lMpxO5OxAhYAAAAAAAAAPMIFWKDIuatgZgJ35ctY3G/73G//\nUBzclQUzgbuaYCx1y1dL3fLV0rLnFW87hJOKWEWxYA6AYuGuhJ0J3JVkY6nZcKvUbLhV2rbqGiOY\nvpgDoFjMpPNgoatf1y6vl7XL63Ny2BYTUhCMcPnlOoHwl7/8ZdX26KOPqrZoNKrabr55cpZ4Hzhw\nQLU98cQTqu2ss85Sbffdp5OLn+qMnNoSjukM52UDXaotUT5XtV0yz0ieLSJGHQE5biRIHzYKmgxm\nJaR/fqczYK1a6SSA3/9K5mcwHT94RO1bUeIkp65bmCmWYB3jl/9PFwd4rOmg7nQeSSOL+bdvznzG\nbly7TEREwj4nYbm/71j6vq64Lixx3frVzg2f88INh8rlkpAuDNFTNV+1dRlJ0UVEQkby7YBRruJc\n0cUYjsftQgmTadBICl4a00n7Zw/qYhs5idlTLphbZh5n0Xwdw/le05GGs/ronlyXL9exevT7/1Xt\nW5mK1Ybz5qTbeob0caPvdKq277/8e7M/4TJdwCBuFGJ57POZghl/sm6ZiIiUipMwP3D8UPq+jsCZ\nat8161I/BU7F6kCwQhaF2tV2xyvnqTbr+YkUHqsfSegiN92Js83HnEzEqmOisRozHvPxP70y8fzC\nOwAAGEVJREFUfZtYnTjmAA5rDvDdPX/QnRYRn1EIgzmA93wxXZhJ4vp3xv4hPdZa42rgnEX2cc7X\nfx/5e42KgIbs47S+8ryIiNQud8adlu9+I33f0Zb9at/hVKG6w//7vXTbZXdep7Yb/NUe1fab79o/\nCQ8bxYRiA/o1+9gtmZ/FLr3vcRER8VU6n29fPFM8x3/sbbXvVZ9z0n1IwCncE5tzjviGdZGb5Lzz\nVZu/X48r2Y+Vs7+xWTKkn58vNvU5Pcw5QFy/JiFjDpAMV6m2j8zWz1NE5KPVujhXR7/+/MeM8Wkg\ne1zd5Yyra1ekxtWdmQKb7Y89qPa1xtWBmD7urz7Q4+rju95QbSIiJaX6Uo05B7j1P6Rvf3aNk9LF\nHFeT1WrfdcYc4JISPd71VOl9j+cZV63CRgEjWhcm9RjSmzzdfMypFk7osTZsnMNnVen4q8pTBPLM\nSv2ZtgqYGbWsZDCYmVfs2eeMnVdFakQkdw5gFe9z357KrIKjAZ9+z159T89XW42CmyIicePztOhs\n/bldc0nm/f3z69eIiEhpqvDlcNZjWMUwVzY4xcPc6UZ5yJ/zHFxWETXr8y6SeS2yWfHrSxj5O3wn\ntpaVFbAAAAAAAAAA4BFWwAIz0OLV1+Z8+zVSdvEO170/fN7zfuUzWl8te/fp1eYoTicSq3/xvec8\n71c+hf7s0NW4d59HPcFkI1ZRLJgDoFjUffH+nFWwI/1LrfNLlM+3/lu67ZfGCtjJ0vbi0+Panp9z\nzxyLV63PWQU7kjWu3vPEVs/7lc/+HeM7NnOAmWOsOYB7Dl26JJJu+86zOz3vVz5bdxWWfsjV1Fzc\ncwBWwAIAAAAAAACAR1gBC8wg7rddi1dfW9D22d/W/u0t60VE5KGn8n+7ezK4Od9ERJ7e3pi6lSff\n1Yh9Wve3eNInTL6JxOoDm5x97vvR+FZNjZebS1NE5F+3NaZuHbI2TXNzb1k5wVGciFUUC+YAKBbu\nqte6L95f0PbZK2GX/PW3RERk3yNf86ZzKW7eVxGRvQ/cUdA+NX/8BRERaWkurJgMpj931eviVesL\n2j57XL03Na4+6PG46p7PRUQ273BXE+o8rtY+zAFmjvHOAbJXwv7phhUiIvLE1t3edC7FzfsqIvLt\nZ3cUtM+KVO7XvUW+8tXlSybthLQemdSDnSydnToBcU1NjWo7ckQXAvj617+u2u66666Cj93ergth\n/O53v1NtN910U0H9efjhh1Xb3XffXXB/ppCRItk7vX39KlatJM09RqUOK2l0Pt1Den9rdytx/e/b\ndeL6cNbOf7ZhpYiIvLbl+2o7q/BC7bW5BeOizU1yx3P/rra7/gq7QIqV2/qSrGI46WIbg07y+f0/\nyBSx8V34R2rfoapMwbj61E8kfvHk/1DbxY68r9p8H4uotqEKu1iGlezcyEMu5UaxoFhJpWqrKC+b\n8lg1aqHIgFX8wAjqPDnKzQImFUZWeOMw8sYxKwF85ti3XeOc9H+79Qf68YyCNouvvj7n/9F9zbLp\np6+p7TbmidWQ8QJdenrmvbzp6mUiIlLa7xQuiH4n8wdi4GOL1b5Dcz+Svp2O1We+q7aLva8L2Fmx\nOlBux6qVQJ5Y1YjVVL+J1RNW6BygL6afZ9jaMA9rDmAV5howjvP60ek1B7BcPK8iffuGtanCMCd5\nDhA/+qE+8CW1+vEmOgcY1oWBYuFZqm2yY3Wgr9d4AjqIfEO6iFyyxCgOlzQq0IldcCtZqgusWAXA\nAu//H9WWOJb5G2Xxf/5bERFp2/zP+hjzFqq22oY1Of9vefwBeX+LvsB11qp61SYi0vuGLnj0blOm\nj7fuaRURkdJq5/k1P/yV9H3hxRv0A8YzBa5qrrtFRERe/bk+RxQ6ribK9OdTJE8hLasIjPEeWu91\naXnFlI+rXswBrHHVKqozYBRffadLv8bZhfrcOcBvnv+h7o/xvllzgLu36fhbdrE9PpUaJ4QLqjMF\nSNN/Ww0541Prj76V2fDcy9S+5rhqzAHih97VnfnoVfrxZujfVv39OlatOaPFist819us0dZ67fqN\nWLUeMnv+sW6FU0SwrUVftLSey+K63LHohV175fdH9XkjX/HVsyp1Ubzz5mTa1i53xuNg6lm3NO5K\n3+c3Cpr6hgfTt2vW3yAiIm3PP6m3M8a7hHFuSlToQrIiIsMBXTTNEiwwAkvLxo5VUhAAAAAAAAAA\ngEdIQQDMcDUbN4mISNuWH426XbTZ+clKpL4h/e9lf6VXzoxX9s8NRXJXvRSifoleHYCZqWbDrSIi\n0rZVf8OZLbrP+XlXZEl9+t8Lv/LYhI/vriZMHydrNWEhiNVTB7GKYjHVcwB35auLOQDyqbnxThER\naXt69DGy5fEHRESk7o770v8+s16veh4vd+WrK3vlayHc1a+Y+dz3uu35p0bdzpoDLL73iQkff+Tf\nVjkrXwvAuHrqiESc93qsVBMv7HLSply7cmn63396+uUJH99d+erKXvlaCHf160zCClgAAAAAAAAA\n8AgrYIEZ6n9udb5h+osNTuLqmo2bxlwBI5K7CubFr9+Wc981//DjvPv99U1Xq7bSrHx06WIbB18Z\nsw8iud/ONqeShMvbvyhoXxSXH7/oJHy/4xrnPa/ZcOuYKwtFclcW7PzG53PuW3X/v+Tdzy02ky07\nVn/6cqNz47Vnx+yDSJ5Y/eA3Be2L4kKsolhM9hzgr24cfQ7wzPZUYZiDhRXdYA5w6tj/zw+JiMji\nO5y6GTU33jnmKliR3JWwG5/annPfllvW5t1v1T/pMXuoO5NT+cnl41tNm73yNb0iMqbziaP4uXOA\n269JjavX3TLmKliR3DnA03+Tu1L6xv+Wf/+7btRxHA4Yf1u926S2szAHOHVs2+2saL0utaI1EokU\nVHAteyVsdsEskdGLZrlFvLKVZMXq9j3NqVuFjY3ZK1/bXnrGuRE38mAXIYpwnaCenh7V9pWv6J+q\nbNmyRbVdeumlqu2ee+4xj3P77bertsOHDxfSRdm0aZNq++Y3v6naqqurC3q8KTapybcHO4+oWO30\n6QT2s0v0IvKOAZ2c2kr+LmInil9gVANOhitUW6/oZNdBI5v9UNw5xqqGzE8A2hr1RVDfa3t024jC\nGLVrNko8EFLbiYgEEnHZvzt3Ahzo0gWyjsxdZO7vSifpjjvJt9tefDp9X/z1X+o+flIP+Mmgfm0O\nDui2fM6frZ+jv3/0Ks2ukuozpzxWj/t1rFYZRYisWA3lidUuY9tz/cdVWzJUptr6ArrNKqjg1rRr\nqM9MEFub9E9VfL812i64Iuf/tWs/M0qsxtKVQtNt7W+p7brO+bRqi2edM1enPlOBhDMhaN2WuQiW\neMOIVaPYhvV65YtVq1jFRyp1o39AF4tJ+gOqLTz3DGJ1BGI11UdidcbMAdyCaG4VYRGRtj3MAVxm\nrFbpGLRi1TId5gC+hI5Bq0CQv0cX65VAiXkcqwiX9ZjWOGGxikL5Uhct6xoy7+n+7frvqORvG1Vb\n8PyP5/y/ZsNtkswq4pIjXC6tO3If19f+ttosvkAXMQoeeyfnGM7Ozj/ZX8bF3jmg9i10XA12vmd2\nO2kU3IrNv0C1FdO4WugcoHOwsIKFIiLHjW0X+k58XC0P6MsXgwnnpVu2NDOuWnMA/4Fm1TayONbo\n42pM9u/MLTAX6D6ktuuYN/q4mpkDDDt93f7z9H3Jg7/WO1ysv4SwxtW3Bgv/26qY5gADRhEu43Rt\nFteyishZ5xkR+5xtzTl9CV3c0CzAZ/ClLlqONa5ahRolmHs+qF2z0a6m5jyC7N+1bcw+JkNGzGQV\nb6xbvjq1ofOC79+xNX2fGS/G5zheqYvDdcft18sqmmYVVC9J6ou/PqPoZHj2PIpwAQAAAAAAAMBU\nIQUBcIrY2dScXgVbt2yliIw/EXbrji3SXX6meV/lsF5hNhHZq15wamlqjqZXFrrf2LY07R7XY7Ru\n/7l0V5xt3lc50D6xDo481rbCfv6NmYdYRbHY3bQvvQqWOQCms5am3enxdPHqa0VE1C8BxtK29ccS\n/+BN876RK7snqpA0NJiZGvfuS6+CPeE5wGjj6mDHxDo48lhZK19xajkZ42rrji2SzPMriUJX5BYq\ne+XrTMMKWAAAAAAAAADwCCtggVPIziYnF9HahtSKrdQqGBGR/d/++ynpk0gm55tIVpLu7rempjOY\nFpqanUTxy+vrRGRE7qJH752SPolk8miJiLyS+jxJj85ziFMHsYpisbtpn4iIrF6ailXmAJim3FWE\ndfUNIpJZsSUi0vKtu6ekTyJZeV/FWWULNO51xtUVS3N/DSMi0vr4N6akTyJ55gB9H0xRbzAdjDau\ntk7hr07SeV9FpMXNUT/UO0W98R4rYAEAAAAAAADAIz6repuHJvVgk21gYEC1HTqkKxXed999qu3J\nJ+38QQ8//HBBx77hhhtU28KFC1Wb31+019wntfphv1H90OqAVTGw168rmuYpgCylw0ZVXSOHiv/N\nf1Nt8UuWqja3Kme28t4P9b5VZ+T8P7KkXgZHVG18ZnuTXFhiVERMGmUgxa6GO/IYIiLW3lte2ava\nho2yhKVB/fzmhHW1zCP9uiphdam94L8/pns0O6D39/frCqrHSuaptjNmVxRNrFoV3/Mpj+lvIpNB\nnQco+Ic21Ra7uF61WdVEy7r16jwVq/UNdqyGetS+bkXlsR5zJDdWjUKm8rPtjaPu6yozqvhWl+pY\nPdSnY222EdMiIkNGh+b4jaqcRpXQjpCO1TPnEKsjEasOYpU5wOYdTXJ+iDmAa7rOAaxq3RazwvV4\nGO/7RMZVievxINBzRG9mjKvS15nT1vK9ByV4zkVq30SFfn9E7HEnewyu+eMviIhI0q9jpu2Fn5iP\nqY5tVOZOGON54Ni7ert8/Y4NqrZkiR5vrOdnjSsl888pmlj1ZA5gjavn6fzAiWCpagsdfUu1xeae\nk/N/aw6weUeTnB/UcwDr8yAikjCqueccY5Q5wOYdTaPu6yozxtW5xhzgSF9ctVWF7esKBc8BBvV7\ndSw0V7VN+rja06WfgDEeWJ8rK8jzdd43aMwHjZWfvmFjjhgosD+hctWWGDFuRJYuk+SIXkajUfEZ\n477ZFxER4xpTLJD57NQvcVaIG6d2eX6XngNYlygDRriV5JtgjdzXZ29n7R62Gq1zoPF6l5WVjdkh\nUhAAkOi+ZnnzeO6J9Ya1DRL2OYPN/t3bT/ix3cmBa3fqpzrZugftP+iAkaLNTfJmd+5ZOSdWd207\n4cceGavpn2xl6RrQE1DAQqyiWET3Ncvb3bnx8tk1zAEw/USbm8T3y9zCMXW33yu+UFhERNq2/OiE\nH9u98OoyCxYN64uggCXa3CRv9+TOAXLGVY/nAMcZV1Gg6N5GVVwrEomkL8lG9+n4KpR74dVlxepA\nbEav0VSKdjkkAAAAAAAAAEx3rIAFYHpme5P8x6udb1gXr1ibbh/5je3ileucG377p6gjvzXr5wtZ\nnGQ5serGo+jVK7VrNjo3AiHzcUbGarf+FSowIcQqisXmHU1yy1rmAJj+Wr73oET+/AEREanZuCnd\nvn/niznbpcdV62e1yYS0PffDnDZCFSdbzriaNQfYv2NLznaL12xwblg/exfmAPBeNBqVJRFn9Wr2\niut0kayUdDFP4yf+SfFJ875oTlufkfbnVMMKWAAAAAAAAADwCEW4UCwmNfl2b59OFO83euBLFPaV\no99KiC8iidIqva2RkDxprILqFJ0UflaJ/k4l1H5QtcXmnafaOgb1x7M6lJsTLrJ0mdrG9eJundet\n3SiEYSXBthJoH+kd1v0p198En+frVG2Jcp3A3UpMLiLSbxQtKTP6Yw1eVpL52ZXl0zNWrcIpRlve\nWDUKrFjJ48UodHA0puPXKt4TPmrE6vwLVNthoxDA6eW5jxeJRNQ2rud26mTvxwrMl2klZv+gRxcW\nOKNCvw4X+jtUmxmreVY8mLHq1zFoJYUnVolVEWK1UKfiHKDTqAEzJ5j7/MY7B+gY0K+Pz5jOTbs5\nQBHF6kBfr+6E9VyncFy1ikJZhaaCVmEjY1wNdBlFEGcvyPn/aONq67af6cfsPpx3+2yJkC4GlWx/\nT7X5552t2uJGXI5nXLWK3yRDehyw3n+r0GO4au6UFze0TOUc4Fhct1UZ42q44y3VZo2rR/r1+fq0\nssLnAFZxoqPGY1rFiYLGScsaV+eX6/PL+X5jXLUKLXowBxg23v5ZFWMXNjqZJjQHsFbYx8exRNko\nZjUyP2ve48SMQmdx/Z5bc4pCxpJIJJL3Yt4uI797n/FmWvleB+N6O+t6QblRNNaao5f4Cl9la8Wg\nVSjMet7WdoXEKikIABQsurfRnhiJyDHqvWAaiUajInkujhyhhgamEWIVxYI5AIrFaOOq9ByZ3M4A\noxgtVjuML8aAqRKNRsX43lFERIasq5EwkYIAAAAAAAAAADzCBVgAAAAAAAAA8AgXYAEAAAAAAADA\nIxThQrGY1OTbh7p0UYM5icISuIuRaHsgWGEep3RYJ5AP9OriJ/Gq0/XORr6g3qAu6FHu14nZ+hI6\nYXVF3Cj8YTy/IZ/xnPOwxher2IYvrpMc9SV1gvCE8XhVPr1vu1FIZ16pfs4iIv1GMpvyeL9qswqe\nzNaHkdLyikmN1Q87dazOTRYYq0ay9oGQjiERkbJBnZDf39+l2uKV8/XORs7A7kClaqsM6O164vp7\nwqpEnz6EkTx+KE+a87gRR9ap0CpiZDRJn5FQ3orVSr/+zB4d1nE5L2yHUF9ct1uf205fuWqbHdSv\nbWlFFbE6ErEqIsTqhOYA44jVaTUHiOm+WLEa89tzALNoBXMAzw21v6eegFVgpdBYNbcTEd+Qfk0m\nMq5akiV6PPAN6THUKjJlFpDJc1yzII6xrdUfs8CVVeDJejyjj/6+Y6otUXmafjzJ81pY2xlFd6yC\nf9NiDiDGOdIq5GbE6mBIn5tFRMJDx1VbweOqNQfw6zgoeA6QNAqnGc8vmaeY1UC+hJsjlBgVogqd\nA1jKfTqGxjMHsIpwlRvnu06/fg+nwxygxyjCFRRjPDHOw/nypJusolkJI6G6NeYVWkirwP5YRf7s\nDe31m1ZcF3psq5hb3FgnGjMmGj6jWJc19whYVdTEnrtYn51CiwYXMgdgBSwAAAAAAAAAeIQLsAAA\nAAAAAADgES7AAgAAAAAAAIBHuAALAAAAAAAAAB6xMz4Dp7i5MV1YoKdkjmorC+o8y8cTOiHznOSg\neRwrwX8iZmzr19+VDIdm6f5YSaN79XMpq5inN0zoY8SNYhslSaN4gYjEfXo46RgsLPn2/DJ9nP4B\nnYS82iiikTQKY8yP64IICbGT9ZcambaPDof1sXWTJPMkIp9M1TFdwKEnXK3ayo1Y7TAKlVSLTrwv\nIpIIG4XkjOTzVnL24RL92pcZx/D3HlVtlVYxiiEjVo34K0kYBUbEjusOI956jZdifpmOwf6Y8ZkP\nWwVfdB9PS+qiBHHfbGNfkTLjjN0e06/kPCsPvvH6TDZi1UGsZkzXWJ3YHEA/qeKYAxhFMIz4CzIH\nyBx7GswBTFbBlqBRAGpAf6YTVhEXEUlOYFw1C8MY21njqlWQqtDCXPkKwFj9CXQf0hsOGmPerLOM\n/ugCg/kKaSnG+5K3yE1Yx7Df6Ld57GkQq9VxPRb1ho1x1fhMHovp8WVO0j5vWu+vOS8wFDwHMIqn\nVVrjqtFFa6wMGIUIRURKCpwDWKw5wIAxB5hrFShM6j7OF2Nc9enzkIiI9ZBHh/V7UG3UC5wOcwCr\n4JY13lvnSKuoU0mej59VrSlpFQkM6NfELNxmjHk+Y5y2notZRMs4Rr7z3pBRMG7IurZgDMsB4yGD\nfqOIZ4HF5sQsZGbHlTWjsYp92fTB7TNorqkfjQEAAAAAAABghuICLAAAAAAAAAB4hAuwAAAAAAAA\nAOARLsACAAAAAAAAgEd8yWShSWYBAAAAAAAAAOPBClgAAAAAAAAA8AgXYAEAAAAAAADAI1yABQAA\nAAAAAACPcAEWAAAAAAAAADzCBVgAAAAAAAAA8AgXYAEAAAAAAADAI1yABQAAAAAAAACPcAEWAAAA\nAAAAADzCBVgAAAAAAAAA8AgXYAEAAAAAAADAI1yABQAAAAAAAACPcAEWAAAAAAAAADzCBVgAAAAA\nAAAA8AgXYAEAAAAAAADAI1yABQAAAAAAAACPcAEWAAAAAAAAADzCBVgAAAAAAAAA8AgXYAEAAAAA\nAADAI1yABQAAAAAAAACPcAEWAAAAAAAAADzCBVgAAAAAAAAA8AgXYAEAAAAAAADAI1yABQAAAAAA\nAACP/H/duSrGOYsMDAAAAABJRU5ErkJggg==\n", 198 | "text/plain": [ 199 | "" 200 | ] 201 | }, 202 | "metadata": {}, 203 | "output_type": "display_data" 204 | } 205 | ], 206 | "source": [ 207 | "# Import DeepExplain\n", 208 | "from deepexplain.tensorflow import DeepExplain\n", 209 | "from utils import plot, plt\n", 210 | "%matplotlib inline\n", 211 | "\n", 212 | "# Define the input to be tested\n", 213 | "test_idx = 13\n", 214 | "xi = test_x[[test_idx]]\n", 215 | "yi = test_y[test_idx] \n", 216 | "\n", 217 | "# Create a DeepExplain context. \n", 218 | "# IMPORTANT: the network must be created within this context.\n", 219 | "# In this example we have trained the network before, so we call `model(X)` to \n", 220 | "# recreate the network graph using the same weights that have been already trained.\n", 221 | "with DeepExplain(session=sess) as de:\n", 222 | " logits = model(X)\n", 223 | " # We run `explain()` several time to compare different attribution methods\n", 224 | " attributions = {\n", 225 | " # Gradient-based\n", 226 | " 'Saliency maps': de.explain('saliency', logits * yi, X, xi),\n", 227 | " 'Gradient * Input': de.explain('grad*input', logits * yi, X, xi),\n", 228 | " 'Integrated Gradients': de.explain('intgrad', logits * yi, X, xi),\n", 229 | " 'Epsilon-LRP': de.explain('elrp', logits * yi, X, xi),\n", 230 | " 'DeepLIFT (Rescale)': de.explain('deeplift', logits * yi, X, xi),\n", 231 | " #Perturbation-based\n", 232 | " '_Occlusion [1x1]': de.explain('occlusion', logits * yi, X, xi),\n", 233 | " '_Occlusion [3x3]': de.explain('occlusion', logits * yi, X, xi, window_shape=(3,))\n", 234 | " }\n", 235 | " print ('Done')\n", 236 | "\n", 237 | "# Plot attributions\n", 238 | "n_cols = len(attributions) + 1\n", 239 | "fig, axes = plt.subplots(nrows=1, ncols=n_cols, figsize=(3*n_cols, 3))\n", 240 | "plot(xi.reshape(28, 28), cmap='Greys', axis=axes[0]).set_title('Original')\n", 241 | "for i, method_name in enumerate(sorted(attributions.keys())):\n", 242 | " plot(attributions[method_name].reshape(28,28), xi = xi.reshape(28, 28), axis=axes[1+i]).set_title(method_name)" 243 | ] 244 | } 245 | ], 246 | "metadata": { 247 | "kernelspec": { 248 | "display_name": "Python 3", 249 | "language": "python", 250 | "name": "python3" 251 | }, 252 | "language_info": { 253 | "codemirror_mode": { 254 | "name": "ipython", 255 | "version": 3.0 256 | }, 257 | "file_extension": ".py", 258 | "mimetype": "text/x-python", 259 | "name": "python", 260 | "nbconvert_exporter": "python", 261 | "pygments_lexer": "ipython3", 262 | "version": "3.5.2" 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 0 267 | } -------------------------------------------------------------------------------- /examples/multiple_input_keras.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "## DeepExplain - Toy example of a model with multiple inputs\n", 10 | "### Keras with Functional API" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": { 17 | "collapsed": false 18 | }, 19 | "outputs": [ 20 | { 21 | "name": "stderr", 22 | "output_type": "stream", 23 | "text": [ 24 | "Using TensorFlow backend.\n" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "from __future__ import absolute_import\n", 30 | "from __future__ import division\n", 31 | "from __future__ import print_function\n", 32 | "\n", 33 | "import tempfile, sys, os\n", 34 | "sys.path.insert(0, os.path.abspath('..'))\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "import keras\n", 38 | "from keras.datasets import mnist\n", 39 | "from keras.models import Sequential, Model\n", 40 | "from keras.layers import Dense, Dropout, Flatten, Activation, Input\n", 41 | "from keras.layers import Conv2D, MaxPooling2D, Concatenate\n", 42 | "from keras import backend as K\n", 43 | "\n", 44 | "# Import DeepExplain\n", 45 | "from deepexplain.tensorflow import DeepExplain" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": { 52 | "collapsed": false 53 | }, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "Output: [[ 0.98570454]]\n", 60 | "DeepExplain: running \"grad*input\" explanation method (2)\n", 61 | "Attributions:\n", 62 | " [array([[ 0.0108474]], dtype=float32), array([[ 0.04880605]], dtype=float32)]\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "# Define two simple numerical inputs\n", 68 | "_x1 = np.array([[1]])\n", 69 | "_x2 = np.array([[2]])\n", 70 | "\n", 71 | "# Define model\n", 72 | "# Here we implement y = sigmoid([w1*x1|w2*x2] * w3)\n", 73 | "def init_model():\n", 74 | " x1 = Input(shape=(1,))\n", 75 | " x2 = Input(shape=(1,))\n", 76 | "\n", 77 | " t1 = Dense(1)(x1)\n", 78 | " t2 = Dense(1)(x2)\n", 79 | " t3 = Concatenate()([t1, t2])\n", 80 | " t4 = Dense(1)(t3)\n", 81 | " y = Activation('sigmoid')(t4)\n", 82 | " \n", 83 | " model = Model(inputs=[x1, x2], outputs=y)\n", 84 | " model.compile(optimizer='rmsprop', loss='mse')\n", 85 | " return model\n", 86 | "\n", 87 | "\n", 88 | "model = init_model()\n", 89 | "# This is a toy example. The random weight initialization will do just fine.\n", 90 | "# model.fit(...)\n", 91 | "\n", 92 | "# Make sure the model works\n", 93 | "print (\"Output: \", model.predict(x=[_x1, _x2]))\n", 94 | "\n", 95 | "with DeepExplain(session=K.get_session()) as de: # <-- init DeepExplain context\n", 96 | " # Need to reconstruct the graph in DeepExplain context, using the same weights.\n", 97 | " input_tensors = model.inputs\n", 98 | " fModel = Model(inputs = input_tensors, outputs = model.outputs)\n", 99 | " target_tensor = fModel(input_tensors)\n", 100 | "\n", 101 | " attributions = de.explain('grad*input', target_tensor, input_tensors, [_x1, _x2])\n", 102 | " print (\"Attributions:\\n\", attributions)" 103 | ] 104 | } 105 | ], 106 | "metadata": { 107 | "kernelspec": { 108 | "display_name": "Python 3", 109 | "language": "python", 110 | "name": "python3" 111 | }, 112 | "language_info": { 113 | "codemirror_mode": { 114 | "name": "ipython", 115 | "version": 3 116 | }, 117 | "file_extension": ".py", 118 | "mimetype": "text/x-python", 119 | "name": "python", 120 | "nbconvert_exporter": "python", 121 | "pygments_lexer": "ipython3", 122 | "version": "3.5.2" 123 | } 124 | }, 125 | "nbformat": 4, 126 | "nbformat_minor": 0 127 | } 128 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | from skimage import feature, transform 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def plot(data, xi=None, cmap='RdBu_r', axis=plt, percentile=100, dilation=3.0, alpha=0.8): 7 | dx, dy = 0.05, 0.05 8 | xx = np.arange(0.0, data.shape[1], dx) 9 | yy = np.arange(0.0, data.shape[0], dy) 10 | xmin, xmax, ymin, ymax = np.amin(xx), np.amax(xx), np.amin(yy), np.amax(yy) 11 | extent = xmin, xmax, ymin, ymax 12 | cmap_xi = plt.get_cmap('Greys_r') 13 | cmap_xi.set_bad(alpha=0) 14 | overlay = None 15 | if xi is not None: 16 | # Compute edges (to overlay to heatmaps later) 17 | xi_greyscale = xi if len(xi.shape) == 2 else np.mean(xi, axis=-1) 18 | in_image_upscaled = transform.rescale(xi_greyscale, dilation, mode='constant') 19 | edges = feature.canny(in_image_upscaled).astype(float) 20 | edges[edges < 0.5] = np.nan 21 | edges[:5, :] = np.nan 22 | edges[-5:, :] = np.nan 23 | edges[:, :5] = np.nan 24 | edges[:, -5:] = np.nan 25 | overlay = edges 26 | 27 | abs_max = np.percentile(np.abs(data), percentile) 28 | abs_min = abs_max 29 | 30 | if len(data.shape) == 3: 31 | data = np.mean(data, 2) 32 | axis.imshow(data, extent=extent, interpolation='none', cmap=cmap, vmin=-abs_min, vmax=abs_max) 33 | if overlay is not None: 34 | axis.imshow(overlay, extent=extent, interpolation='none', cmap=cmap_xi, alpha=alpha) 35 | axis.axis('off') 36 | return axis -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='deepexplain', 4 | version='0.3', 5 | description='Perturbation and gradient-based methods for Deep Network interpretability', 6 | url='https://github.com/marcoancona/DeepExplain', 7 | author='Marco Ancona (ETH Zurich)', 8 | author_email='marco.ancona@inf.ethz.ch', 9 | license='MIT', 10 | packages=find_packages(), 11 | install_requires=[ 12 | 'scipy', 13 | 'matplotlib', 14 | 'scikit-image' 15 | ], 16 | extras_require={ 17 | "tf": ["tensorflow>=1.0.0"], 18 | "tf_gpu": ["tensorflow-gpu>=1.0.0"], 19 | }, 20 | zip_safe=False, 21 | test_suite='nose.collector', 22 | tests_require=['nose'], 23 | ) --------------------------------------------------------------------------------