├── lrp.pyc ├── README.md ├── lrp.py └── .ipynb_checkpoints └── LRP_Example-checkpoint.ipynb /lrp.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dshieble/Tensorflow_Deep_Taylor_LRP/HEAD/lrp.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow_Deep_Taylor_LRP 2 | Layerwise Relevance Propagation with Deep Taylor Series in TensorFlow. 3 | 4 | You can use LRP to visualize the relative feature importances of the input to a neural network. 5 | 6 | ## How to Use 7 | 8 | ### Step 1: Construct your tensorflow graph 9 | ### Step 2: Make sure your prediction layer (output layer) is named "absolute_output" 10 | ### Step 3: Make sure your input layer (shaped as [num_batches, height, width, num_channels]) is named "absolute_input" 11 | ### Step 4: `relevance_heatmap = lrp.lrp(prediction*label, lowest_value, highest_value)` 12 | -------------------------------------------------------------------------------- /lrp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from tensorflow.examples.tutorials.mnist import input_data 3 | import tensorflow as tf 4 | import numpy as np 5 | from tqdm import tqdm 6 | from tensorflow.python.ops import nn_ops, gen_nn_ops 7 | import matplotlib.pyplot as plt 8 | from scipy.stats.mstats import zscore 9 | 10 | #Helper Method for 11 | 12 | def lrp(F, lowest, highest, graph=None, return_flist=False): 13 | """ 14 | Accepts a final output, and propagates back from there to compute LRP over a tensorflow graph. 15 | Performs a Taylor Decomp at each layer to assess the relevances of each neuron at that layer 16 | """ 17 | #Assumptions: 18 | #all conv strides are [1,1,1,1] 19 | #all pool strides are [1,2,2,1] 20 | #all pool/conv padding is SAME 21 | #only reshaping that happens is after a pool layer (pool -> fc) or a conv layer (conv -> fc) 22 | F_list = [] 23 | traversed, graph, graph_dict, var_dict = get_traversed(graph=graph) 24 | for n in traversed: 25 | val_name = next(I for I in graph_dict[n].input if I in traversed).split("/read")[0] + ":0" 26 | X = graph.get_tensor_by_name(val_name) 27 | if graph_dict[n].op == "MatMul": 28 | weight_name = next(I for I in graph_dict[n].input if not I in traversed).split("/read")[0] + ":0" 29 | W = var_dict[weight_name] 30 | if "absolute_input" in graph_dict[n].input: 31 | F = fprop_first(F, W, X, lowest, highest) 32 | F_list.append(F) 33 | break 34 | else: 35 | F = fprop(F, W, X) 36 | F_list.append(F) 37 | elif graph_dict[n].op == "MaxPool" or graph_dict[n].op == "MaxPoolWithArgmax": 38 | F = fprop_pool(F, X) 39 | F_list.append(F) 40 | elif graph_dict[n].op == "Conv2D": 41 | weight_name = next(I for I in graph_dict[n].input if not I in traversed).split("/read")[0] + ":0" 42 | W = var_dict[weight_name] 43 | if "absolute_input" in graph_dict[n].input: 44 | F = fprop_conv_first(F, W, X, lowest, highest) 45 | F_list.append(F) 46 | break 47 | else: 48 | F = fprop_conv(F, W, X) 49 | F_list.append(F) 50 | if return_flist: 51 | return F_list 52 | else: 53 | return F 54 | 55 | def get_traversed(graph = None): 56 | #Get the graph and graph traversal 57 | graph = tf.get_default_graph() if graph is None else graph 58 | graph_dict = {node.name:node for node in graph.as_graph_def().node} 59 | var_dict = {v.name:v.value() for v in tf.get_collection(tf.GraphKeys.VARIABLES)} 60 | return traverse(graph_dict["absolute_output"], [], graph_dict), graph, graph_dict, var_dict 61 | 62 | 63 | def traverse(node, L, graph_dict): 64 | #Depth First Search the Network Graph 65 | L.append(node.name) 66 | if "absolute_input" in node.name: 67 | return L 68 | inputs = node.input 69 | for nodename in inputs: 70 | if not traverse(graph_dict[nodename], L, graph_dict) is None: 71 | return L 72 | return None 73 | 74 | def fprop_first(F, W, X, lowest, highest): 75 | #Propagate from last feedforward layer to input 76 | W,V,U = W,tf.maximum(0.0,W), tf.minimum(0.0,W) 77 | X,L,H = X, X*0+lowest, X*0+highest 78 | 79 | Z = tf.matmul(X, W)-tf.matmul(L, V)-tf.matmul(H, U)+1e-9 80 | S = F/Z 81 | F = X*tf.matmul(S,tf.transpose(W))-L*tf.matmul(S, tf.transpose(V))-H*tf.matmul(S,tf.transpose(U)) 82 | return F 83 | 84 | def fprop(F, W, X): 85 | #Propagate over feedforward layer 86 | V = tf.maximum(0.0, W) 87 | Z = tf.matmul(X, V)+1e-9; 88 | S = F/Z 89 | C = tf.matmul(S, tf.transpose(V)) 90 | F = X*C 91 | return F 92 | 93 | def fprop_conv_first(F, W, X, lowest, highest, strides=None, padding='SAME'): 94 | #Propagate from last conv layer to input 95 | strides = [1, 1, 1, 1] if strides is None else strides 96 | 97 | Wn = tf.minimum(0.0, W) 98 | Wp = tf.maximum(0.0, W) 99 | 100 | X, L, H = X, X*0+lowest, X*0+highest 101 | 102 | c = tf.nn.conv2d(X, W, strides, padding) 103 | cp = tf.nn.conv2d(H, Wp, strides, padding) 104 | cn = tf.nn.conv2d(L, Wn, strides, padding) 105 | Z = c - cp - cn + 1e-9 106 | S = F/Z 107 | 108 | g = nn_ops.conv2d_backprop_input(tf.shape(X), W, S, strides, padding) 109 | gp = nn_ops.conv2d_backprop_input(tf.shape(X), Wp, S, strides, padding) 110 | gn = nn_ops.conv2d_backprop_input(tf.shape(X), Wn, S, strides, padding) 111 | F = X*g - L*gp - H*gn 112 | return F 113 | 114 | def fprop_conv(F, W, X, strides=None, padding='SAME'): 115 | #Propagate over conv layer 116 | xshape = X.get_shape().as_list() 117 | fshape = F.get_shape().as_list() 118 | if len(xshape) != len(fshape): 119 | F = tf.reshape(F, (-1, xshape[1], xshape[2], fshape[-1]/(xshape[1]*xshape[2]))) 120 | strides = [1, 1, 1, 1] if strides is None else strides 121 | W = tf.maximum(0.0, W) 122 | 123 | Z = tf.nn.conv2d(X, W, strides, padding) + 1e-9 124 | S = F/Z 125 | C = nn_ops.conv2d_backprop_input(tf.shape(X), W, S, strides, padding) 126 | F = X*C 127 | return F 128 | 129 | def fprop_pool(F, X, strides=None, ksize=None, padding='SAME'): 130 | #Propagate over pool layer 131 | xshape = X.get_shape().as_list() 132 | fshape = F.get_shape().as_list() 133 | if len(xshape) != len(fshape): 134 | F = tf.reshape(F, (-1, int(np.ceil(xshape[1]/2.0)), 135 | int(np.ceil(xshape[2]/2.0)), xshape[3])) 136 | ksize = [1, 2, 2, 1] if ksize is None else ksize 137 | strides = [1, 2, 2, 1] if strides is None else strides 138 | 139 | Z = tf.nn.max_pool(X, strides=strides, ksize=ksize, padding=padding) + 1e-9 140 | S = F / Z 141 | C = gen_nn_ops._max_pool_grad(X, Z, S, ksize, strides, padding) 142 | F = X*C 143 | return F 144 | 145 | 146 | def get_lrp_im(sess, F, x, y, xval, yval): 147 | #Compute LRP over the values and labels 148 | im = [] 149 | for i in range(0, xval.shape[0]): 150 | im += list(F.eval(session=sess, feed_dict={x: xval[i:i+1], y: yval[i:i+1]})) 151 | return im 152 | 153 | def visualize(im_list, xval): 154 | #Visualize the LRPs 155 | for i in range(len(im_list[0])): 156 | plt.figure() 157 | plt.subplot(1,1+len(im_list),1) 158 | plt.title("Image") 159 | plt.imshow(xval[i]) 160 | 161 | for j in range(len(im_list)): 162 | plt.subplot(1,1+len(im_list),2+j) 163 | plt.title("LRP for network {}".format(j)) 164 | I = np.mean(np.maximum(im_list[j][i], 0), -1) 165 | I = np.minimum(I, np.percentile(I, 99)) 166 | I = I/np.max(I) 167 | print "np.linalg.norm(I)", np.linalg.norm(I) 168 | plt.imshow(I, cmap="gray") 169 | 170 | plt.show() 171 | return im_list 172 | 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/LRP_Example-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stdout", 12 | "output_type": "stream", 13 | "text": [ 14 | "Extracting /tmp/tensorflow/mnist/input_data/train-images-idx3-ubyte.gz\n", 15 | "Extracting /tmp/tensorflow/mnist/input_data/train-labels-idx1-ubyte.gz\n", 16 | "Extracting /tmp/tensorflow/mnist/input_data/t10k-images-idx3-ubyte.gz\n", 17 | "Extracting /tmp/tensorflow/mnist/input_data/t10k-labels-idx1-ubyte.gz\n" 18 | ] 19 | } 20 | ], 21 | "source": [ 22 | "\"\"\"\n", 23 | " This is a toy example that demonstrates how we can use LRP on a convolutional neural network trained on mnist\n", 24 | "\"\"\"\n", 25 | "\n", 26 | "%matplotlib inline\n", 27 | "%load_ext autoreload\n", 28 | "\n", 29 | "%autoreload 2\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "from tensorflow.examples.tutorials.mnist import input_data\n", 32 | "import tensorflow as tf\n", 33 | "import numpy as np\n", 34 | "from tqdm import tqdm\n", 35 | "import lrp\n", 36 | "import pandas as pd\n", 37 | "from pylab import rcParams\n", 38 | "rcParams['figure.figsize'] = 8, 10\n", 39 | "mnist = input_data.read_data_sets('/tmp/tensorflow/mnist/input_data', one_hot=True)\n" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "collapsed": false, 47 | "scrolled": true 48 | }, 49 | "outputs": [ 50 | { 51 | "name": "stderr", 52 | "output_type": "stream", 53 | "text": [ 54 | " 20%|██ | 1/5 [00:57<03:49, 57.36s/it]" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "#Build the Convolutional Neural Network\n", 60 | "batch_size = 50\n", 61 | "total_batch = int(mnist.train.num_examples/batch_size)\n", 62 | "num_epochs = 5\n", 63 | "\n", 64 | "tf.reset_default_graph()\n", 65 | "x = tf.placeholder(tf.float32, [None, 784])\n", 66 | "y_ = tf.placeholder(tf.float32, [None, 10], name=\"truth\")\n", 67 | "\n", 68 | "#Set the weights for the network\n", 69 | "xavier = tf.contrib.layers.xavier_initializer_conv2d() \n", 70 | "conv1_weights = tf.get_variable(name=\"c1\", initializer=xavier, shape=[5, 5, 1, 10])\n", 71 | "conv1_biases = tf.Variable(tf.zeros([10]))\n", 72 | "conv2_weights = tf.get_variable(name=\"c2\", initializer=xavier, shape=[5, 5, 10, 25])\n", 73 | "conv2_biases = tf.Variable(tf.zeros([25]))\n", 74 | "conv3_weights = tf.get_variable(name=\"c3\", initializer=xavier, shape=[4, 4, 25, 100])\n", 75 | "conv3_biases = tf.Variable(tf.zeros([100]))\n", 76 | "fc1_weights = tf.Variable(tf.truncated_normal([4 * 4 * 100, 10], stddev=0.1))\n", 77 | "fc1_biases = tf.Variable(tf.zeros([10]))\n", 78 | "\n", 79 | "#Stack the Layers\n", 80 | "reshaped_input = tf.reshape(x, [-1, 28, 28, 1], name=\"absolute_input\")\n", 81 | "#layer 1\n", 82 | "conv1 = tf.nn.conv2d(reshaped_input, conv1_weights, strides=[1, 1, 1, 1], padding='SAME')\n", 83 | "relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))\n", 84 | "pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n", 85 | "#layer 2\n", 86 | "conv2 = tf.nn.conv2d(pool1, conv2_weights, strides=[1, 1, 1, 1], padding='SAME')\n", 87 | "relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases))\n", 88 | "pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n", 89 | "#layer 3\n", 90 | "conv3 = tf.nn.conv2d(pool2, conv3_weights, strides=[1, 1, 1, 1], padding='SAME')\n", 91 | "relu3 = tf.nn.relu(tf.nn.bias_add(conv3, conv3_biases))\n", 92 | "pool3 = tf.nn.max_pool(relu3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')\n", 93 | "#layer 4 \n", 94 | "pool_shape = pool3.get_shape().as_list()\n", 95 | "reshaped = tf.reshape(pool3, [-1, pool_shape[1] * pool_shape[2] * pool_shape[3]])\n", 96 | "y = tf.add(tf.matmul(reshaped, fc1_weights), fc1_biases, name=\"absolute_output\")\n", 97 | "\n", 98 | "# Define loss and optimizer\n", 99 | "cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_))\n", 100 | "train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cross_entropy)\n", 101 | "\n", 102 | "# Train the model\n", 103 | "sess = tf.InteractiveSession()\n", 104 | "tf.initialize_all_variables().run()\n", 105 | "for i in tqdm(range(num_epochs)):\n", 106 | " for i in range(total_batch):\n", 107 | " batch_x, batch_y = mnist.train.next_batch(batch_size)\n", 108 | " sess.run(train_step, feed_dict={x: batch_x, y_: batch_y})\n", 109 | "\n", 110 | "# Test trained model\n", 111 | "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n", 112 | "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", 113 | "test_acc = []\n", 114 | "train_acc = []\n", 115 | "for i in tqdm(range(total_batch)):\n", 116 | " batch_x, batch_y = mnist.test.next_batch(batch_size)\n", 117 | " test_acc.append(sess.run(accuracy, feed_dict={x: batch_x, y_: batch_y}))\n", 118 | " batch_x, batch_y = mnist.train.next_batch(batch_size)\n", 119 | " train_acc.append(sess.run(accuracy, feed_dict={x: batch_x, y_: batch_y}))\n", 120 | "print np.mean(train_acc), np.mean(test_acc)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": { 127 | "collapsed": false, 128 | "scrolled": false 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "#Run LRP with Deep Taylor Decomposition on the output of the network\n", 133 | "F_list = lrp.lrp(y*y_, 0, 1, return_flist=True)\n", 134 | "im_list = lrp.get_lrp_im(sess, F_list[-1], reshaped_input, y_, np.reshape(batch_x, (batch_size, 28,28, 1)), batch_y)\n", 135 | "#Visualize the produced heatmaps\n", 136 | "for b, im in zip(batch_x, im_list):\n", 137 | " plt.figure()\n", 138 | " plt.subplot(1,2,1)\n", 139 | " plt.imshow(np.reshape(b, (28,28)))\n", 140 | " plt.subplot(1,2,2)\n", 141 | " plt.imshow(np.reshape(im, (28,28)), cmap=\"gray\")\n", 142 | " plt.show()" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": { 149 | "collapsed": true 150 | }, 151 | "outputs": [], 152 | "source": [] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "Python 2", 158 | "language": "python", 159 | "name": "python2" 160 | }, 161 | "language_info": { 162 | "codemirror_mode": { 163 | "name": "ipython", 164 | "version": 2 165 | }, 166 | "file_extension": ".py", 167 | "mimetype": "text/x-python", 168 | "name": "python", 169 | "nbconvert_exporter": "python", 170 | "pygments_lexer": "ipython2", 171 | "version": "2.7.11" 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 0 176 | } 177 | --------------------------------------------------------------------------------