├── .ipynb_checkpoints
├── LRP Implementation-checkpoint.ipynb
├── LRP old working-checkpoint.ipynb
├── SHAP & LIME (Issues & Observations)-checkpoint.ipynb
├── SHAP & LIME Implementation dummy-checkpoint.ipynb
└── SHAP & LIME Implementation-checkpoint.ipynb
├── README.md
├── images
├── LSTM Models.png
├── bidi.PNG
├── encdects_16_ts.PNG
├── encdects_1_ts.PNG
├── stackedlstmts_16_ts.PNG
└── vanilla lstmts_1_ts.PNG
├── models
├── bidi_model.h5
├── enc_dec_model_ts1.h5
├── enc_dec_model_ts16.h5
├── stacked_lstm_model_ts16_1.h5
├── stacked_lstm_model_ts16_2.h5
└── vanilla_lstm_model_ts1.h5
├── notebooks
├── .ipynb_checkpoints
│ ├── LRP Implementation-checkpoint.ipynb
│ └── SHAP & LIME Implementation-checkpoint.ipynb
├── LIME_and_SHAP_Implementation.ipynb
└── LRP_Implementation.ipynb
└── result plots
├── Bar plot - LRP relevance score for 1 instance.PNG
├── LRP relevance score for all instances.PNG
├── comp table new.PNG
├── shap_bidi_lstm.PNG
├── shap_enc_dec_ts1.png
├── shap_stacked_ lstm_ts16_1.PNG
└── shap_vanilla_ts1.png
/.ipynb_checkpoints/SHAP & LIME (Issues & Observations)-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "### Jithin Sasikumar\n",
8 | "\n",
9 | "This notebook addresses the major issues when implementing **feature attribution methods (SHAP and LIME)** for LSTM time-series model. It contains data pre-processing, model prediction, SHAP implementation for our model along with the issues and my observations. I have also added the comparison of our model with two other models trained on different datasets for different problems. "
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import numpy as np\n",
19 | "import pandas as pd\n",
20 | "import matplotlib.pyplot as plt\n",
21 | "import shap\n",
22 | "import tensorflow as tf"
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "metadata": {},
28 | "source": [
29 | "## Loading train and test data for pre-processing\n",
30 | "\n",
31 | "1. The data has **25 features**, each representing sensor information from a specific channel of the respective satellite. \n",
32 | "2. The data is one-hot encoded."
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 3,
38 | "metadata": {},
39 | "outputs": [
40 | {
41 | "name": "stdout",
42 | "output_type": "stream",
43 | "text": [
44 | "(2880, 25)\n"
45 | ]
46 | },
47 | {
48 | "data": {
49 | "text/html": [
50 | "
\n",
51 | "\n",
64 | "
\n",
65 | " \n",
66 | " \n",
67 | " | \n",
68 | " 0 | \n",
69 | " 1 | \n",
70 | " 2 | \n",
71 | " 3 | \n",
72 | " 4 | \n",
73 | " 5 | \n",
74 | " 6 | \n",
75 | " 7 | \n",
76 | " 8 | \n",
77 | " 9 | \n",
78 | " ... | \n",
79 | " 15 | \n",
80 | " 16 | \n",
81 | " 17 | \n",
82 | " 18 | \n",
83 | " 19 | \n",
84 | " 20 | \n",
85 | " 21 | \n",
86 | " 22 | \n",
87 | " 23 | \n",
88 | " 24 | \n",
89 | "
\n",
90 | " \n",
91 | " \n",
92 | " \n",
93 | " 0 | \n",
94 | " 0.999 | \n",
95 | " 0.0 | \n",
96 | " 0.0 | \n",
97 | " 0.0 | \n",
98 | " 0.0 | \n",
99 | " 0.0 | \n",
100 | " 0.0 | \n",
101 | " 0.0 | \n",
102 | " 0.0 | \n",
103 | " 0.0 | \n",
104 | " ... | \n",
105 | " 0.0 | \n",
106 | " 0.0 | \n",
107 | " 0.0 | \n",
108 | " 0.0 | \n",
109 | " 0.0 | \n",
110 | " 0.0 | \n",
111 | " 0.0 | \n",
112 | " 0.0 | \n",
113 | " 0.0 | \n",
114 | " 0.0 | \n",
115 | "
\n",
116 | " \n",
117 | " 1 | \n",
118 | " 0.999 | \n",
119 | " 0.0 | \n",
120 | " 0.0 | \n",
121 | " 0.0 | \n",
122 | " 0.0 | \n",
123 | " 0.0 | \n",
124 | " 0.0 | \n",
125 | " 0.0 | \n",
126 | " 0.0 | \n",
127 | " 0.0 | \n",
128 | " ... | \n",
129 | " 0.0 | \n",
130 | " 0.0 | \n",
131 | " 0.0 | \n",
132 | " 0.0 | \n",
133 | " 0.0 | \n",
134 | " 0.0 | \n",
135 | " 0.0 | \n",
136 | " 0.0 | \n",
137 | " 0.0 | \n",
138 | " 0.0 | \n",
139 | "
\n",
140 | " \n",
141 | " 2 | \n",
142 | " 0.999 | \n",
143 | " 0.0 | \n",
144 | " 0.0 | \n",
145 | " 0.0 | \n",
146 | " 0.0 | \n",
147 | " 0.0 | \n",
148 | " 0.0 | \n",
149 | " 0.0 | \n",
150 | " 0.0 | \n",
151 | " 0.0 | \n",
152 | " ... | \n",
153 | " 0.0 | \n",
154 | " 0.0 | \n",
155 | " 0.0 | \n",
156 | " 0.0 | \n",
157 | " 0.0 | \n",
158 | " 0.0 | \n",
159 | " 0.0 | \n",
160 | " 0.0 | \n",
161 | " 0.0 | \n",
162 | " 0.0 | \n",
163 | "
\n",
164 | " \n",
165 | " 3 | \n",
166 | " 0.999 | \n",
167 | " 0.0 | \n",
168 | " 0.0 | \n",
169 | " 0.0 | \n",
170 | " 0.0 | \n",
171 | " 0.0 | \n",
172 | " 0.0 | \n",
173 | " 0.0 | \n",
174 | " 0.0 | \n",
175 | " 0.0 | \n",
176 | " ... | \n",
177 | " 0.0 | \n",
178 | " 0.0 | \n",
179 | " 0.0 | \n",
180 | " 0.0 | \n",
181 | " 0.0 | \n",
182 | " 0.0 | \n",
183 | " 0.0 | \n",
184 | " 0.0 | \n",
185 | " 0.0 | \n",
186 | " 0.0 | \n",
187 | "
\n",
188 | " \n",
189 | " 4 | \n",
190 | " 0.999 | \n",
191 | " 0.0 | \n",
192 | " 0.0 | \n",
193 | " 0.0 | \n",
194 | " 0.0 | \n",
195 | " 0.0 | \n",
196 | " 0.0 | \n",
197 | " 0.0 | \n",
198 | " 0.0 | \n",
199 | " 0.0 | \n",
200 | " ... | \n",
201 | " 0.0 | \n",
202 | " 0.0 | \n",
203 | " 0.0 | \n",
204 | " 0.0 | \n",
205 | " 0.0 | \n",
206 | " 0.0 | \n",
207 | " 0.0 | \n",
208 | " 0.0 | \n",
209 | " 0.0 | \n",
210 | " 0.0 | \n",
211 | "
\n",
212 | " \n",
213 | " ... | \n",
214 | " ... | \n",
215 | " ... | \n",
216 | " ... | \n",
217 | " ... | \n",
218 | " ... | \n",
219 | " ... | \n",
220 | " ... | \n",
221 | " ... | \n",
222 | " ... | \n",
223 | " ... | \n",
224 | " ... | \n",
225 | " ... | \n",
226 | " ... | \n",
227 | " ... | \n",
228 | " ... | \n",
229 | " ... | \n",
230 | " ... | \n",
231 | " ... | \n",
232 | " ... | \n",
233 | " ... | \n",
234 | " ... | \n",
235 | "
\n",
236 | " \n",
237 | " 2875 | \n",
238 | " 0.999 | \n",
239 | " 0.0 | \n",
240 | " 0.0 | \n",
241 | " 0.0 | \n",
242 | " 0.0 | \n",
243 | " 0.0 | \n",
244 | " 0.0 | \n",
245 | " 0.0 | \n",
246 | " 0.0 | \n",
247 | " 0.0 | \n",
248 | " ... | \n",
249 | " 0.0 | \n",
250 | " 0.0 | \n",
251 | " 0.0 | \n",
252 | " 0.0 | \n",
253 | " 0.0 | \n",
254 | " 0.0 | \n",
255 | " 0.0 | \n",
256 | " 0.0 | \n",
257 | " 0.0 | \n",
258 | " 0.0 | \n",
259 | "
\n",
260 | " \n",
261 | " 2876 | \n",
262 | " 0.999 | \n",
263 | " 0.0 | \n",
264 | " 0.0 | \n",
265 | " 0.0 | \n",
266 | " 0.0 | \n",
267 | " 0.0 | \n",
268 | " 0.0 | \n",
269 | " 0.0 | \n",
270 | " 0.0 | \n",
271 | " 0.0 | \n",
272 | " ... | \n",
273 | " 0.0 | \n",
274 | " 0.0 | \n",
275 | " 0.0 | \n",
276 | " 0.0 | \n",
277 | " 0.0 | \n",
278 | " 0.0 | \n",
279 | " 0.0 | \n",
280 | " 0.0 | \n",
281 | " 0.0 | \n",
282 | " 0.0 | \n",
283 | "
\n",
284 | " \n",
285 | " 2877 | \n",
286 | " 0.999 | \n",
287 | " 0.0 | \n",
288 | " 0.0 | \n",
289 | " 0.0 | \n",
290 | " 0.0 | \n",
291 | " 0.0 | \n",
292 | " 0.0 | \n",
293 | " 0.0 | \n",
294 | " 0.0 | \n",
295 | " 0.0 | \n",
296 | " ... | \n",
297 | " 0.0 | \n",
298 | " 0.0 | \n",
299 | " 0.0 | \n",
300 | " 0.0 | \n",
301 | " 0.0 | \n",
302 | " 0.0 | \n",
303 | " 0.0 | \n",
304 | " 0.0 | \n",
305 | " 0.0 | \n",
306 | " 0.0 | \n",
307 | "
\n",
308 | " \n",
309 | " 2878 | \n",
310 | " 0.999 | \n",
311 | " 0.0 | \n",
312 | " 0.0 | \n",
313 | " 0.0 | \n",
314 | " 0.0 | \n",
315 | " 0.0 | \n",
316 | " 0.0 | \n",
317 | " 0.0 | \n",
318 | " 0.0 | \n",
319 | " 0.0 | \n",
320 | " ... | \n",
321 | " 0.0 | \n",
322 | " 0.0 | \n",
323 | " 0.0 | \n",
324 | " 0.0 | \n",
325 | " 0.0 | \n",
326 | " 0.0 | \n",
327 | " 0.0 | \n",
328 | " 0.0 | \n",
329 | " 0.0 | \n",
330 | " 0.0 | \n",
331 | "
\n",
332 | " \n",
333 | " 2879 | \n",
334 | " 0.999 | \n",
335 | " 0.0 | \n",
336 | " 0.0 | \n",
337 | " 0.0 | \n",
338 | " 0.0 | \n",
339 | " 0.0 | \n",
340 | " 0.0 | \n",
341 | " 0.0 | \n",
342 | " 0.0 | \n",
343 | " 0.0 | \n",
344 | " ... | \n",
345 | " 0.0 | \n",
346 | " 0.0 | \n",
347 | " 0.0 | \n",
348 | " 0.0 | \n",
349 | " 0.0 | \n",
350 | " 0.0 | \n",
351 | " 0.0 | \n",
352 | " 0.0 | \n",
353 | " 0.0 | \n",
354 | " 0.0 | \n",
355 | "
\n",
356 | " \n",
357 | "
\n",
358 | "
2880 rows × 25 columns
\n",
359 | "
"
360 | ],
361 | "text/plain": [
362 | " 0 1 2 3 4 5 6 7 8 9 ... 15 16 17 \\\n",
363 | "0 0.999 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
364 | "1 0.999 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
365 | "2 0.999 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
366 | "3 0.999 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
367 | "4 0.999 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
368 | "... ... ... ... ... ... ... ... ... ... ... ... ... ... ... \n",
369 | "2875 0.999 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
370 | "2876 0.999 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
371 | "2877 0.999 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
372 | "2878 0.999 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
373 | "2879 0.999 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
374 | "\n",
375 | " 18 19 20 21 22 23 24 \n",
376 | "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
377 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
378 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
379 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
380 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
381 | "... ... ... ... ... ... ... ... \n",
382 | "2875 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
383 | "2876 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
384 | "2877 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
385 | "2878 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
386 | "2879 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
387 | "\n",
388 | "[2880 rows x 25 columns]"
389 | ]
390 | },
391 | "execution_count": 3,
392 | "metadata": {},
393 | "output_type": "execute_result"
394 | }
395 | ],
396 | "source": [
397 | "#Train data\n",
398 | "train_arr = np.load('D:/DLR/telemanom/data/train/A-1.npy')\n",
399 | "data_train = pd.DataFrame(train_arr)\n",
400 | "print(data_train.shape)\n",
401 | "\n",
402 | "data_train"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": 4,
408 | "metadata": {},
409 | "outputs": [
410 | {
411 | "name": "stdout",
412 | "output_type": "stream",
413 | "text": [
414 | "(8640, 25)\n"
415 | ]
416 | },
417 | {
418 | "data": {
419 | "text/html": [
420 | "\n",
421 | "\n",
434 | "
\n",
435 | " \n",
436 | " \n",
437 | " | \n",
438 | " 0 | \n",
439 | " 1 | \n",
440 | " 2 | \n",
441 | " 3 | \n",
442 | " 4 | \n",
443 | " 5 | \n",
444 | " 6 | \n",
445 | " 7 | \n",
446 | " 8 | \n",
447 | " 9 | \n",
448 | " ... | \n",
449 | " 15 | \n",
450 | " 16 | \n",
451 | " 17 | \n",
452 | " 18 | \n",
453 | " 19 | \n",
454 | " 20 | \n",
455 | " 21 | \n",
456 | " 22 | \n",
457 | " 23 | \n",
458 | " 24 | \n",
459 | "
\n",
460 | " \n",
461 | " \n",
462 | " \n",
463 | " 0 | \n",
464 | " 1.0 | \n",
465 | " 0.0 | \n",
466 | " 0.0 | \n",
467 | " 0.0 | \n",
468 | " 0.0 | \n",
469 | " 0.0 | \n",
470 | " 0.0 | \n",
471 | " 0.0 | \n",
472 | " 0.0 | \n",
473 | " 0.0 | \n",
474 | " ... | \n",
475 | " 0.0 | \n",
476 | " 0.0 | \n",
477 | " 0.0 | \n",
478 | " 0.0 | \n",
479 | " 0.0 | \n",
480 | " 0.0 | \n",
481 | " 0.0 | \n",
482 | " 0.0 | \n",
483 | " 0.0 | \n",
484 | " 0.0 | \n",
485 | "
\n",
486 | " \n",
487 | " 1 | \n",
488 | " 1.0 | \n",
489 | " 0.0 | \n",
490 | " 0.0 | \n",
491 | " 0.0 | \n",
492 | " 0.0 | \n",
493 | " 0.0 | \n",
494 | " 0.0 | \n",
495 | " 0.0 | \n",
496 | " 0.0 | \n",
497 | " 0.0 | \n",
498 | " ... | \n",
499 | " 0.0 | \n",
500 | " 0.0 | \n",
501 | " 0.0 | \n",
502 | " 0.0 | \n",
503 | " 0.0 | \n",
504 | " 0.0 | \n",
505 | " 0.0 | \n",
506 | " 0.0 | \n",
507 | " 0.0 | \n",
508 | " 0.0 | \n",
509 | "
\n",
510 | " \n",
511 | " 2 | \n",
512 | " 1.0 | \n",
513 | " 0.0 | \n",
514 | " 0.0 | \n",
515 | " 0.0 | \n",
516 | " 0.0 | \n",
517 | " 0.0 | \n",
518 | " 0.0 | \n",
519 | " 0.0 | \n",
520 | " 0.0 | \n",
521 | " 0.0 | \n",
522 | " ... | \n",
523 | " 0.0 | \n",
524 | " 0.0 | \n",
525 | " 0.0 | \n",
526 | " 0.0 | \n",
527 | " 0.0 | \n",
528 | " 0.0 | \n",
529 | " 0.0 | \n",
530 | " 0.0 | \n",
531 | " 0.0 | \n",
532 | " 0.0 | \n",
533 | "
\n",
534 | " \n",
535 | " 3 | \n",
536 | " 1.0 | \n",
537 | " 0.0 | \n",
538 | " 0.0 | \n",
539 | " 0.0 | \n",
540 | " 0.0 | \n",
541 | " 0.0 | \n",
542 | " 0.0 | \n",
543 | " 0.0 | \n",
544 | " 0.0 | \n",
545 | " 0.0 | \n",
546 | " ... | \n",
547 | " 0.0 | \n",
548 | " 0.0 | \n",
549 | " 0.0 | \n",
550 | " 0.0 | \n",
551 | " 0.0 | \n",
552 | " 0.0 | \n",
553 | " 0.0 | \n",
554 | " 0.0 | \n",
555 | " 0.0 | \n",
556 | " 0.0 | \n",
557 | "
\n",
558 | " \n",
559 | " 4 | \n",
560 | " 1.0 | \n",
561 | " 0.0 | \n",
562 | " 0.0 | \n",
563 | " 0.0 | \n",
564 | " 0.0 | \n",
565 | " 0.0 | \n",
566 | " 0.0 | \n",
567 | " 0.0 | \n",
568 | " 0.0 | \n",
569 | " 0.0 | \n",
570 | " ... | \n",
571 | " 0.0 | \n",
572 | " 0.0 | \n",
573 | " 0.0 | \n",
574 | " 0.0 | \n",
575 | " 0.0 | \n",
576 | " 0.0 | \n",
577 | " 1.0 | \n",
578 | " 1.0 | \n",
579 | " 0.0 | \n",
580 | " 0.0 | \n",
581 | "
\n",
582 | " \n",
583 | " ... | \n",
584 | " ... | \n",
585 | " ... | \n",
586 | " ... | \n",
587 | " ... | \n",
588 | " ... | \n",
589 | " ... | \n",
590 | " ... | \n",
591 | " ... | \n",
592 | " ... | \n",
593 | " ... | \n",
594 | " ... | \n",
595 | " ... | \n",
596 | " ... | \n",
597 | " ... | \n",
598 | " ... | \n",
599 | " ... | \n",
600 | " ... | \n",
601 | " ... | \n",
602 | " ... | \n",
603 | " ... | \n",
604 | " ... | \n",
605 | "
\n",
606 | " \n",
607 | " 8635 | \n",
608 | " 1.0 | \n",
609 | " 0.0 | \n",
610 | " 0.0 | \n",
611 | " 0.0 | \n",
612 | " 0.0 | \n",
613 | " 0.0 | \n",
614 | " 0.0 | \n",
615 | " 0.0 | \n",
616 | " 0.0 | \n",
617 | " 0.0 | \n",
618 | " ... | \n",
619 | " 0.0 | \n",
620 | " 0.0 | \n",
621 | " 0.0 | \n",
622 | " 0.0 | \n",
623 | " 0.0 | \n",
624 | " 0.0 | \n",
625 | " 0.0 | \n",
626 | " 0.0 | \n",
627 | " 0.0 | \n",
628 | " 0.0 | \n",
629 | "
\n",
630 | " \n",
631 | " 8636 | \n",
632 | " 1.0 | \n",
633 | " 0.0 | \n",
634 | " 0.0 | \n",
635 | " 0.0 | \n",
636 | " 0.0 | \n",
637 | " 0.0 | \n",
638 | " 0.0 | \n",
639 | " 0.0 | \n",
640 | " 0.0 | \n",
641 | " 0.0 | \n",
642 | " ... | \n",
643 | " 0.0 | \n",
644 | " 0.0 | \n",
645 | " 0.0 | \n",
646 | " 0.0 | \n",
647 | " 0.0 | \n",
648 | " 0.0 | \n",
649 | " 0.0 | \n",
650 | " 0.0 | \n",
651 | " 0.0 | \n",
652 | " 0.0 | \n",
653 | "
\n",
654 | " \n",
655 | " 8637 | \n",
656 | " 1.0 | \n",
657 | " 0.0 | \n",
658 | " 0.0 | \n",
659 | " 0.0 | \n",
660 | " 0.0 | \n",
661 | " 0.0 | \n",
662 | " 0.0 | \n",
663 | " 0.0 | \n",
664 | " 0.0 | \n",
665 | " 0.0 | \n",
666 | " ... | \n",
667 | " 0.0 | \n",
668 | " 0.0 | \n",
669 | " 0.0 | \n",
670 | " 0.0 | \n",
671 | " 0.0 | \n",
672 | " 0.0 | \n",
673 | " 0.0 | \n",
674 | " 0.0 | \n",
675 | " 0.0 | \n",
676 | " 0.0 | \n",
677 | "
\n",
678 | " \n",
679 | " 8638 | \n",
680 | " 1.0 | \n",
681 | " 0.0 | \n",
682 | " 0.0 | \n",
683 | " 0.0 | \n",
684 | " 0.0 | \n",
685 | " 0.0 | \n",
686 | " 0.0 | \n",
687 | " 0.0 | \n",
688 | " 0.0 | \n",
689 | " 0.0 | \n",
690 | " ... | \n",
691 | " 0.0 | \n",
692 | " 0.0 | \n",
693 | " 0.0 | \n",
694 | " 0.0 | \n",
695 | " 0.0 | \n",
696 | " 0.0 | \n",
697 | " 0.0 | \n",
698 | " 0.0 | \n",
699 | " 0.0 | \n",
700 | " 0.0 | \n",
701 | "
\n",
702 | " \n",
703 | " 8639 | \n",
704 | " 1.0 | \n",
705 | " 0.0 | \n",
706 | " 0.0 | \n",
707 | " 0.0 | \n",
708 | " 0.0 | \n",
709 | " 0.0 | \n",
710 | " 0.0 | \n",
711 | " 0.0 | \n",
712 | " 0.0 | \n",
713 | " 0.0 | \n",
714 | " ... | \n",
715 | " 0.0 | \n",
716 | " 0.0 | \n",
717 | " 0.0 | \n",
718 | " 0.0 | \n",
719 | " 0.0 | \n",
720 | " 0.0 | \n",
721 | " 0.0 | \n",
722 | " 0.0 | \n",
723 | " 0.0 | \n",
724 | " 0.0 | \n",
725 | "
\n",
726 | " \n",
727 | "
\n",
728 | "
8640 rows × 25 columns
\n",
729 | "
"
730 | ],
731 | "text/plain": [
732 | " 0 1 2 3 4 5 6 7 8 9 ... 15 16 17 \\\n",
733 | "0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
734 | "1 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
735 | "2 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
736 | "3 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
737 | "4 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
738 | "... ... ... ... ... ... ... ... ... ... ... ... ... ... ... \n",
739 | "8635 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
740 | "8636 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
741 | "8637 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
742 | "8638 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
743 | "8639 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 \n",
744 | "\n",
745 | " 18 19 20 21 22 23 24 \n",
746 | "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
747 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
748 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
749 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
750 | "4 0.0 0.0 0.0 1.0 1.0 0.0 0.0 \n",
751 | "... ... ... ... ... ... ... ... \n",
752 | "8635 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
753 | "8636 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
754 | "8637 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
755 | "8638 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
756 | "8639 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
757 | "\n",
758 | "[8640 rows x 25 columns]"
759 | ]
760 | },
761 | "execution_count": 4,
762 | "metadata": {},
763 | "output_type": "execute_result"
764 | }
765 | ],
766 | "source": [
767 | "#Test data\n",
768 | "test_arr = np.load('D:/DLR/telemanom/data/test/A-1.npy')\n",
769 | "test = pd.DataFrame(test_arr)\n",
770 | "print(test.shape)\n",
771 | "\n",
772 | "test"
773 | ]
774 | },
775 | {
776 | "cell_type": "code",
777 | "execution_count": 34,
778 | "metadata": {},
779 | "outputs": [
780 | {
781 | "data": {
782 | "text/html": [
783 | "\n",
784 | "\n",
797 | "
\n",
798 | " \n",
799 | " \n",
800 | " | \n",
801 | " 0 | \n",
802 | "
\n",
803 | " \n",
804 | " \n",
805 | " \n",
806 | " 0 | \n",
807 | " 0.976049 | \n",
808 | "
\n",
809 | " \n",
810 | " 1 | \n",
811 | " 0.976191 | \n",
812 | "
\n",
813 | " \n",
814 | " 2 | \n",
815 | " 0.976298 | \n",
816 | "
\n",
817 | " \n",
818 | " 3 | \n",
819 | " 0.976378 | \n",
820 | "
\n",
821 | " \n",
822 | " 4 | \n",
823 | " 0.976436 | \n",
824 | "
\n",
825 | " \n",
826 | " ... | \n",
827 | " ... | \n",
828 | "
\n",
829 | " \n",
830 | " 8375 | \n",
831 | " 0.976578 | \n",
832 | "
\n",
833 | " \n",
834 | " 8376 | \n",
835 | " 0.976569 | \n",
836 | "
\n",
837 | " \n",
838 | " 8377 | \n",
839 | " 0.976559 | \n",
840 | "
\n",
841 | " \n",
842 | " 8378 | \n",
843 | " 0.976551 | \n",
844 | "
\n",
845 | " \n",
846 | " 8379 | \n",
847 | " 0.976543 | \n",
848 | "
\n",
849 | " \n",
850 | "
\n",
851 | "
8380 rows × 1 columns
\n",
852 | "
"
853 | ],
854 | "text/plain": [
855 | " 0\n",
856 | "0 0.976049\n",
857 | "1 0.976191\n",
858 | "2 0.976298\n",
859 | "3 0.976378\n",
860 | "4 0.976436\n",
861 | "... ...\n",
862 | "8375 0.976578\n",
863 | "8376 0.976569\n",
864 | "8377 0.976559\n",
865 | "8378 0.976551\n",
866 | "8379 0.976543\n",
867 | "\n",
868 | "[8380 rows x 1 columns]"
869 | ]
870 | },
871 | "execution_count": 34,
872 | "metadata": {},
873 | "output_type": "execute_result"
874 | }
875 | ],
876 | "source": [
877 | "y_hat = np.load('D:/DLR/telemanom/data/2018-05-19_15.00.10/y_hat/A-1.npy')\n",
878 | "y_hat_df = pd.DataFrame(y_hat)\n",
879 | "y_hat_df"
880 | ]
881 | },
882 | {
883 | "cell_type": "markdown",
884 | "metadata": {},
885 | "source": [
886 | "## Loading pre-trained model (NASA)"
887 | ]
888 | },
889 | {
890 | "cell_type": "code",
891 | "execution_count": 6,
892 | "metadata": {},
893 | "outputs": [
894 | {
895 | "name": "stdout",
896 | "output_type": "stream",
897 | "text": [
898 | "2.0.0\n"
899 | ]
900 | }
901 | ],
902 | "source": [
903 | "print(tf.__version__)\n",
904 | "model = tf.keras.models.load_model('D:/DLR/telemanom/data/2018-05-19_15.00.10/models/A-1.h5')"
905 | ]
906 | },
907 | {
908 | "cell_type": "code",
909 | "execution_count": 8,
910 | "metadata": {},
911 | "outputs": [
912 | {
913 | "name": "stdout",
914 | "output_type": "stream",
915 | "text": [
916 | "\n",
917 | "Model: \"sequential\"\n",
918 | "_________________________________________________________________\n",
919 | "Layer (type) Output Shape Param # \n",
920 | "=================================================================\n",
921 | "lstm_5 (LSTM) (None, None, 80) 33920 \n",
922 | "_________________________________________________________________\n",
923 | "dropout_5 (Dropout) (None, None, 80) 0 \n",
924 | "_________________________________________________________________\n",
925 | "lstm_6 (LSTM) (None, 80) 51520 \n",
926 | "_________________________________________________________________\n",
927 | "dropout_6 (Dropout) (None, 80) 0 \n",
928 | "_________________________________________________________________\n",
929 | "dense_3 (Dense) (None, 10) 810 \n",
930 | "_________________________________________________________________\n",
931 | "activation_3 (Activation) (None, 10) 0 \n",
932 | "=================================================================\n",
933 | "Total params: 86,250\n",
934 | "Trainable params: 86,250\n",
935 | "Non-trainable params: 0\n",
936 | "_________________________________________________________________\n"
937 | ]
938 | }
939 | ],
940 | "source": [
941 | "print(model)\n",
942 | "#Displaying model configurations and properities\n",
943 | "model.summary()"
944 | ]
945 | },
946 | {
947 | "cell_type": "markdown",
948 | "metadata": {},
949 | "source": [
950 | "## Data Pre-processing\n",
951 | "1. The data is pre-processed with respect to its time steps as batches. As LSTM accepts only 3D data as input, the actual data is transformed into three dimensional array in the format **(data_samples, n_timesteps, n_features). => (2620, 250, 25)**
\n",
952 | "3. During pre-processing, in order to restructure the data to represent like a supervised learning problem, previous time steps are used as input variables and the next time step is used as output variable which is called **sliding window technique for time-series**
\n",
953 | "4. Here, the number of previous timesteps is **250** as per config.yaml. Thus our data will be trained and model makes the predictions by processing each 250 timesteps data with 25 features, **ten times**. \n",
954 | "5. Thus the number of predictions will be 10. The output shape will be **(2620, 10) = > (data_samples, n_predictions).**\n",
955 | "\n",
956 | "
\n",
957 | "The same is demonstrated in the below code.\n",
958 | " "
959 | ]
960 | },
961 | {
962 | "cell_type": "code",
963 | "execution_count": 9,
964 | "metadata": {},
965 | "outputs": [
966 | {
967 | "name": "stdout",
968 | "output_type": "stream",
969 | "text": [
970 | "X: (2620, 250, 25)\n",
971 | "Y: (2620, 10)\n"
972 | ]
973 | }
974 | ],
975 | "source": [
976 | "#train data\n",
977 | "train=False\n",
978 | "data = []\n",
979 | "for i in range(len(data_train) - 250 - 10):\n",
980 | " data.append(data_train[i:i + 250 + 10])\n",
981 | "data = np.array(data)\n",
982 | "\n",
983 | "assert len(data.shape) == 3\n",
984 | "\n",
985 | "if not train:\n",
986 | " X_train = data[:, :-10, :]\n",
987 | " y_train = data[:, -10:, 0]\n",
988 | "\n",
989 | "print(\"X:\",X_train.shape)\n",
990 | "print(\"Y:\",y_train.shape)"
991 | ]
992 | },
993 | {
994 | "cell_type": "code",
995 | "execution_count": 15,
996 | "metadata": {},
997 | "outputs": [],
998 | "source": [
999 | "#For test data\n",
1000 | "train=False\n",
1001 | "data = []\n",
1002 | "for i in range(len(test_arr) - 250 - 10):\n",
1003 | " data.append(test_arr[i:i + 250 + 10])\n",
1004 | "data = np.array(data)\n",
1005 | "\n",
1006 | "assert len(data.shape) == 3\n",
1007 | "\n",
1008 | "if not train:\n",
1009 | " X_test = data[:, :-10, :]\n",
1010 | " y_test = data[:, -10:, 0]\n",
1011 | "\n"
1012 | ]
1013 | },
1014 | {
1015 | "cell_type": "markdown",
1016 | "metadata": {},
1017 | "source": [
1018 | "## Model Prediction\n",
1019 | "1. As explained in the above cell, the model makes predictions with our pre-processed test data.\n",
1020 | "2. Before predicting the value, actual data is converted into **batches** and each batch is given into model seperately where the model uses previous batch information for current batch processsing and then predicts the output value.
\n",
1021 | "3. Thus each batch contains **260 data samples, 250 timesteps and 25 features** => **(260, 250, 25)**.\n",
1022 | "4. Our model outputs 10 predictions for 260 data samples as => **(260,10) predictions (y_hat)** for each batch given."
1023 | ]
1024 | },
1025 | {
1026 | "cell_type": "code",
1027 | "execution_count": 12,
1028 | "metadata": {},
1029 | "outputs": [],
1030 | "source": [
1031 | "num_batches = int((y_test.shape[0] - 250)/ 70)\n",
1032 | "for i in range(0, num_batches + 1):\n",
1033 | " prior_idx = i * 70\n",
1034 | " idx = (i + 1) * 70\n",
1035 | " \n",
1036 | " if i + 1 == num_batches + 1:\n",
1037 | " idx = y_test.shape[0]\n",
1038 | " \n",
1039 | " X_test_batch = X_test[prior_idx:idx]\n",
1040 | " y_test_batch = y_test[prior_idx:idx]\n",
1041 | " y_hat_batch = model.predict(X_test_batch)"
1042 | ]
1043 | },
1044 | {
1045 | "cell_type": "code",
1046 | "execution_count": 14,
1047 | "metadata": {},
1048 | "outputs": [
1049 | {
1050 | "name": "stdout",
1051 | "output_type": "stream",
1052 | "text": [
1053 | "X batch in test data: (260, 250, 25)\n",
1054 | "Y batch in test data: (260, 10)\n",
1055 | "y_hat predictions from test data: (260, 10)\n"
1056 | ]
1057 | }
1058 | ],
1059 | "source": [
1060 | "print(\"X batch in test data:\",X_test_batch.shape)\n",
1061 | "print(\"Y batch in test data:\",y_test_batch.shape)\n",
1062 | "print(\"y_hat predictions from test data:\",y_hat_batch.shape)"
1063 | ]
1064 | },
1065 | {
1066 | "cell_type": "markdown",
1067 | "metadata": {},
1068 | "source": [
1069 | "## SHAP Implementation\n",
1070 | "\n",
1071 | "**Reference:**
\n",
1072 | "https://arxiv.org/pdf/1903.02407.pdf
\n",
1073 | "https://github.com/liuyilin950623/SHAP_on_Autoencoder"
1074 | ]
1075 | },
1076 | {
1077 | "cell_type": "code",
1078 | "execution_count": 16,
1079 | "metadata": {},
1080 | "outputs": [
1081 | {
1082 | "data": {
1083 | "text/html": [
1084 | "\n",
1085 | "\n",
1098 | "
\n",
1099 | " \n",
1100 | " \n",
1101 | " | \n",
1102 | " 0 | \n",
1103 | " 1 | \n",
1104 | " 2 | \n",
1105 | " 3 | \n",
1106 | " 4 | \n",
1107 | " 5 | \n",
1108 | " 6 | \n",
1109 | " 7 | \n",
1110 | " 8 | \n",
1111 | " 9 | \n",
1112 | "
\n",
1113 | " \n",
1114 | " \n",
1115 | " \n",
1116 | " reconstruction_loss | \n",
1117 | " 0.96608 | \n",
1118 | " 0.967489 | \n",
1119 | " 0.956203 | \n",
1120 | " 0.964819 | \n",
1121 | " 0.960538 | \n",
1122 | " 0.966676 | \n",
1123 | " 0.952093 | \n",
1124 | " 0.96608 | \n",
1125 | " 0.966304 | \n",
1126 | " 0.968249 | \n",
1127 | "
\n",
1128 | " \n",
1129 | "
\n",
1130 | "
"
1131 | ],
1132 | "text/plain": [
1133 | " 0 1 2 3 4 \\\n",
1134 | "reconstruction_loss 0.96608 0.967489 0.956203 0.964819 0.960538 \n",
1135 | "\n",
1136 | " 5 6 7 8 9 \n",
1137 | "reconstruction_loss 0.966676 0.952093 0.96608 0.966304 0.968249 "
1138 | ]
1139 | },
1140 | "execution_count": 16,
1141 | "metadata": {},
1142 | "output_type": "execute_result"
1143 | }
1144 | ],
1145 | "source": [
1146 | "#Picking a data point with largest reconstruction or prediction error\n",
1147 | "\n",
1148 | "X_reconstruction_standard = model.predict(X_test_batch)\n",
1149 | "rec_err = np.linalg.norm(y_test_batch - X_reconstruction_standard, axis = 1)\n",
1150 | "idx = list(rec_err).index(max(rec_err))\n",
1151 | "df = pd.DataFrame(data = X_reconstruction_standard[idx], columns = ['reconstruction_loss'])\n",
1152 | "df.T"
1153 | ]
1154 | },
1155 | {
1156 | "cell_type": "code",
1157 | "execution_count": 17,
1158 | "metadata": {},
1159 | "outputs": [
1160 | {
1161 | "name": "stdout",
1162 | "output_type": "stream",
1163 | "text": [
1164 | "Model Prediction output shape for test batch: (260, 10)\n",
1165 | "\n",
1166 | "Actual Value y:\n",
1167 | " [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
1168 | "Predicted Value y_hat:\n",
1169 | " [0.9739516 0.9887368 0.9738245 0.99259615 0.9800745 0.98507845\n",
1170 | " 0.9748681 0.99089825 0.9805185 0.9821451 ]\n",
1171 | "\n",
1172 | "First five predictions:\n",
1173 | " [[0.9733607 0.9879205 0.9732397 0.9919983 0.9796158 0.9846846\n",
1174 | " 0.9742988 0.990171 0.9798947 0.9817468 ]\n",
1175 | " [0.9739516 0.9887368 0.9738245 0.99259615 0.9800745 0.98507845\n",
1176 | " 0.9748681 0.99089825 0.9805185 0.9821451 ]\n",
1177 | " [0.97446465 0.9894272 0.97432685 0.99309766 0.98047465 0.9854017\n",
1178 | " 0.9753439 0.99151456 0.98105055 0.9824883 ]\n",
1179 | " [0.9749013 0.9900011 0.9747505 0.9935121 0.9808161 0.9856611\n",
1180 | " 0.9757351 0.99203014 0.981495 0.9827764 ]\n",
1181 | " [0.97526526 0.99046975 0.9751019 0.9938499 0.9811022 0.9858655\n",
1182 | " 0.97605145 0.9924544 0.9818594 0.9830123 ]]\n",
1183 | "\n",
1184 | "First 5 Prediction (or) Re-construction error:\n",
1185 | " [0.06148452 0.0598148 0.05840042 0.05722382 0.05626203]\n"
1186 | ]
1187 | }
1188 | ],
1189 | "source": [
1190 | "print(\"Model Prediction output shape for test batch:\",X_reconstruction_standard.shape)\n",
1191 | "print(\"\\nActual Value y:\\n\",y_test_batch[1])\n",
1192 | "print(\"Predicted Value y_hat:\\n\",X_reconstruction_standard[1])\n",
1193 | "print(\"\\nFirst five predictions:\\n\",X_reconstruction_standard[0:5])\n",
1194 | "print(\"\\nFirst 5 Prediction (or) Re-construction error:\\n\",rec_err[0:5])\n"
1195 | ]
1196 | },
1197 | {
1198 | "cell_type": "code",
1199 | "execution_count": 18,
1200 | "metadata": {},
1201 | "outputs": [
1202 | {
1203 | "data": {
1204 | "text/html": [
1205 | "\n",
1206 | "\n",
1219 | "
\n",
1220 | " \n",
1221 | " \n",
1222 | " | \n",
1223 | " 9 | \n",
1224 | " 1 | \n",
1225 | " 5 | \n",
1226 | " 8 | \n",
1227 | " 7 | \n",
1228 | " 0 | \n",
1229 | " 3 | \n",
1230 | " 4 | \n",
1231 | " 2 | \n",
1232 | " 6 | \n",
1233 | "
\n",
1234 | " \n",
1235 | " \n",
1236 | " \n",
1237 | " reconstruction_loss | \n",
1238 | " 0.968249 | \n",
1239 | " 0.967489 | \n",
1240 | " 0.966676 | \n",
1241 | " 0.966304 | \n",
1242 | " 0.96608 | \n",
1243 | " 0.96608 | \n",
1244 | " 0.964819 | \n",
1245 | " 0.960538 | \n",
1246 | " 0.956203 | \n",
1247 | " 0.952093 | \n",
1248 | "
\n",
1249 | " \n",
1250 | "
\n",
1251 | "
"
1252 | ],
1253 | "text/plain": [
1254 | " 9 1 5 8 7 0 \\\n",
1255 | "reconstruction_loss 0.968249 0.967489 0.966676 0.966304 0.96608 0.96608 \n",
1256 | "\n",
1257 | " 3 4 2 6 \n",
1258 | "reconstruction_loss 0.964819 0.960538 0.956203 0.952093 "
1259 | ]
1260 | },
1261 | "execution_count": 18,
1262 | "metadata": {},
1263 | "output_type": "execute_result"
1264 | }
1265 | ],
1266 | "source": [
1267 | "# Selecting top 5 features from the error list\n",
1268 | "def sort_by_absolute(df, index):\n",
1269 | " df_abs = df.apply(lambda x: abs(x))\n",
1270 | " df_abs = df_abs.sort_values('reconstruction_loss', ascending = False)\n",
1271 | " df = df.loc[df_abs.index,:]\n",
1272 | " return df\n",
1273 | "sort_by_absolute(df, idx).T"
1274 | ]
1275 | },
1276 | {
1277 | "cell_type": "code",
1278 | "execution_count": 19,
1279 | "metadata": {},
1280 | "outputs": [
1281 | {
1282 | "data": {
1283 | "text/html": [
1284 | "\n",
1285 | "\n",
1298 | "
\n",
1299 | " \n",
1300 | " \n",
1301 | " | \n",
1302 | " 9 | \n",
1303 | " 1 | \n",
1304 | " 5 | \n",
1305 | " 8 | \n",
1306 | " 7 | \n",
1307 | "
\n",
1308 | " \n",
1309 | " \n",
1310 | " \n",
1311 | " reconstruction_loss | \n",
1312 | " 0.968249 | \n",
1313 | " 0.967489 | \n",
1314 | " 0.966676 | \n",
1315 | " 0.966304 | \n",
1316 | " 0.96608 | \n",
1317 | "
\n",
1318 | " \n",
1319 | "
\n",
1320 | "
"
1321 | ],
1322 | "text/plain": [
1323 | " 9 1 5 8 7\n",
1324 | "reconstruction_loss 0.968249 0.967489 0.966676 0.966304 0.96608"
1325 | ]
1326 | },
1327 | "execution_count": 19,
1328 | "metadata": {},
1329 | "output_type": "execute_result"
1330 | }
1331 | ],
1332 | "source": [
1333 | "top_5_features = sort_by_absolute(df, idx).iloc[:5,:]\n",
1334 | "top_5_features.T"
1335 | ]
1336 | },
1337 | {
1338 | "cell_type": "code",
1339 | "execution_count": 20,
1340 | "metadata": {},
1341 | "outputs": [
1342 | {
1343 | "name": "stderr",
1344 | "output_type": "stream",
1345 | "text": [
1346 | "Using 260 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.\n"
1347 | ]
1348 | },
1349 | {
1350 | "data": {
1351 | "application/vnd.jupyter.widget-view+json": {
1352 | "model_id": "39122d36bbf5412bbff93857bd0135d0",
1353 | "version_major": 2,
1354 | "version_minor": 0
1355 | },
1356 | "text/plain": [
1357 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=250.0), HTML(value='')))"
1358 | ]
1359 | },
1360 | "metadata": {},
1361 | "output_type": "display_data"
1362 | },
1363 | {
1364 | "name": "stdout",
1365 | "output_type": "stream",
1366 | "text": [
1367 | "\n"
1368 | ]
1369 | },
1370 | {
1371 | "ename": "IndexError",
1372 | "evalue": "index 25 is out of bounds for axis 1 with size 25",
1373 | "output_type": "error",
1374 | "traceback": [
1375 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
1376 | "\u001b[1;31mIndexError\u001b[0m Traceback (most recent call last)",
1377 | "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtop_5_features\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mexplainer_autoencoder\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mshap\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mKernelExplainer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mX_test_batch\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mshap_values\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mexplainer_autoencoder\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshap_values\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_test_batch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
1378 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\shap\\explainers\\_kernel.py\u001b[0m in \u001b[0;36mshap_values\u001b[1;34m(self, X, **kwargs)\u001b[0m\n\u001b[0;32m 176\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mkeep_index\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 177\u001b[0m \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconvert_to_instance_with_index\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcolumn_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mindex_value\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mi\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mindex_name\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 178\u001b[1;33m \u001b[0mexplanations\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexplain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 179\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 180\u001b[0m \u001b[1;31m# vector-output\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
1379 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\shap\\explainers\\_kernel.py\u001b[0m in \u001b[0;36mexplain\u001b[1;34m(self, incoming_instance, **kwargs)\u001b[0m\n\u001b[0;32m 197\u001b[0m \u001b[1;31m# convert incoming input to a standardized iml object\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 198\u001b[0m \u001b[0minstance\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mconvert_to_instance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mincoming_instance\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 199\u001b[1;33m \u001b[0mmatch_instance_to_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minstance\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 200\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 201\u001b[0m \u001b[1;31m# find the feature groups we will test. If a feature does not change from its\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
1380 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\shap\\utils\\_legacy.py\u001b[0m in \u001b[0;36mmatch_instance_to_data\u001b[1;34m(instance, data)\u001b[0m\n\u001b[0;32m 85\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mDenseData\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 86\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0minstance\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroup_display_values\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 87\u001b[1;33m \u001b[0minstance\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroup_display_values\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0minstance\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgroup\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgroup\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;34m\"\"\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mgroup\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroups\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 88\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minstance\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroup_display_values\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroups\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 89\u001b[0m \u001b[0minstance\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroups\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroups\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
1381 | "\u001b[1;32mC:\\ProgramData\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\shap\\utils\\_legacy.py\u001b[0m in \u001b[0;36m\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 85\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mDenseData\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 86\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0minstance\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroup_display_values\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 87\u001b[1;33m \u001b[0minstance\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroup_display_values\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0minstance\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgroup\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgroup\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;34m\"\"\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mgroup\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroups\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 88\u001b[0m \u001b[1;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minstance\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroup_display_values\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroups\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 89\u001b[0m \u001b[0minstance\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroups\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroups\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
1382 | "\u001b[1;31mIndexError\u001b[0m: index 25 is out of bounds for axis 1 with size 25"
1383 | ]
1384 | }
1385 | ],
1386 | "source": [
1387 | "# using SHAP\n",
1388 | "for i in top_5_features.index:\n",
1389 | " explainer_autoencoder = shap.KernelExplainer(model.predict, X_test_batch)\n",
1390 | " shap_values = explainer_autoencoder.shap_values(X_test_batch[idx,:,:])"
1391 | ]
1392 | },
1393 | {
1394 | "cell_type": "code",
1395 | "execution_count": 22,
1396 | "metadata": {},
1397 | "outputs": [
1398 | {
1399 | "name": "stdout",
1400 | "output_type": "stream",
1401 | "text": [
1402 | "Model: \"sequential\"\n",
1403 | "_________________________________________________________________\n",
1404 | "Layer (type) Output Shape Param # \n",
1405 | "=================================================================\n",
1406 | "lstm_5 (LSTM) (None, None, 80) 33920 \n",
1407 | "_________________________________________________________________\n",
1408 | "dropout_5 (Dropout) (None, None, 80) 0 \n",
1409 | "_________________________________________________________________\n",
1410 | "lstm_6 (LSTM) (None, 80) 51520 \n",
1411 | "_________________________________________________________________\n",
1412 | "dropout_6 (Dropout) (None, 80) 0 \n",
1413 | "_________________________________________________________________\n",
1414 | "dense_3 (Dense) (None, 10) 810 \n",
1415 | "_________________________________________________________________\n",
1416 | "activation_3 (Activation) (None, 10) 0 \n",
1417 | "=================================================================\n",
1418 | "Total params: 86,250\n",
1419 | "Trainable params: 86,250\n",
1420 | "Non-trainable params: 0\n",
1421 | "_________________________________________________________________\n",
1422 | "\n",
1423 | "X test batch- (260, 250, 25)\n",
1424 | "\n",
1425 | "y test batch (260, 10)\n",
1426 | "\n",
1427 | "y hat batch (260, 10)\n"
1428 | ]
1429 | }
1430 | ],
1431 | "source": [
1432 | "model.summary()\n",
1433 | "\n",
1434 | "print(\"\\nX test batch-\",X_test_batch.shape)\n",
1435 | "\n",
1436 | "print(\"\\ny test batch\",y_test_batch.shape)\n",
1437 | "\n",
1438 | "print(\"\\ny hat batch\",y_hat_batch.shape)"
1439 | ]
1440 | },
1441 | {
1442 | "cell_type": "markdown",
1443 | "metadata": {},
1444 | "source": [
1445 | "### Issues with Feature Attribution Methods (SHAP & LIME):\n",
1446 | "\n",
1447 | "**My Observations:**\n",
1448 | "\n",
1449 | "1. As our LSTM network is designed to predict the values in 10 predictions for entire batches, it's input layer accepts data of shape (260, 250, 25) and output layer outputs predictions in shape (260, 10).
\n",
1450 | "2. Here, those 10 predictions are called **lag observations (or) lag features.** as we use previous timestep information.
\n",
1451 | "3. When we apply our model and data to any feature attribution methods like **SHAP, LIME** etc. they expect the features to be the same, only then these methods can assign importance scores to every feature and return the most contributing features for the prediction [In our case - features that contributed for anomaly]. \n",
1452 | "4. Thus, the ground rule is input and output features should be equal (i.e.) If input layer has **25 features**, then the output layer also should have 25 features. But there is no actual issue with our model as it uses 25 features to make predictions and the output **10 predictions** are just lag features and not actual features which creates a conflict.
\n",
1453 | "5. Thus SHAP, LIME assumes the features to be inconsistent as they don't match and throws the above error.
\n",
1454 | "6. To illustrate in simpler terms, the input shape of the network should be equal to it's output shape.
\n",
1455 | "\n",
1456 | "The above error applies to LIME as well."
1457 | ]
1458 | },
1459 | {
1460 | "cell_type": "markdown",
1461 | "metadata": {},
1462 | "source": [
1463 | "## Comparison\n",
1464 | "To justify the observations, I have compared my implementaion with some other implementations as below."
1465 | ]
1466 | },
1467 | {
1468 | "cell_type": "markdown",
1469 | "metadata": {},
1470 | "source": [
1471 | "### Our dataset (NASA)"
1472 | ]
1473 | },
1474 | {
1475 | "cell_type": "code",
1476 | "execution_count": 24,
1477 | "metadata": {},
1478 | "outputs": [
1479 | {
1480 | "name": "stdout",
1481 | "output_type": "stream",
1482 | "text": [
1483 | "X train - (2620, 250, 25)\n",
1484 | "y train - (2620, 10)\n",
1485 | "\n",
1486 | "X test batch- (260, 250, 25)\n",
1487 | "\n",
1488 | "y test batch (260, 10)\n",
1489 | "\n",
1490 | "y hat batch (260, 10)\n"
1491 | ]
1492 | }
1493 | ],
1494 | "source": [
1495 | "print(\"X train - \",X_train.shape)\n",
1496 | "print(\"y train -\",y_train.shape)\n",
1497 | "\n",
1498 | "print(\"\\nX test batch-\",X_test_batch.shape)\n",
1499 | "\n",
1500 | "print(\"\\ny test batch\",y_test_batch.shape)\n",
1501 | "\n",
1502 | "print(\"\\ny hat batch\",y_hat_batch.shape)\n"
1503 | ]
1504 | },
1505 | {
1506 | "cell_type": "markdown",
1507 | "metadata": {},
1508 | "source": [
1509 | "### 1. SHAP implementation using Auto encoders on generic numerical dataset\n",
1510 | "\n",
1511 | "1. The dataset used is **Boston Housing dataset**. It is not a time series dataset but the problem is anomaly detection using an auto encoder.
\n",
1512 | "2. For comparison, I have printed the dataset's shape along with their model summary.
\n",
1513 | "3. It has **13 features**.
\n",
1514 | "4. As it is a normal auto encoder, the input data is two dimensional. \n",
1515 | "\n",
1516 | "**Reference**
\n",
1517 | "https://github.com/liuyilin950623/SHAP_on_Autoencoder"
1518 | ]
1519 | },
1520 | {
1521 | "cell_type": "code",
1522 | "execution_count": 31,
1523 | "metadata": {},
1524 | "outputs": [
1525 | {
1526 | "name": "stdout",
1527 | "output_type": "stream",
1528 | "text": [
1529 | "Dataset shape: (506, 13)\n",
1530 | "X train: (379, 13)\n"
1531 | ]
1532 | },
1533 | {
1534 | "data": {
1535 | "text/html": [
1536 | "\n",
1537 | "\n",
1550 | "
\n",
1551 | " \n",
1552 | " \n",
1553 | " | \n",
1554 | " CRIM | \n",
1555 | " ZN | \n",
1556 | " INDUS | \n",
1557 | " CHAS | \n",
1558 | " NOX | \n",
1559 | " RM | \n",
1560 | " AGE | \n",
1561 | " DIS | \n",
1562 | " RAD | \n",
1563 | " TAX | \n",
1564 | " PTRATIO | \n",
1565 | " B | \n",
1566 | " LSTAT | \n",
1567 | "
\n",
1568 | " \n",
1569 | " \n",
1570 | " \n",
1571 | " 0 | \n",
1572 | " 0.00632 | \n",
1573 | " 18.0 | \n",
1574 | " 2.31 | \n",
1575 | " 0.0 | \n",
1576 | " 0.538 | \n",
1577 | " 6.575 | \n",
1578 | " 65.2 | \n",
1579 | " 4.0900 | \n",
1580 | " 1.0 | \n",
1581 | " 296.0 | \n",
1582 | " 15.3 | \n",
1583 | " 396.90 | \n",
1584 | " 4.98 | \n",
1585 | "
\n",
1586 | " \n",
1587 | " 1 | \n",
1588 | " 0.02731 | \n",
1589 | " 0.0 | \n",
1590 | " 7.07 | \n",
1591 | " 0.0 | \n",
1592 | " 0.469 | \n",
1593 | " 6.421 | \n",
1594 | " 78.9 | \n",
1595 | " 4.9671 | \n",
1596 | " 2.0 | \n",
1597 | " 242.0 | \n",
1598 | " 17.8 | \n",
1599 | " 396.90 | \n",
1600 | " 9.14 | \n",
1601 | "
\n",
1602 | " \n",
1603 | " 2 | \n",
1604 | " 0.02729 | \n",
1605 | " 0.0 | \n",
1606 | " 7.07 | \n",
1607 | " 0.0 | \n",
1608 | " 0.469 | \n",
1609 | " 7.185 | \n",
1610 | " 61.1 | \n",
1611 | " 4.9671 | \n",
1612 | " 2.0 | \n",
1613 | " 242.0 | \n",
1614 | " 17.8 | \n",
1615 | " 392.83 | \n",
1616 | " 4.03 | \n",
1617 | "
\n",
1618 | " \n",
1619 | " 3 | \n",
1620 | " 0.03237 | \n",
1621 | " 0.0 | \n",
1622 | " 2.18 | \n",
1623 | " 0.0 | \n",
1624 | " 0.458 | \n",
1625 | " 6.998 | \n",
1626 | " 45.8 | \n",
1627 | " 6.0622 | \n",
1628 | " 3.0 | \n",
1629 | " 222.0 | \n",
1630 | " 18.7 | \n",
1631 | " 394.63 | \n",
1632 | " 2.94 | \n",
1633 | "
\n",
1634 | " \n",
1635 | " 4 | \n",
1636 | " 0.06905 | \n",
1637 | " 0.0 | \n",
1638 | " 2.18 | \n",
1639 | " 0.0 | \n",
1640 | " 0.458 | \n",
1641 | " 7.147 | \n",
1642 | " 54.2 | \n",
1643 | " 6.0622 | \n",
1644 | " 3.0 | \n",
1645 | " 222.0 | \n",
1646 | " 18.7 | \n",
1647 | " 396.90 | \n",
1648 | " 5.33 | \n",
1649 | "
\n",
1650 | " \n",
1651 | " ... | \n",
1652 | " ... | \n",
1653 | " ... | \n",
1654 | " ... | \n",
1655 | " ... | \n",
1656 | " ... | \n",
1657 | " ... | \n",
1658 | " ... | \n",
1659 | " ... | \n",
1660 | " ... | \n",
1661 | " ... | \n",
1662 | " ... | \n",
1663 | " ... | \n",
1664 | " ... | \n",
1665 | "
\n",
1666 | " \n",
1667 | " 501 | \n",
1668 | " 0.06263 | \n",
1669 | " 0.0 | \n",
1670 | " 11.93 | \n",
1671 | " 0.0 | \n",
1672 | " 0.573 | \n",
1673 | " 6.593 | \n",
1674 | " 69.1 | \n",
1675 | " 2.4786 | \n",
1676 | " 1.0 | \n",
1677 | " 273.0 | \n",
1678 | " 21.0 | \n",
1679 | " 391.99 | \n",
1680 | " 9.67 | \n",
1681 | "
\n",
1682 | " \n",
1683 | " 502 | \n",
1684 | " 0.04527 | \n",
1685 | " 0.0 | \n",
1686 | " 11.93 | \n",
1687 | " 0.0 | \n",
1688 | " 0.573 | \n",
1689 | " 6.120 | \n",
1690 | " 76.7 | \n",
1691 | " 2.2875 | \n",
1692 | " 1.0 | \n",
1693 | " 273.0 | \n",
1694 | " 21.0 | \n",
1695 | " 396.90 | \n",
1696 | " 9.08 | \n",
1697 | "
\n",
1698 | " \n",
1699 | " 503 | \n",
1700 | " 0.06076 | \n",
1701 | " 0.0 | \n",
1702 | " 11.93 | \n",
1703 | " 0.0 | \n",
1704 | " 0.573 | \n",
1705 | " 6.976 | \n",
1706 | " 91.0 | \n",
1707 | " 2.1675 | \n",
1708 | " 1.0 | \n",
1709 | " 273.0 | \n",
1710 | " 21.0 | \n",
1711 | " 396.90 | \n",
1712 | " 5.64 | \n",
1713 | "
\n",
1714 | " \n",
1715 | " 504 | \n",
1716 | " 0.10959 | \n",
1717 | " 0.0 | \n",
1718 | " 11.93 | \n",
1719 | " 0.0 | \n",
1720 | " 0.573 | \n",
1721 | " 6.794 | \n",
1722 | " 89.3 | \n",
1723 | " 2.3889 | \n",
1724 | " 1.0 | \n",
1725 | " 273.0 | \n",
1726 | " 21.0 | \n",
1727 | " 393.45 | \n",
1728 | " 6.48 | \n",
1729 | "
\n",
1730 | " \n",
1731 | " 505 | \n",
1732 | " 0.04741 | \n",
1733 | " 0.0 | \n",
1734 | " 11.93 | \n",
1735 | " 0.0 | \n",
1736 | " 0.573 | \n",
1737 | " 6.030 | \n",
1738 | " 80.8 | \n",
1739 | " 2.5050 | \n",
1740 | " 1.0 | \n",
1741 | " 273.0 | \n",
1742 | " 21.0 | \n",
1743 | " 396.90 | \n",
1744 | " 7.88 | \n",
1745 | "
\n",
1746 | " \n",
1747 | "
\n",
1748 | "
506 rows × 13 columns
\n",
1749 | "
"
1750 | ],
1751 | "text/plain": [
1752 | " CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX \\\n",
1753 | "0 0.00632 18.0 2.31 0.0 0.538 6.575 65.2 4.0900 1.0 296.0 \n",
1754 | "1 0.02731 0.0 7.07 0.0 0.469 6.421 78.9 4.9671 2.0 242.0 \n",
1755 | "2 0.02729 0.0 7.07 0.0 0.469 7.185 61.1 4.9671 2.0 242.0 \n",
1756 | "3 0.03237 0.0 2.18 0.0 0.458 6.998 45.8 6.0622 3.0 222.0 \n",
1757 | "4 0.06905 0.0 2.18 0.0 0.458 7.147 54.2 6.0622 3.0 222.0 \n",
1758 | ".. ... ... ... ... ... ... ... ... ... ... \n",
1759 | "501 0.06263 0.0 11.93 0.0 0.573 6.593 69.1 2.4786 1.0 273.0 \n",
1760 | "502 0.04527 0.0 11.93 0.0 0.573 6.120 76.7 2.2875 1.0 273.0 \n",
1761 | "503 0.06076 0.0 11.93 0.0 0.573 6.976 91.0 2.1675 1.0 273.0 \n",
1762 | "504 0.10959 0.0 11.93 0.0 0.573 6.794 89.3 2.3889 1.0 273.0 \n",
1763 | "505 0.04741 0.0 11.93 0.0 0.573 6.030 80.8 2.5050 1.0 273.0 \n",
1764 | "\n",
1765 | " PTRATIO B LSTAT \n",
1766 | "0 15.3 396.90 4.98 \n",
1767 | "1 17.8 396.90 9.14 \n",
1768 | "2 17.8 392.83 4.03 \n",
1769 | "3 18.7 394.63 2.94 \n",
1770 | "4 18.7 396.90 5.33 \n",
1771 | ".. ... ... ... \n",
1772 | "501 21.0 391.99 9.67 \n",
1773 | "502 21.0 396.90 9.08 \n",
1774 | "503 21.0 396.90 5.64 \n",
1775 | "504 21.0 393.45 6.48 \n",
1776 | "505 21.0 396.90 7.88 \n",
1777 | "\n",
1778 | "[506 rows x 13 columns]"
1779 | ]
1780 | },
1781 | "execution_count": 31,
1782 | "metadata": {},
1783 | "output_type": "execute_result"
1784 | }
1785 | ],
1786 | "source": [
1787 | "from sklearn.preprocessing import StandardScaler\n",
1788 | "from sklearn.model_selection import train_test_split\n",
1789 | "\n",
1790 | "Xb, yb = shap.datasets.boston()\n",
1791 | "\n",
1792 | "std = StandardScaler()\n",
1793 | "X_standard = std.fit_transform(Xb)\n",
1794 | "X_trainn, X_testt = train_test_split(X_standard)\n",
1795 | "\n",
1796 | "print(\"Dataset shape:\",Xb.shape)\n",
1797 | "print(\"X train:\",X_trainn.shape)\n",
1798 | "Xb"
1799 | ]
1800 | },
1801 | {
1802 | "cell_type": "markdown",
1803 | "metadata": {},
1804 | "source": [
1805 | "**Model summary for this implementation:**
\n",
1806 | ""
1807 | ]
1808 | },
1809 | {
1810 | "cell_type": "markdown",
1811 | "metadata": {},
1812 | "source": [
1813 | "### 2. LIME Implementation using LSTM in Co2 dataset\n",
1814 | "\n",
1815 | "1. CO2 dataset has 2 features. This is a classification problem but data is time series.
\n",
1816 | "2. They have pre-processed the data with respect to 12 timesteps.
\n",
1817 | "3. Thus their input shape is **(2270, 12, 2)** and output shape is **(2270, 2)**.\n",
1818 | "\n",
1819 | "**Reference**
\n",
1820 | "https://github.com/marcotcr/lime/blob/master/doc/notebooks/Lime%20with%20Recurrent%20Neural%20Networks.ipynb"
1821 | ]
1822 | },
1823 | {
1824 | "cell_type": "markdown",
1825 | "metadata": {},
1826 | "source": [
1827 | ""
1828 | ]
1829 | },
1830 | {
1831 | "cell_type": "markdown",
1832 | "metadata": {},
1833 | "source": [
1834 | "### Wrap-up:\n",
1835 | "\n",
1836 | "If we compare our model with the two above model summary and data shape, it is pretty evident that the input and output feature shapes should be the same.
\n",
1837 | "1. First implementation has 13 input features and output prediction also has 13 features.\n",
1838 | "2. Second implementation has 2 input features and output also has 2 features.\n",
1839 | "\n",
1840 | "From these obseravtions, we can also see that SHAP and LIME worked as expected with both datasets and models without any errors or exceptions as the input and output shapes are equal. "
1841 | ]
1842 | },
1843 | {
1844 | "cell_type": "code",
1845 | "execution_count": null,
1846 | "metadata": {},
1847 | "outputs": [],
1848 | "source": []
1849 | }
1850 | ],
1851 | "metadata": {
1852 | "kernelspec": {
1853 | "display_name": "Python 3",
1854 | "language": "python",
1855 | "name": "python3"
1856 | },
1857 | "language_info": {
1858 | "codemirror_mode": {
1859 | "name": "ipython",
1860 | "version": 3
1861 | },
1862 | "file_extension": ".py",
1863 | "mimetype": "text/x-python",
1864 | "name": "python",
1865 | "nbconvert_exporter": "python",
1866 | "pygments_lexer": "ipython3",
1867 | "version": "3.7.9"
1868 | }
1869 | },
1870 | "nbformat": 4,
1871 | "nbformat_minor": 4
1872 | }
1873 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Explaining deep LSTM models for detecting anomalies in time-series data (R&D Project)
2 |
3 | This research work focuses on comparing the existing approaches to explain the decisions of models trained using time-series data and proposing the best-fit method that generates `explanations` for a deep neural network. The proposed approach is used specifically for explaining `LSTM` networks for `anomaly detection` task in `time-series data (satellite telemetry data)`.
4 |
5 | ## Author
6 |
7 | - [@Jithin Sasikumar](https://www.github.com/Jithsaavvy)
8 |
9 | ## Acknowledgements
10 |
11 | - [@German Aerospace Center (DLR)](https://www.dlr.de/DE/Home/home_node.html)
12 | - [@Hochschule Bonn-Rhein-Sieg](https://www.h-brs.de/en)
13 |
14 | ## Languages and Tools
15 | 
16 |
17 | `Explainable AI (XAI)` methods are applied for `LSTM regression-based` anomaly detection models with different configurations `(simple/vanilla LSTM, encoder-decoder based LSTM, stacked LSTM and bi-directional LSTM)` for `satellite telemetry` data (time-series). These methods are compared based on the explanations obtained as a result of implementation as plots. A comparative analysis is performed by comparing the implemented methods based on generated results differentiating the methods fitting between differentiable and non-differentiable models.
18 |
19 | ## Why XAI?
20 |
21 | It serves to answer the `why` questions such as - _Why a specific input data is classified as anomaly?_ _What features made the model to classify a specific data as anomalous ones?_ and so on. The model should be clear and evident in detecting anomalies. Any failure will ruin the complete system, in our case - it can imperil the entire satellite if the anomaly detection mechanism is not properly done, hence it is necessary to explain the models. Explanations are generated for the users to better understand about the model prediction. To facilitate that, these `explainable AI (XAI)` methods are used.
22 |
23 | ## Problem Statement
24 |
25 | - The lack of `transparency` and `interpretability` in deep learning based models is a major drawback for understanding its decisions (or) predictions, making them as `black box` that restricts AI systems to be implemented in safety-critical applications in real-time with less supervision. Explainable AI is proposed to solve the black box problem.
26 | - Efficient anomaly detection systems should be employed to monitor the satellite telemetry channels that produces high volumes of `telemetry data`. Any failure in detecting anomaly could endanger the whole satellite.
27 | - Various methods (or) approaches are available for explaining and interpreting deep learning models, specifically for image and text data but a suitable method for a specific deep network architecture is not yet proposed for anomaly detection task using time series data.
28 |
29 | ## LSTM models used
30 |
31 | 
32 |
33 | All LSTM model architectures are illustrated in [images](./images/) folder. For example, the architecture of `encoder-decoder LSTM` with $16$ timesteps is depicted as follows:
34 |
35 | 
36 |
37 | ## XAI methods used
38 | - [LRP](http://www.heatmapping.org/)
39 | - [LIME](https://arxiv.org/abs/1602.04938)
40 | - [SHAP](https://shap.readthedocs.io/en/latest/index.html#)
41 |
42 | ## Notebooks
43 |
44 | [LRP implementation](./notebooks/LRP_Implementation.ipynb) describes the implementation of `Layer-wise Relevance Propogation (LRP)` which is a `gradient-based attribution method` for LSTM regression model trained on time-series data along with their issues and limitations. It contains data pre-processing, model summary and implementation of LRP method. My observations are documented.
45 |
46 | [LIME and SHAP implementation](./notebooks/LIME_and_SHAP_Implementation.ipynb) describes the implementation of `perturbation-based methods (SHAP and LIME)` for LSTM models with different architectures trained on time-series data along with their issues and limitations. It contains data pre-processing, model summary, model prediction, anomaly detection, SHAP and LIME implementation. All the models are developed from scratch using `tensorflow`, `keras` and trained except NASA model which is downloaded from [1] and based on [2]. The models are compared and my observations are documented.
47 |
48 | ## Results
49 |
50 | ### SHAP
51 |
52 | SHAP force plot for vanilla LSTM
53 | 
54 |
55 | SHAP force plot for bi-directional LSTM
56 | 
57 |
58 | SHAP force plot for stacked LSTM
59 | 
60 |
61 | SHAP force plot for encoder-decoder LSTM
62 | 
63 |
64 | _In the above plot, it doesn't plot anything for encoder-decoder based LSTM due to the addition of flatten layers which resulted in the computation of zero SHAP values_.
65 |
66 | ### LRP
67 |
68 | LRP is implemented for the `bi-directional LSTM` model as the current LRP implementation [3] supports only bi-directional LSTMs. which is also a deficit.
69 |
70 | 
71 | 
72 |
73 | ### LIME
74 |
75 | `Thows an exception!!!` LIME is proposed to generate explanations for any classifier. So, LIME API expects the input model to be a classifier (i.e.) the model should output probabilities with sigmoid (or) softmax output layer as activations. As our model is regression based, it has `ReLU` activation output layer, and does not output any probabilities.
76 | 
77 |
78 | ## Comparative analysis
79 |
80 | 
81 |
82 | ## Conclusions and Take aways
83 |
84 | - Based on the comparative analysis and implementation, gradient based attribution methods serves to be the `best-fit approach` for LSTM regression based models.
85 | - Out of the perturbation based methods, `SHAP` fit for our models, as it generates `global explanations` and works for any models other than classifiers whereas `LIME` does not work for regression models.
86 | - But SHAP is not a best-fit approach for neural networks, because all the perturbation based methods are proposed to work for `non-differentiable models`. They do work for neural networks as well but they are not optimized.
87 | - SHAP performs very `slower` than LRP for LSTM network based on the [execution
88 | time](./result%20plots/comp%20table%20new.PNG).
89 | - Thus, `LSTM` network being a `differentiable model` (model with gradients) works best with gradient based methods. It can be concluded that `gradient-based methods` are the best fit methods exclusively for LSTM regression based neural networks.
90 |
91 | ## Future works
92 | - The research work can be extended to implement the other explainable methods like `integrated gradients`, `DeepLIFT` and so on, as they were not implemented in this work due to time constraints.
93 | - Current LRP implementation is supported only for bi-directional LSTM models. It can be extended to support multiple LSTM model architectures with additional functionalities.
94 | - `Optimal evaluation metrics` for evaluating the explanations generated by XAI methods for LSTM networks using time series data for anomaly detection is not proposed.
95 |
96 | ### _Note !!!_
97 | 1. _As this is a research project, the implementation and other details are demonstrated using `jupyter notebooks`. Also, the code is `not refactored`._
98 | 2. _This work was carried out between **June 2020 to March 2021**. All the tools and libraries used, implementation details were based on the ones available during that time frame._
99 |
100 | ## Citation
101 |
102 | If you want to use this work, please cite:
103 |
104 | ```
105 | title = {Explaining deep learning models for detecting anomalies in time-series data},
106 | author = {Jithin Sasikumar},
107 | month = "02",
108 | year = "2021",
109 | }
110 | ```
111 |
112 | ## References
113 |
114 | [1] https://github.com/khundman/telemanom
115 | [2] Kyle Hundman, Valentino Constantinou, Christopher Laporte, Ian Colwell, and Tom Soderstrom.Detecting spacecraft anomalies using lstms and nonparametric dynamic thresholding. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery amp; Data Mining, KDD ’18, page 387–395, New York, NY, USA, 2018. Association for Computing Machinery. ISBN 9781450355520. doi: 10.1145/3219819.3219845. URL https://doi.org/10.1145/3219819.3219845
116 | [3] AlexanderWarnecke. Layerwise Relevance Propagation for LSTMs, 2021. URL https://github.com/alewarne/Layerwise-Relevance-Propagation-for-LSTMs. Accessed on: 2021-02-06. [Online].
117 | [4] https://www.kaggle.com/datasets/vinayak123tyagi/bearing-dataset
118 |
--------------------------------------------------------------------------------
/images/LSTM Models.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/images/LSTM Models.png
--------------------------------------------------------------------------------
/images/bidi.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/images/bidi.PNG
--------------------------------------------------------------------------------
/images/encdects_16_ts.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/images/encdects_16_ts.PNG
--------------------------------------------------------------------------------
/images/encdects_1_ts.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/images/encdects_1_ts.PNG
--------------------------------------------------------------------------------
/images/stackedlstmts_16_ts.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/images/stackedlstmts_16_ts.PNG
--------------------------------------------------------------------------------
/images/vanilla lstmts_1_ts.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/images/vanilla lstmts_1_ts.PNG
--------------------------------------------------------------------------------
/models/bidi_model.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/models/bidi_model.h5
--------------------------------------------------------------------------------
/models/enc_dec_model_ts1.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/models/enc_dec_model_ts1.h5
--------------------------------------------------------------------------------
/models/enc_dec_model_ts16.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/models/enc_dec_model_ts16.h5
--------------------------------------------------------------------------------
/models/stacked_lstm_model_ts16_1.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/models/stacked_lstm_model_ts16_1.h5
--------------------------------------------------------------------------------
/models/stacked_lstm_model_ts16_2.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/models/stacked_lstm_model_ts16_2.h5
--------------------------------------------------------------------------------
/models/vanilla_lstm_model_ts1.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/models/vanilla_lstm_model_ts1.h5
--------------------------------------------------------------------------------
/result plots/Bar plot - LRP relevance score for 1 instance.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/result plots/Bar plot - LRP relevance score for 1 instance.PNG
--------------------------------------------------------------------------------
/result plots/LRP relevance score for all instances.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/result plots/LRP relevance score for all instances.PNG
--------------------------------------------------------------------------------
/result plots/comp table new.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/result plots/comp table new.PNG
--------------------------------------------------------------------------------
/result plots/shap_bidi_lstm.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/result plots/shap_bidi_lstm.PNG
--------------------------------------------------------------------------------
/result plots/shap_enc_dec_ts1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/result plots/shap_enc_dec_ts1.png
--------------------------------------------------------------------------------
/result plots/shap_stacked_ lstm_ts16_1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/result plots/shap_stacked_ lstm_ts16_1.PNG
--------------------------------------------------------------------------------
/result plots/shap_vanilla_ts1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jithsaavvy/Explaining-deep-learning-models-for-detecting-anomalies-in-time-series-data-RnD-project/d4362c1d6c6ec98412e872020ab26aa7d1283e04/result plots/shap_vanilla_ts1.png
--------------------------------------------------------------------------------