├── README.md ├── LICENSE ├── .gitignore └── lstm_autoencoder_classifier.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # lstm_autoencoder_classifier 2 | An LSTM Autoencoder for rare event classification 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 cran2367 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | 3 | data/ 4 | *.DS_Store 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /lstm_autoencoder_classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# LSTM Autoencoder for Rare Event Binary Classification\n", 8 | "\n", 9 | "This is a continuation of the regular autoencoder for rare event classification presented in\n", 10 | "https://towardsdatascience.com/extreme-rare-event-classification-using-autoencoders-in-keras-a565b386f098\n", 11 | "and code present in\n", 12 | "https://github.com/cran2367/autoencoder_classifier/blob/master/autoencoder_classifier.ipynb\n", 13 | "Here we will show an implementation of building a binary classifier using LSTM Autoencoders. \n", 14 | "Similar to the previous post, the purpose is to show the implementation steps. The Autoencoder tuning for performance improvement can be done.\n", 15 | "\n", 16 | "LSTM requires closer attention to preparing the data. Here we have all the steps, and few tests to validate the data preparation.\n", 17 | "\n", 18 | "The dataset used here is taken from here,\n", 19 | "\n", 20 | "**Dataset: Rare Event Classification in Multivariate Time Series** https://arxiv.org/abs/1809.10717 (please cite this article, if using the dataset)." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 51, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "%matplotlib inline\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "import seaborn as sns\n", 32 | "\n", 33 | "import pandas as pd\n", 34 | "import numpy as np\n", 35 | "from pylab import rcParams\n", 36 | "\n", 37 | "import tensorflow as tf\n", 38 | "from keras import optimizers, Sequential\n", 39 | "from keras.models import Model\n", 40 | "from keras.utils import plot_model\n", 41 | "from keras.layers import Dense, LSTM, RepeatVector, TimeDistributed\n", 42 | "from keras.callbacks import ModelCheckpoint, TensorBoard\n", 43 | "\n", 44 | "from sklearn.preprocessing import StandardScaler\n", 45 | "from sklearn.model_selection import train_test_split\n", 46 | "from sklearn.metrics import confusion_matrix, precision_recall_curve\n", 47 | "from sklearn.metrics import recall_score, classification_report, auc, roc_curve\n", 48 | "from sklearn.metrics import precision_recall_fscore_support, f1_score\n", 49 | "\n", 50 | "from numpy.random import seed\n", 51 | "seed(7)\n", 52 | "from tensorflow import set_random_seed\n", 53 | "set_random_seed(11)\n", 54 | "\n", 55 | "from sklearn.model_selection import train_test_split\n", 56 | "\n", 57 | "SEED = 123 #used to help randomly select the data points\n", 58 | "DATA_SPLIT_PCT = 0.2\n", 59 | "\n", 60 | "rcParams['figure.figsize'] = 8, 6\n", 61 | "LABELS = [\"Normal\",\"Break\"]" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Reading and preparing data" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "The data is taken from https://arxiv.org/abs/1809.10717. Please use this source for any citation." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 52, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/html": [ 86 | "
\n", 87 | "\n", 100 | "\n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | "
timeyx1x2x3x4x5x6x7x8...x52x53x54x55x56x57x58x59x60x61
05/1/99 0:0000.376665-4.596435-4.09575613.497687-0.118830-20.6698830.000732-0.061114...10.0917210.053279-4.936434-24.59014618.5154363.4734000.0334440.9532190.0060760
15/1/99 0:0200.475720-4.542502-4.01835916.230659-0.128733-18.7580790.000732-0.061114...10.0958710.062801-4.937179-32.41326622.7600652.6829330.0335361.0905020.0060830
25/1/99 0:0400.363848-4.681394-4.35314714.127998-0.138636-17.8366320.010803-0.061114...10.1002650.072322-4.937924-34.18377427.0046633.5374870.0336291.8405400.0060900
35/1/99 0:0600.301590-4.758934-4.02361213.161567-0.148142-18.5176010.002075-0.061114...10.1046600.081600-4.938669-35.95428121.6724493.9860950.0337212.5548800.0060970
45/1/99 0:0800.265578-4.749928-4.33315015.267340-0.155314-17.5059130.000732-0.061114...10.1090540.091121-4.939414-37.72478921.9072513.6015730.0337771.4104940.0061050
\n", 250 | "

5 rows × 63 columns

\n", 251 | "
" 252 | ], 253 | "text/plain": [ 254 | " time y x1 x2 x3 x4 x5 \\\n", 255 | "0 5/1/99 0:00 0 0.376665 -4.596435 -4.095756 13.497687 -0.118830 \n", 256 | "1 5/1/99 0:02 0 0.475720 -4.542502 -4.018359 16.230659 -0.128733 \n", 257 | "2 5/1/99 0:04 0 0.363848 -4.681394 -4.353147 14.127998 -0.138636 \n", 258 | "3 5/1/99 0:06 0 0.301590 -4.758934 -4.023612 13.161567 -0.148142 \n", 259 | "4 5/1/99 0:08 0 0.265578 -4.749928 -4.333150 15.267340 -0.155314 \n", 260 | "\n", 261 | " x6 x7 x8 ... x52 x53 x54 \\\n", 262 | "0 -20.669883 0.000732 -0.061114 ... 10.091721 0.053279 -4.936434 \n", 263 | "1 -18.758079 0.000732 -0.061114 ... 10.095871 0.062801 -4.937179 \n", 264 | "2 -17.836632 0.010803 -0.061114 ... 10.100265 0.072322 -4.937924 \n", 265 | "3 -18.517601 0.002075 -0.061114 ... 10.104660 0.081600 -4.938669 \n", 266 | "4 -17.505913 0.000732 -0.061114 ... 10.109054 0.091121 -4.939414 \n", 267 | "\n", 268 | " x55 x56 x57 x58 x59 x60 x61 \n", 269 | "0 -24.590146 18.515436 3.473400 0.033444 0.953219 0.006076 0 \n", 270 | "1 -32.413266 22.760065 2.682933 0.033536 1.090502 0.006083 0 \n", 271 | "2 -34.183774 27.004663 3.537487 0.033629 1.840540 0.006090 0 \n", 272 | "3 -35.954281 21.672449 3.986095 0.033721 2.554880 0.006097 0 \n", 273 | "4 -37.724789 21.907251 3.601573 0.033777 1.410494 0.006105 0 \n", 274 | "\n", 275 | "[5 rows x 63 columns]" 276 | ] 277 | }, 278 | "execution_count": 52, 279 | "metadata": {}, 280 | "output_type": "execute_result" 281 | } 282 | ], 283 | "source": [ 284 | "'''\n", 285 | "Download data here:\n", 286 | "https://docs.google.com/forms/d/e/1FAIpQLSdyUk3lfDl7I5KYK_pw285LCApc-_RcoC0Tf9cnDnZ_TWzPAw/viewform\n", 287 | "'''\n", 288 | "df = pd.read_csv(\"data/processminer-rare-event-mts - data.csv\") \n", 289 | "df.head(n=5) # visualize the data." 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": {}, 295 | "source": [ 296 | "### Shift the data\n", 297 | "\n", 298 | "This is a timeseries data in which we have to predict the event (y = 1) ahead in time. In this data, consecutive rows are 2 minutes apart. We will shift the labels in column `y` by 2 rows to do a 4 minute ahead prediction." 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 53, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "sign = lambda x: (1, -1)[x < 0]\n", 308 | "\n", 309 | "def curve_shift(df, shift_by):\n", 310 | " '''\n", 311 | " This function will shift the binary labels in a dataframe.\n", 312 | " The curve shift will be with respect to the 1s. \n", 313 | " For example, if shift is -2, the following process\n", 314 | " will happen: if row n is labeled as 1, then\n", 315 | " - Make row (n+shift_by):(n+shift_by-1) = 1.\n", 316 | " - Remove row n.\n", 317 | " i.e. the labels will be shifted up to 2 rows up.\n", 318 | " \n", 319 | " Inputs:\n", 320 | " df A pandas dataframe with a binary labeled column. \n", 321 | " This labeled column should be named as 'y'.\n", 322 | " shift_by An integer denoting the number of rows to shift.\n", 323 | " \n", 324 | " Output\n", 325 | " df A dataframe with the binary labels shifted by shift.\n", 326 | " '''\n", 327 | "\n", 328 | " vector = df['y'].copy()\n", 329 | " for s in range(abs(shift_by)):\n", 330 | " tmp = vector.shift(sign(shift_by))\n", 331 | " tmp = tmp.fillna(0)\n", 332 | " vector += tmp\n", 333 | " labelcol = 'y'\n", 334 | " # Add vector to the df\n", 335 | " df.insert(loc=0, column=labelcol+'tmp', value=vector)\n", 336 | " # Remove the rows with labelcol == 1.\n", 337 | " df = df.drop(df[df[labelcol] == 1].index)\n", 338 | " # Drop labelcol and rename the tmp col as labelcol\n", 339 | " df = df.drop(labelcol, axis=1)\n", 340 | " df = df.rename(columns={labelcol+'tmp': labelcol})\n", 341 | " # Make the labelcol binary\n", 342 | " df.loc[df[labelcol] > 0, labelcol] = 1\n", 343 | "\n", 344 | " return df" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 54, 350 | "metadata": {}, 351 | "outputs": [ 352 | { 353 | "name": "stdout", 354 | "output_type": "stream", 355 | "text": [ 356 | "Before shifting\n" 357 | ] 358 | }, 359 | { 360 | "data": { 361 | "text/html": [ 362 | "
\n", 363 | "\n", 376 | "\n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | "
timeyx1x2x3
2565/1/99 8:3201.016235-4.058394-1.097158
2575/1/99 8:3401.005602-3.876199-1.074373
2585/1/99 8:3600.933933-3.868467-1.249954
2595/1/99 8:3810.892311-13.332664-10.006578
2605/1/99 10:5000.020062-3.987897-1.248529
\n", 430 | "
" 431 | ], 432 | "text/plain": [ 433 | " time y x1 x2 x3\n", 434 | "256 5/1/99 8:32 0 1.016235 -4.058394 -1.097158\n", 435 | "257 5/1/99 8:34 0 1.005602 -3.876199 -1.074373\n", 436 | "258 5/1/99 8:36 0 0.933933 -3.868467 -1.249954\n", 437 | "259 5/1/99 8:38 1 0.892311 -13.332664 -10.006578\n", 438 | "260 5/1/99 10:50 0 0.020062 -3.987897 -1.248529" 439 | ] 440 | }, 441 | "metadata": {}, 442 | "output_type": "display_data" 443 | }, 444 | { 445 | "name": "stdout", 446 | "output_type": "stream", 447 | "text": [ 448 | "After shifting\n" 449 | ] 450 | }, 451 | { 452 | "data": { 453 | "text/html": [ 454 | "
\n", 455 | "\n", 468 | "\n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | "
ytimex1x2x3
2550.05/1/99 8:300.997107-3.865720-1.133779
2560.05/1/99 8:321.016235-4.058394-1.097158
2571.05/1/99 8:341.005602-3.876199-1.074373
2581.05/1/99 8:360.933933-3.868467-1.249954
2600.05/1/99 10:500.020062-3.987897-1.248529
\n", 522 | "
" 523 | ], 524 | "text/plain": [ 525 | " y time x1 x2 x3\n", 526 | "255 0.0 5/1/99 8:30 0.997107 -3.865720 -1.133779\n", 527 | "256 0.0 5/1/99 8:32 1.016235 -4.058394 -1.097158\n", 528 | "257 1.0 5/1/99 8:34 1.005602 -3.876199 -1.074373\n", 529 | "258 1.0 5/1/99 8:36 0.933933 -3.868467 -1.249954\n", 530 | "260 0.0 5/1/99 10:50 0.020062 -3.987897 -1.248529" 531 | ] 532 | }, 533 | "metadata": {}, 534 | "output_type": "display_data" 535 | } 536 | ], 537 | "source": [ 538 | "'''\n", 539 | "Shift the data by 2 units, equal to 4 minutes.\n", 540 | "\n", 541 | "Test: Testing whether the shift happened correctly.\n", 542 | "'''\n", 543 | "print('Before shifting') # Positive labeled rows before shifting.\n", 544 | "one_indexes = df.index[df['y'] == 1]\n", 545 | "display(df.iloc[(one_indexes[0]-3):(one_indexes[0]+2), 0:5].head(n=5))\n", 546 | "\n", 547 | "# Shift the response column y by 2 rows to do a 4-min ahead prediction.\n", 548 | "df = curve_shift(df, shift_by = -2)\n", 549 | "\n", 550 | "print('After shifting') # Validating if the shift happened correctly.\n", 551 | "display(df.iloc[(one_indexes[0]-4):(one_indexes[0]+1), 0:5].head(n=5)) " 552 | ] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "metadata": {}, 557 | "source": [ 558 | "If we note here, we moved the positive label at 5/1/99 8:38 to t-1 and t-2 timestamps, and dropped row t. There is a time difference of more than 2 minutes between a break row and the next row because in the data consecutive break rows are deleted. This was done to prevent a classification model learn predicting a break after the break has happened. Refer https://arxiv.org/abs/1809.10717 for details." 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": 55, 564 | "metadata": {}, 565 | "outputs": [], 566 | "source": [ 567 | "# Remove time column, and the categorical columns\n", 568 | "df = df.drop(['time', 'x28', 'x61'], axis=1)" 569 | ] 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "metadata": {}, 574 | "source": [ 575 | "# Prepare data for LSTM models" 576 | ] 577 | }, 578 | { 579 | "cell_type": "markdown", 580 | "metadata": {}, 581 | "source": [ 582 | "LSTM is a bit more demanding than other models. Significant amount of time and attention goes in preparing the data that fits an LSTM.\n", 583 | "\n", 584 | "First, we will create the 3-dimensional arrays of shape: (samples x timesteps x features). Samples mean the number of data points. Timesteps is the number of time steps we look back at any time t to make a prediction. This is also referred to as lookback period. The features is the number of features the data has, in other words, the number of predictors in a multivariate data." 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 56, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "input_X = df.loc[:, df.columns != 'y'].values # converts the df to a numpy array\n", 594 | "input_y = df['y'].values\n", 595 | "\n", 596 | "n_features = input_X.shape[1] # number of features" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 57, 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [ 605 | "def temporalize(X, y, lookback):\n", 606 | " output_X = []\n", 607 | " output_y = []\n", 608 | " for i in range(len(X)-lookback-1):\n", 609 | " t = []\n", 610 | " for j in range(1,lookback+1):\n", 611 | " # Gather past records upto the lookback period\n", 612 | " t.append(X[[(i+j+1)], :])\n", 613 | " output_X.append(t)\n", 614 | " output_y.append(y[i+lookback+1])\n", 615 | " return output_X, output_y" 616 | ] 617 | }, 618 | { 619 | "cell_type": "markdown", 620 | "metadata": {}, 621 | "source": [ 622 | "In LSTM, to make prediction at any time t, we will look at data from (t-lookback):t. In the following, we have an example to show how the input data are transformed with the `temporalize` function with `lookback=5`. For the modeling, we may use a longer lookback." 623 | ] 624 | }, 625 | { 626 | "cell_type": "code", 627 | "execution_count": 58, 628 | "metadata": {}, 629 | "outputs": [ 630 | { 631 | "name": "stdout", 632 | "output_type": "stream", 633 | "text": [ 634 | "First instance of y = 1 in the original data\n" 635 | ] 636 | }, 637 | { 638 | "data": { 639 | "text/html": [ 640 | "
\n", 641 | "\n", 654 | "\n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | "
yx1x2x3x4x5x6x7x8x9...x51x52x53x54x55x56x57x58x59x60
2520.00.987078-4.025989-1.2102050.8996030.45033814.0988540.000732-0.051043-0.059966...29.98462411.248703-0.752385-5.014893-67.45403766.2325684.1142690.0337264.8450870.007776
2530.00.921726-3.728572-1.230373-1.5987180.22717814.5946120.000061-0.051043-0.040129...29.98462411.253342-0.752385-5.014987-58.02947766.3100223.5374870.0325184.9695000.007760
2540.00.975947-3.913736-1.3046820.5619870.00403414.6305320.000732-0.051043-0.040129...29.98462411.257736-0.752385-5.015081-61.78374971.9173523.4734000.0313102.9814320.007743
2550.00.997107-3.865720-1.1337790.377295-0.21912614.6664200.000732-0.061114-0.040129...29.98462411.262375-0.752385-5.015176-70.15179173.8769773.4734000.0307762.5635930.007727
2560.01.016235-4.058394-1.0971582.327307-0.44228614.7023090.000732-0.061114-0.040129...29.98462411.267013-0.752385-5.015270-60.88470172.1889284.1142690.0311862.9824540.007711
2571.01.005602-3.876199-1.0743730.844397-0.55305014.7382280.000732-0.061114-0.030057...29.98462411.271652-0.752385-5.015364-69.55389170.5008794.0501820.0315963.7467140.007695
\n", 828 | "

6 rows × 60 columns

\n", 829 | "
" 830 | ], 831 | "text/plain": [ 832 | " y x1 x2 x3 x4 x5 x6 \\\n", 833 | "252 0.0 0.987078 -4.025989 -1.210205 0.899603 0.450338 14.098854 \n", 834 | "253 0.0 0.921726 -3.728572 -1.230373 -1.598718 0.227178 14.594612 \n", 835 | "254 0.0 0.975947 -3.913736 -1.304682 0.561987 0.004034 14.630532 \n", 836 | "255 0.0 0.997107 -3.865720 -1.133779 0.377295 -0.219126 14.666420 \n", 837 | "256 0.0 1.016235 -4.058394 -1.097158 2.327307 -0.442286 14.702309 \n", 838 | "257 1.0 1.005602 -3.876199 -1.074373 0.844397 -0.553050 14.738228 \n", 839 | "\n", 840 | " x7 x8 x9 ... x51 x52 x53 \\\n", 841 | "252 0.000732 -0.051043 -0.059966 ... 29.984624 11.248703 -0.752385 \n", 842 | "253 0.000061 -0.051043 -0.040129 ... 29.984624 11.253342 -0.752385 \n", 843 | "254 0.000732 -0.051043 -0.040129 ... 29.984624 11.257736 -0.752385 \n", 844 | "255 0.000732 -0.061114 -0.040129 ... 29.984624 11.262375 -0.752385 \n", 845 | "256 0.000732 -0.061114 -0.040129 ... 29.984624 11.267013 -0.752385 \n", 846 | "257 0.000732 -0.061114 -0.030057 ... 29.984624 11.271652 -0.752385 \n", 847 | "\n", 848 | " x54 x55 x56 x57 x58 x59 x60 \n", 849 | "252 -5.014893 -67.454037 66.232568 4.114269 0.033726 4.845087 0.007776 \n", 850 | "253 -5.014987 -58.029477 66.310022 3.537487 0.032518 4.969500 0.007760 \n", 851 | "254 -5.015081 -61.783749 71.917352 3.473400 0.031310 2.981432 0.007743 \n", 852 | "255 -5.015176 -70.151791 73.876977 3.473400 0.030776 2.563593 0.007727 \n", 853 | "256 -5.015270 -60.884701 72.188928 4.114269 0.031186 2.982454 0.007711 \n", 854 | "257 -5.015364 -69.553891 70.500879 4.050182 0.031596 3.746714 0.007695 \n", 855 | "\n", 856 | "[6 rows x 60 columns]" 857 | ] 858 | }, 859 | "metadata": {}, 860 | "output_type": "display_data" 861 | }, 862 | { 863 | "name": "stdout", 864 | "output_type": "stream", 865 | "text": [ 866 | "For the same instance of y = 1, we are keeping past 5 samples in the 3D predictor array, X.\n" 867 | ] 868 | }, 869 | { 870 | "data": { 871 | "text/html": [ 872 | "
\n", 873 | "\n", 886 | "\n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | "
0123456789...49505152535455565758
00.921726-3.728572-1.230373-1.5987180.22717814.5946120.000061-0.051043-0.0401290.001791...29.98462411.253342-0.752385-5.014987-58.02947766.3100223.5374870.0325184.9695000.007760
10.975947-3.913736-1.3046820.5619870.00403414.6305320.000732-0.051043-0.0401290.001791...29.98462411.257736-0.752385-5.015081-61.78374971.9173523.4734000.0313102.9814320.007743
20.997107-3.865720-1.1337790.377295-0.21912614.6664200.000732-0.061114-0.0401290.001791...29.98462411.262375-0.752385-5.015176-70.15179173.8769773.4734000.0307762.5635930.007727
31.016235-4.058394-1.0971582.327307-0.44228614.7023090.000732-0.061114-0.0401290.001791...29.98462411.267013-0.752385-5.015270-60.88470172.1889284.1142690.0311862.9824540.007711
41.005602-3.876199-1.0743730.844397-0.55305014.7382280.000732-0.061114-0.0300570.001791...29.98462411.271652-0.752385-5.015364-69.55389170.5008794.0501820.0315963.7467140.007695
\n", 1036 | "

5 rows × 59 columns

\n", 1037 | "
" 1038 | ], 1039 | "text/plain": [ 1040 | " 0 1 2 3 4 5 6 \\\n", 1041 | "0 0.921726 -3.728572 -1.230373 -1.598718 0.227178 14.594612 0.000061 \n", 1042 | "1 0.975947 -3.913736 -1.304682 0.561987 0.004034 14.630532 0.000732 \n", 1043 | "2 0.997107 -3.865720 -1.133779 0.377295 -0.219126 14.666420 0.000732 \n", 1044 | "3 1.016235 -4.058394 -1.097158 2.327307 -0.442286 14.702309 0.000732 \n", 1045 | "4 1.005602 -3.876199 -1.074373 0.844397 -0.553050 14.738228 0.000732 \n", 1046 | "\n", 1047 | " 7 8 9 ... 49 50 51 \\\n", 1048 | "0 -0.051043 -0.040129 0.001791 ... 29.984624 11.253342 -0.752385 \n", 1049 | "1 -0.051043 -0.040129 0.001791 ... 29.984624 11.257736 -0.752385 \n", 1050 | "2 -0.061114 -0.040129 0.001791 ... 29.984624 11.262375 -0.752385 \n", 1051 | "3 -0.061114 -0.040129 0.001791 ... 29.984624 11.267013 -0.752385 \n", 1052 | "4 -0.061114 -0.030057 0.001791 ... 29.984624 11.271652 -0.752385 \n", 1053 | "\n", 1054 | " 52 53 54 55 56 57 58 \n", 1055 | "0 -5.014987 -58.029477 66.310022 3.537487 0.032518 4.969500 0.007760 \n", 1056 | "1 -5.015081 -61.783749 71.917352 3.473400 0.031310 2.981432 0.007743 \n", 1057 | "2 -5.015176 -70.151791 73.876977 3.473400 0.030776 2.563593 0.007727 \n", 1058 | "3 -5.015270 -60.884701 72.188928 4.114269 0.031186 2.982454 0.007711 \n", 1059 | "4 -5.015364 -69.553891 70.500879 4.050182 0.031596 3.746714 0.007695 \n", 1060 | "\n", 1061 | "[5 rows x 59 columns]" 1062 | ] 1063 | }, 1064 | "metadata": {}, 1065 | "output_type": "display_data" 1066 | } 1067 | ], 1068 | "source": [ 1069 | "'''\n", 1070 | "Test: The 3D tensors (arrays) for LSTM are forming correctly.\n", 1071 | "'''\n", 1072 | "print('First instance of y = 1 in the original data')\n", 1073 | "display(df.iloc[(np.where(np.array(input_y) == 1)[0][0]-5):(np.where(np.array(input_y) == 1)[0][0]+1), ])\n", 1074 | "\n", 1075 | "lookback = 5 # Equivalent to 10 min of past data.\n", 1076 | "# Temporalize the data\n", 1077 | "X, y = temporalize(X = input_X, y = input_y, lookback = lookback)\n", 1078 | "\n", 1079 | "print('For the same instance of y = 1, we are keeping past 5 samples in the 3D predictor array, X.')\n", 1080 | "display(pd.DataFrame(np.concatenate(X[np.where(np.array(y) == 1)[0][0]], axis=0 ))) " 1081 | ] 1082 | }, 1083 | { 1084 | "cell_type": "markdown", 1085 | "metadata": {}, 1086 | "source": [ 1087 | "The two tables are the same. This testifies that we are correctly taking 5 samples (= lookback), X(t):X(t-5) to predict y(t)." 1088 | ] 1089 | }, 1090 | { 1091 | "cell_type": "markdown", 1092 | "metadata": {}, 1093 | "source": [ 1094 | "### Divide the data into train, valid, and test" 1095 | ] 1096 | }, 1097 | { 1098 | "cell_type": "code", 1099 | "execution_count": 59, 1100 | "metadata": {}, 1101 | "outputs": [], 1102 | "source": [ 1103 | "X_train, X_test, y_train, y_test = train_test_split(np.array(X), np.array(y), test_size=DATA_SPLIT_PCT, random_state=SEED)\n", 1104 | "X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=DATA_SPLIT_PCT, random_state=SEED)" 1105 | ] 1106 | }, 1107 | { 1108 | "cell_type": "code", 1109 | "execution_count": 60, 1110 | "metadata": {}, 1111 | "outputs": [ 1112 | { 1113 | "data": { 1114 | "text/plain": [ 1115 | "(11691, 5, 1, 59)" 1116 | ] 1117 | }, 1118 | "execution_count": 60, 1119 | "metadata": {}, 1120 | "output_type": "execute_result" 1121 | } 1122 | ], 1123 | "source": [ 1124 | "X_train.shape" 1125 | ] 1126 | }, 1127 | { 1128 | "cell_type": "code", 1129 | "execution_count": 61, 1130 | "metadata": {}, 1131 | "outputs": [], 1132 | "source": [ 1133 | "X_train_y0 = X_train[y_train==0]\n", 1134 | "X_train_y1 = X_train[y_train==1]\n", 1135 | "\n", 1136 | "X_valid_y0 = X_valid[y_valid==0]\n", 1137 | "X_valid_y1 = X_valid[y_valid==1]" 1138 | ] 1139 | }, 1140 | { 1141 | "cell_type": "code", 1142 | "execution_count": 62, 1143 | "metadata": {}, 1144 | "outputs": [ 1145 | { 1146 | "data": { 1147 | "text/plain": [ 1148 | "(11536, 5, 1, 59)" 1149 | ] 1150 | }, 1151 | "execution_count": 62, 1152 | "metadata": {}, 1153 | "output_type": "execute_result" 1154 | } 1155 | ], 1156 | "source": [ 1157 | "X_train_y0.shape" 1158 | ] 1159 | }, 1160 | { 1161 | "cell_type": "markdown", 1162 | "metadata": {}, 1163 | "source": [ 1164 | "#### Reshaping the data\n", 1165 | "The tensors we have here are 4-dimensional. We will reshape them into the desired 3-dimensions corresponding to sample x lookback x features." 1166 | ] 1167 | }, 1168 | { 1169 | "cell_type": "code", 1170 | "execution_count": 63, 1171 | "metadata": {}, 1172 | "outputs": [], 1173 | "source": [ 1174 | "X_train = X_train.reshape(X_train.shape[0], lookback, n_features)\n", 1175 | "X_train_y0 = X_train_y0.reshape(X_train_y0.shape[0], lookback, n_features)\n", 1176 | "X_train_y1 = X_train_y1.reshape(X_train_y1.shape[0], lookback, n_features)\n", 1177 | "\n", 1178 | "X_test = X_test.reshape(X_test.shape[0], lookback, n_features)\n", 1179 | "\n", 1180 | "X_valid = X_valid.reshape(X_valid.shape[0], lookback, n_features)\n", 1181 | "X_valid_y0 = X_valid_y0.reshape(X_valid_y0.shape[0], lookback, n_features)\n", 1182 | "X_valid_y1 = X_valid_y1.reshape(X_valid_y1.shape[0], lookback, n_features)" 1183 | ] 1184 | }, 1185 | { 1186 | "cell_type": "markdown", 1187 | "metadata": {}, 1188 | "source": [ 1189 | "### Standardize the data\n", 1190 | "It is usually better to use a standardized data (transformed to Gaussian, mean 0 and sd 1) for autoencoders.\n", 1191 | "\n", 1192 | "One common mistake is: we normalize the entire data and then split into train-test. This is not correct. Test data should be completely unseen to anything during the modeling. We should normalize the test data using the feature summary statistics computed from the training data. For normalization, these statistics are the mean and variance for each feature. \n", 1193 | "\n", 1194 | "The same logic should be used for the validation set. This makes the model more stable for a test data.\n", 1195 | "\n", 1196 | "To do this, we will require two UDFs.\n", 1197 | "\n", 1198 | "- `flatten`: This function will re-create the original 2D array from which the 3D arrays were created. This function is the inverse of `temporalize`, meaning `X = flatten(temporalize(X))`.\n", 1199 | "- `scale`: This function will scale a 3D array that we created as inputs to the LSTM." 1200 | ] 1201 | }, 1202 | { 1203 | "cell_type": "code", 1204 | "execution_count": 64, 1205 | "metadata": {}, 1206 | "outputs": [], 1207 | "source": [ 1208 | "def flatten(X):\n", 1209 | " '''\n", 1210 | " Flatten a 3D array.\n", 1211 | " \n", 1212 | " Input\n", 1213 | " X A 3D array for lstm, where the array is sample x timesteps x features.\n", 1214 | " \n", 1215 | " Output\n", 1216 | " flattened_X A 2D array, sample x features.\n", 1217 | " '''\n", 1218 | " flattened_X = np.empty((X.shape[0], X.shape[2])) # sample x features array.\n", 1219 | " for i in range(X.shape[0]):\n", 1220 | " flattened_X[i] = X[i, (X.shape[1]-1), :]\n", 1221 | " return(flattened_X)\n", 1222 | "\n", 1223 | "def scale(X, scaler):\n", 1224 | " '''\n", 1225 | " Scale 3D array.\n", 1226 | "\n", 1227 | " Inputs\n", 1228 | " X A 3D array for lstm, where the array is sample x timesteps x features.\n", 1229 | " scaler A scaler object, e.g., sklearn.preprocessing.StandardScaler, sklearn.preprocessing.normalize\n", 1230 | " \n", 1231 | " Output\n", 1232 | " X Scaled 3D array.\n", 1233 | " '''\n", 1234 | " for i in range(X.shape[0]):\n", 1235 | " X[i, :, :] = scaler.transform(X[i, :, :])\n", 1236 | " \n", 1237 | " return X" 1238 | ] 1239 | }, 1240 | { 1241 | "cell_type": "code", 1242 | "execution_count": 65, 1243 | "metadata": {}, 1244 | "outputs": [], 1245 | "source": [ 1246 | "# Initialize a scaler using the training data.\n", 1247 | "scaler = StandardScaler().fit(flatten(X_train_y0))" 1248 | ] 1249 | }, 1250 | { 1251 | "cell_type": "code", 1252 | "execution_count": 66, 1253 | "metadata": {}, 1254 | "outputs": [], 1255 | "source": [ 1256 | "X_train_y0_scaled = scale(X_train_y0, scaler)\n", 1257 | "X_train_y1_scaled = scale(X_train_y1, scaler)\n", 1258 | "X_train_scaled = scale(X_train, scaler)" 1259 | ] 1260 | }, 1261 | { 1262 | "cell_type": "code", 1263 | "execution_count": 67, 1264 | "metadata": {}, 1265 | "outputs": [ 1266 | { 1267 | "name": "stdout", 1268 | "output_type": "stream", 1269 | "text": [ 1270 | "colwise mean [ 0. -0. 0. 0. 0. -0. 0. 0. 0. 0. -0. -0. -0. -0. -0. 0. -0. -0.\n", 1271 | " -0. -0. 0. 0. -0. -0. 0. 0. -0. 0. -0. 0. 0. 0. -0. -0. -0. 0.\n", 1272 | " 0. 0. -0. 0. 0. 0. -0. -0. 0. 0. 0. 0. 0. 0. 0. -0. -0. 0.\n", 1273 | " 0. -0. -0. 0. 0.]\n", 1274 | "colwise variance [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", 1275 | " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", 1276 | " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n" 1277 | ] 1278 | } 1279 | ], 1280 | "source": [ 1281 | "'''\n", 1282 | "Test: Check if the scaling is correct.\n", 1283 | "\n", 1284 | "The test succeeds if all the column means \n", 1285 | "and variances are 0 and 1, respectively, after\n", 1286 | "flattening.\n", 1287 | "'''\n", 1288 | "a = flatten(X_train_y0_scaled)\n", 1289 | "print('colwise mean', np.mean(a, axis=0).round(6))\n", 1290 | "print('colwise variance', np.var(a, axis=0))" 1291 | ] 1292 | }, 1293 | { 1294 | "cell_type": "markdown", 1295 | "metadata": {}, 1296 | "source": [ 1297 | "The test succeeded. Now we will _scale_ the validation and test sets." 1298 | ] 1299 | }, 1300 | { 1301 | "cell_type": "code", 1302 | "execution_count": 68, 1303 | "metadata": {}, 1304 | "outputs": [], 1305 | "source": [ 1306 | "X_valid_scaled = scale(X_valid, scaler)\n", 1307 | "X_valid_y0_scaled = scale(X_valid_y0, scaler)\n", 1308 | "\n", 1309 | "X_test_scaled = scale(X_test, scaler)" 1310 | ] 1311 | }, 1312 | { 1313 | "cell_type": "markdown", 1314 | "metadata": {}, 1315 | "source": [ 1316 | "## LSTM Autoencoder training" 1317 | ] 1318 | }, 1319 | { 1320 | "cell_type": "markdown", 1321 | "metadata": {}, 1322 | "source": [ 1323 | "First we will initialize the Autoencoder architecture. We are building a simple autoencoder. More complex architectures and other configurations should be explored." 1324 | ] 1325 | }, 1326 | { 1327 | "cell_type": "code", 1328 | "execution_count": 69, 1329 | "metadata": {}, 1330 | "outputs": [], 1331 | "source": [ 1332 | "timesteps = X_train_y0_scaled.shape[1] # equal to the lookback\n", 1333 | "n_features = X_train_y0_scaled.shape[2] # 59\n", 1334 | "\n", 1335 | "epochs = 200\n", 1336 | "batch = 64\n", 1337 | "lr = 0.0001" 1338 | ] 1339 | }, 1340 | { 1341 | "cell_type": "code", 1342 | "execution_count": 93, 1343 | "metadata": {}, 1344 | "outputs": [ 1345 | { 1346 | "name": "stdout", 1347 | "output_type": "stream", 1348 | "text": [ 1349 | "_________________________________________________________________\n", 1350 | "Layer (type) Output Shape Param # \n", 1351 | "=================================================================\n", 1352 | "lstm_23 (LSTM) (None, 5, 32) 11776 \n", 1353 | "_________________________________________________________________\n", 1354 | "lstm_24 (LSTM) (None, 16) 3136 \n", 1355 | "_________________________________________________________________\n", 1356 | "repeat_vector_7 (RepeatVecto (None, 5, 16) 0 \n", 1357 | "_________________________________________________________________\n", 1358 | "lstm_25 (LSTM) (None, 5, 16) 2112 \n", 1359 | "_________________________________________________________________\n", 1360 | "lstm_26 (LSTM) (None, 5, 32) 6272 \n", 1361 | "_________________________________________________________________\n", 1362 | "time_distributed_6 (TimeDist (None, 5, 59) 1947 \n", 1363 | "=================================================================\n", 1364 | "Total params: 25,243\n", 1365 | "Trainable params: 25,243\n", 1366 | "Non-trainable params: 0\n", 1367 | "_________________________________________________________________\n" 1368 | ] 1369 | } 1370 | ], 1371 | "source": [ 1372 | "lstm_autoencoder = Sequential()\n", 1373 | "# Encoder\n", 1374 | "lstm_autoencoder.add(LSTM(32, activation='relu', input_shape=(timesteps, n_features), return_sequences=True))\n", 1375 | "lstm_autoencoder.add(LSTM(16, activation='relu', return_sequences=False))\n", 1376 | "lstm_autoencoder.add(RepeatVector(timesteps))\n", 1377 | "# Decoder\n", 1378 | "lstm_autoencoder.add(LSTM(16, activation='relu', return_sequences=True))\n", 1379 | "lstm_autoencoder.add(LSTM(32, activation='relu', return_sequences=True))\n", 1380 | "lstm_autoencoder.add(TimeDistributed(Dense(n_features)))\n", 1381 | "\n", 1382 | "lstm_autoencoder.summary()" 1383 | ] 1384 | }, 1385 | { 1386 | "cell_type": "markdown", 1387 | "metadata": {}, 1388 | "source": [ 1389 | "As a rule-of-thumb, look at the number of parameters. If not using any regularization, keep this less than the number of samples. If using regularization, depending on the degree of regularization you can let more parameters in the model that is greater than the sample size. For example, if using dropout with 0.5, you can have up to double the sample size (loosely speaking)." 1390 | ] 1391 | }, 1392 | { 1393 | "cell_type": "code", 1394 | "execution_count": 94, 1395 | "metadata": {}, 1396 | "outputs": [ 1397 | { 1398 | "name": "stdout", 1399 | "output_type": "stream", 1400 | "text": [ 1401 | "Train on 11536 samples, validate on 2880 samples\n", 1402 | "Epoch 1/200\n", 1403 | " - 6s - loss: 0.9916 - val_loss: 1.0049\n", 1404 | "Epoch 2/200\n", 1405 | " - 2s - loss: 0.9459 - val_loss: 0.9127\n", 1406 | "Epoch 3/200\n", 1407 | " - 2s - loss: 0.8240 - val_loss: 0.7706\n", 1408 | "Epoch 4/200\n", 1409 | " - 2s - loss: 0.7041 - val_loss: 0.6567\n", 1410 | "Epoch 5/200\n", 1411 | " - 2s - loss: 0.6128 - val_loss: 0.5740\n", 1412 | "Epoch 6/200\n", 1413 | " - 2s - loss: 0.5494 - val_loss: 0.5183\n", 1414 | "Epoch 7/200\n", 1415 | " - 2s - loss: 0.5029 - val_loss: 0.4769\n", 1416 | "Epoch 8/200\n", 1417 | " - 2s - loss: 0.4679 - val_loss: 0.4460\n", 1418 | "Epoch 9/200\n", 1419 | " - 2s - loss: 0.4420 - val_loss: 0.4227\n", 1420 | "Epoch 10/200\n", 1421 | " - 2s - loss: 0.4210 - val_loss: 0.4055\n", 1422 | "Epoch 11/200\n", 1423 | " - 2s - loss: 0.4037 - val_loss: 0.3896\n", 1424 | "Epoch 12/200\n", 1425 | " - 2s - loss: 0.3890 - val_loss: 0.3752\n", 1426 | "Epoch 13/200\n", 1427 | " - 2s - loss: 0.3759 - val_loss: 0.3637\n", 1428 | "Epoch 14/200\n", 1429 | " - 2s - loss: 0.3644 - val_loss: 0.3539\n", 1430 | "Epoch 15/200\n", 1431 | " - 2s - loss: 0.3546 - val_loss: 0.3450\n", 1432 | "Epoch 16/200\n", 1433 | " - 2s - loss: 0.3462 - val_loss: 0.3374\n", 1434 | "Epoch 17/200\n", 1435 | " - 2s - loss: 0.3388 - val_loss: 0.3309\n", 1436 | "Epoch 18/200\n", 1437 | " - 2s - loss: 0.3320 - val_loss: 0.3249\n", 1438 | "Epoch 19/200\n", 1439 | " - 2s - loss: 0.3261 - val_loss: 0.3194\n", 1440 | "Epoch 20/200\n", 1441 | " - 2s - loss: 0.3206 - val_loss: 0.3146\n", 1442 | "Epoch 21/200\n", 1443 | " - 2s - loss: 0.3154 - val_loss: 0.3095\n", 1444 | "Epoch 22/200\n", 1445 | " - 2s - loss: 0.3105 - val_loss: 0.3051\n", 1446 | "Epoch 23/200\n", 1447 | " - 2s - loss: 0.3060 - val_loss: 0.3012\n", 1448 | "Epoch 24/200\n", 1449 | " - 2s - loss: 0.3012 - val_loss: 0.2966\n", 1450 | "Epoch 25/200\n", 1451 | " - 2s - loss: 0.2970 - val_loss: 0.2926\n", 1452 | "Epoch 26/200\n", 1453 | " - 2s - loss: 0.2928 - val_loss: 0.2895\n", 1454 | "Epoch 27/200\n", 1455 | " - 2s - loss: 0.2893 - val_loss: 0.2852\n", 1456 | "Epoch 28/200\n", 1457 | " - 2s - loss: 0.2858 - val_loss: 0.2818\n", 1458 | "Epoch 29/200\n", 1459 | " - 2s - loss: 0.2817 - val_loss: 0.2784\n", 1460 | "Epoch 30/200\n", 1461 | " - 2s - loss: 0.2782 - val_loss: 0.2753\n", 1462 | "Epoch 31/200\n", 1463 | " - 2s - loss: 0.2749 - val_loss: 0.2723\n", 1464 | "Epoch 32/200\n", 1465 | " - 2s - loss: 0.2720 - val_loss: 0.2694\n", 1466 | "Epoch 33/200\n", 1467 | " - 2s - loss: 0.2686 - val_loss: 0.2674\n", 1468 | "Epoch 34/200\n", 1469 | " - 2s - loss: 0.2658 - val_loss: 0.2643\n", 1470 | "Epoch 35/200\n", 1471 | " - 2s - loss: 0.2644 - val_loss: 0.2622\n", 1472 | "Epoch 36/200\n", 1473 | " - 2s - loss: 0.2620 - val_loss: 0.2600\n", 1474 | "Epoch 37/200\n", 1475 | " - 2s - loss: 0.2586 - val_loss: 0.2583\n", 1476 | "Epoch 38/200\n", 1477 | " - 2s - loss: 0.2565 - val_loss: 0.2559\n", 1478 | "Epoch 39/200\n", 1479 | " - 2s - loss: 0.2539 - val_loss: 0.2532\n", 1480 | "Epoch 40/200\n", 1481 | " - 2s - loss: 0.2514 - val_loss: 0.2514\n", 1482 | "Epoch 41/200\n", 1483 | " - 2s - loss: 0.2494 - val_loss: 0.2496\n", 1484 | "Epoch 42/200\n", 1485 | " - 2s - loss: 0.2472 - val_loss: 0.2468\n", 1486 | "Epoch 43/200\n", 1487 | " - 2s - loss: 0.2452 - val_loss: 0.2453\n", 1488 | "Epoch 44/200\n", 1489 | " - 2s - loss: 0.2434 - val_loss: 0.2437\n", 1490 | "Epoch 45/200\n", 1491 | " - 2s - loss: 0.2414 - val_loss: 0.2413\n", 1492 | "Epoch 46/200\n", 1493 | " - 2s - loss: 0.2395 - val_loss: 0.2396\n", 1494 | "Epoch 47/200\n", 1495 | " - 2s - loss: 0.2380 - val_loss: 0.2383\n", 1496 | "Epoch 48/200\n", 1497 | " - 2s - loss: 0.2362 - val_loss: 0.2362\n", 1498 | "Epoch 49/200\n", 1499 | " - 2s - loss: 0.2344 - val_loss: 0.2347\n", 1500 | "Epoch 50/200\n", 1501 | " - 2s - loss: 0.2328 - val_loss: 0.2333\n", 1502 | "Epoch 51/200\n", 1503 | " - 2s - loss: 0.2312 - val_loss: 0.2318\n", 1504 | "Epoch 52/200\n", 1505 | " - 2s - loss: 0.2297 - val_loss: 0.2305\n", 1506 | "Epoch 53/200\n", 1507 | " - 2s - loss: 0.2283 - val_loss: 0.2293\n", 1508 | "Epoch 54/200\n", 1509 | " - 2s - loss: 0.2269 - val_loss: 0.2278\n", 1510 | "Epoch 55/200\n", 1511 | " - 2s - loss: 0.2255 - val_loss: 0.2264\n", 1512 | "Epoch 56/200\n", 1513 | " - 2s - loss: 0.2245 - val_loss: 0.2250\n", 1514 | "Epoch 57/200\n", 1515 | " - 2s - loss: 0.2231 - val_loss: 0.2237\n", 1516 | "Epoch 58/200\n", 1517 | " - 2s - loss: 0.2216 - val_loss: 0.2227\n", 1518 | "Epoch 59/200\n", 1519 | " - 2s - loss: 0.2203 - val_loss: 0.2212\n", 1520 | "Epoch 60/200\n", 1521 | " - 2s - loss: 0.2188 - val_loss: 0.2203\n", 1522 | "Epoch 61/200\n", 1523 | " - 2s - loss: 0.2177 - val_loss: 0.2192\n", 1524 | "Epoch 62/200\n", 1525 | " - 2s - loss: 0.2166 - val_loss: 0.2186\n", 1526 | "Epoch 63/200\n", 1527 | " - 2s - loss: 0.2158 - val_loss: 0.2165\n", 1528 | "Epoch 64/200\n", 1529 | " - 2s - loss: 0.2143 - val_loss: 0.2154\n", 1530 | "Epoch 65/200\n", 1531 | " - 2s - loss: 0.2134 - val_loss: 0.2142\n", 1532 | "Epoch 66/200\n", 1533 | " - 2s - loss: 0.2122 - val_loss: 0.2136\n", 1534 | "Epoch 67/200\n", 1535 | " - 2s - loss: 0.2110 - val_loss: 0.2128\n", 1536 | "Epoch 68/200\n", 1537 | " - 2s - loss: 0.2098 - val_loss: 0.2121\n", 1538 | "Epoch 69/200\n", 1539 | " - 2s - loss: 0.2088 - val_loss: 0.2108\n", 1540 | "Epoch 70/200\n", 1541 | " - 2s - loss: 0.2080 - val_loss: 0.2102\n", 1542 | "Epoch 71/200\n", 1543 | " - 2s - loss: 0.2066 - val_loss: 0.2083\n", 1544 | "Epoch 72/200\n", 1545 | " - 2s - loss: 0.2058 - val_loss: 0.2080\n", 1546 | "Epoch 73/200\n", 1547 | " - 2s - loss: 0.2056 - val_loss: 0.2067\n", 1548 | "Epoch 74/200\n", 1549 | " - 2s - loss: 0.2039 - val_loss: 0.2059\n", 1550 | "Epoch 75/200\n", 1551 | " - 2s - loss: 0.2028 - val_loss: 0.2051\n", 1552 | "Epoch 76/200\n", 1553 | " - 2s - loss: 0.2018 - val_loss: 0.2041\n", 1554 | "Epoch 77/200\n", 1555 | " - 2s - loss: 0.2009 - val_loss: 0.2033\n", 1556 | "Epoch 78/200\n", 1557 | " - 2s - loss: 0.2001 - val_loss: 0.2025\n", 1558 | "Epoch 79/200\n", 1559 | " - 2s - loss: 0.1991 - val_loss: 0.2020\n", 1560 | "Epoch 80/200\n", 1561 | " - 2s - loss: 0.1986 - val_loss: 0.2007\n", 1562 | "Epoch 81/200\n", 1563 | " - 2s - loss: 0.1976 - val_loss: 0.2001\n", 1564 | "Epoch 82/200\n", 1565 | " - 2s - loss: 0.1967 - val_loss: 0.1992\n", 1566 | "Epoch 83/200\n", 1567 | " - 2s - loss: 0.1963 - val_loss: 0.1989\n", 1568 | "Epoch 84/200\n", 1569 | " - 2s - loss: 0.1952 - val_loss: 0.1977\n", 1570 | "Epoch 85/200\n", 1571 | " - 2s - loss: 0.1947 - val_loss: 0.1969\n", 1572 | "Epoch 86/200\n", 1573 | " - 2s - loss: 0.1935 - val_loss: 0.1963\n", 1574 | "Epoch 87/200\n", 1575 | " - 2s - loss: 0.1927 - val_loss: 0.1956\n", 1576 | "Epoch 88/200\n", 1577 | " - 2s - loss: 0.1922 - val_loss: 0.1952\n", 1578 | "Epoch 89/200\n", 1579 | " - 2s - loss: 0.1915 - val_loss: 0.1944\n", 1580 | "Epoch 90/200\n", 1581 | " - 2s - loss: 0.1914 - val_loss: 0.1938\n", 1582 | "Epoch 91/200\n", 1583 | " - 2s - loss: 0.1903 - val_loss: 0.1931\n", 1584 | "Epoch 92/200\n", 1585 | " - 2s - loss: 0.1894 - val_loss: 0.1926\n", 1586 | "Epoch 93/200\n", 1587 | " - 2s - loss: 0.1887 - val_loss: 0.1924\n", 1588 | "Epoch 94/200\n", 1589 | " - 2s - loss: 0.1881 - val_loss: 0.1910\n", 1590 | "Epoch 95/200\n", 1591 | " - 2s - loss: 0.1874 - val_loss: 0.1904\n", 1592 | "Epoch 96/200\n", 1593 | " - 2s - loss: 0.1871 - val_loss: 0.1906\n", 1594 | "Epoch 97/200\n", 1595 | " - 2s - loss: 0.1867 - val_loss: 0.1897\n", 1596 | "Epoch 98/200\n", 1597 | " - 2s - loss: 0.1860 - val_loss: 0.1891\n", 1598 | "Epoch 99/200\n", 1599 | " - 2s - loss: 0.1852 - val_loss: 0.1890\n", 1600 | "Epoch 100/200\n", 1601 | " - 2s - loss: 0.1847 - val_loss: 0.1882\n", 1602 | "Epoch 101/200\n", 1603 | " - 2s - loss: 0.1840 - val_loss: 0.1879\n", 1604 | "Epoch 102/200\n", 1605 | " - 2s - loss: 0.1836 - val_loss: 0.1869\n", 1606 | "Epoch 103/200\n", 1607 | " - 2s - loss: 0.1829 - val_loss: 0.1864\n", 1608 | "Epoch 104/200\n", 1609 | " - 2s - loss: 0.1825 - val_loss: 0.1864\n", 1610 | "Epoch 105/200\n", 1611 | " - 2s - loss: 0.1822 - val_loss: 0.1853\n", 1612 | "Epoch 106/200\n", 1613 | " - 2s - loss: 0.1818 - val_loss: 0.1851\n", 1614 | "Epoch 107/200\n", 1615 | " - 2s - loss: 0.1810 - val_loss: 0.1846\n", 1616 | "Epoch 108/200\n", 1617 | " - 2s - loss: 0.1804 - val_loss: 0.1847\n", 1618 | "Epoch 109/200\n", 1619 | " - 2s - loss: 0.1803 - val_loss: 0.1838\n", 1620 | "Epoch 110/200\n", 1621 | " - 2s - loss: 0.1799 - val_loss: 0.1837\n", 1622 | "Epoch 111/200\n", 1623 | " - 2s - loss: 0.1793 - val_loss: 0.1829\n", 1624 | "Epoch 112/200\n", 1625 | " - 2s - loss: 0.1790 - val_loss: 0.1819\n", 1626 | "Epoch 113/200\n", 1627 | " - 2s - loss: 0.1781 - val_loss: 0.1816\n", 1628 | "Epoch 114/200\n", 1629 | " - 2s - loss: 0.1784 - val_loss: 0.1811\n", 1630 | "Epoch 115/200\n", 1631 | " - 2s - loss: 0.1771 - val_loss: 0.1813\n", 1632 | "Epoch 116/200\n", 1633 | " - 2s - loss: 0.1768 - val_loss: 0.1804\n", 1634 | "Epoch 117/200\n", 1635 | " - 2s - loss: 0.1764 - val_loss: 0.1800\n", 1636 | "Epoch 118/200\n", 1637 | " - 2s - loss: 0.1761 - val_loss: 0.1798\n", 1638 | "Epoch 119/200\n", 1639 | " - 2s - loss: 0.1756 - val_loss: 0.1795\n", 1640 | "Epoch 120/200\n", 1641 | " - 2s - loss: 0.1751 - val_loss: 0.1790\n", 1642 | "Epoch 121/200\n", 1643 | " - 2s - loss: 0.1746 - val_loss: 0.1789\n", 1644 | "Epoch 122/200\n", 1645 | " - 2s - loss: 0.1743 - val_loss: 0.1785\n", 1646 | "Epoch 123/200\n", 1647 | " - 2s - loss: 0.1742 - val_loss: 0.1783\n", 1648 | "Epoch 124/200\n", 1649 | " - 2s - loss: 0.1737 - val_loss: 0.1778\n", 1650 | "Epoch 125/200\n", 1651 | " - 2s - loss: 0.1734 - val_loss: 0.1775\n", 1652 | "Epoch 126/200\n", 1653 | " - 2s - loss: 0.1739 - val_loss: 0.1774\n", 1654 | "Epoch 127/200\n", 1655 | " - 2s - loss: 0.1732 - val_loss: 0.1768\n", 1656 | "Epoch 128/200\n", 1657 | " - 2s - loss: 0.1722 - val_loss: 0.1768\n", 1658 | "Epoch 129/200\n", 1659 | " - 2s - loss: 0.1717 - val_loss: 0.1763\n", 1660 | "Epoch 130/200\n", 1661 | " - 2s - loss: 0.1713 - val_loss: 0.1757\n", 1662 | "Epoch 131/200\n", 1663 | " - 2s - loss: 0.1711 - val_loss: 0.1751\n", 1664 | "Epoch 132/200\n", 1665 | " - 2s - loss: 0.1710 - val_loss: 0.1750\n", 1666 | "Epoch 133/200\n", 1667 | " - 2s - loss: 0.1706 - val_loss: 0.1746\n", 1668 | "Epoch 134/200\n", 1669 | " - 2s - loss: 0.1703 - val_loss: 0.1745\n", 1670 | "Epoch 135/200\n", 1671 | " - 2s - loss: 0.1699 - val_loss: 0.1745\n", 1672 | "Epoch 136/200\n", 1673 | " - 2s - loss: 0.1695 - val_loss: 0.1742\n", 1674 | "Epoch 137/200\n", 1675 | " - 2s - loss: 0.1691 - val_loss: 0.1740\n", 1676 | "Epoch 138/200\n", 1677 | " - 2s - loss: 0.1688 - val_loss: 0.1734\n", 1678 | "Epoch 139/200\n", 1679 | " - 2s - loss: 0.1687 - val_loss: 0.1733\n", 1680 | "Epoch 140/200\n", 1681 | " - 2s - loss: 0.1686 - val_loss: 0.1730\n", 1682 | "Epoch 141/200\n", 1683 | " - 2s - loss: 0.1687 - val_loss: 0.1733\n", 1684 | "Epoch 142/200\n", 1685 | " - 2s - loss: 0.1682 - val_loss: 0.1729\n", 1686 | "Epoch 143/200\n", 1687 | " - 2s - loss: 0.1675 - val_loss: 0.1723\n", 1688 | "Epoch 144/200\n", 1689 | " - 2s - loss: 0.1671 - val_loss: 0.1719\n", 1690 | "Epoch 145/200\n", 1691 | " - 2s - loss: 0.1670 - val_loss: 0.1716\n", 1692 | "Epoch 146/200\n", 1693 | " - 2s - loss: 0.1668 - val_loss: 0.1714\n", 1694 | "Epoch 147/200\n", 1695 | " - 2s - loss: 0.1664 - val_loss: 0.1709\n", 1696 | "Epoch 148/200\n", 1697 | " - 2s - loss: 0.1660 - val_loss: 0.1709\n", 1698 | "Epoch 149/200\n", 1699 | " - 2s - loss: 0.1661 - val_loss: 0.1703\n", 1700 | "Epoch 150/200\n", 1701 | " - 2s - loss: 0.1656 - val_loss: 0.1704\n", 1702 | "Epoch 151/200\n", 1703 | " - 2s - loss: 0.1659 - val_loss: 0.1705\n", 1704 | "Epoch 152/200\n", 1705 | " - 2s - loss: 0.1655 - val_loss: 0.1704\n", 1706 | "Epoch 153/200\n", 1707 | " - 2s - loss: 0.1651 - val_loss: 0.1701\n" 1708 | ] 1709 | }, 1710 | { 1711 | "name": "stdout", 1712 | "output_type": "stream", 1713 | "text": [ 1714 | "Epoch 154/200\n", 1715 | " - 2s - loss: 0.1645 - val_loss: 0.1696\n", 1716 | "Epoch 155/200\n", 1717 | " - 2s - loss: 0.1643 - val_loss: 0.1694\n", 1718 | "Epoch 156/200\n", 1719 | " - 2s - loss: 0.1641 - val_loss: 0.1690\n", 1720 | "Epoch 157/200\n", 1721 | " - 2s - loss: 0.1638 - val_loss: 0.1690\n", 1722 | "Epoch 158/200\n", 1723 | " - 2s - loss: 0.1636 - val_loss: 0.1690\n", 1724 | "Epoch 159/200\n", 1725 | " - 2s - loss: 0.1640 - val_loss: 0.1684\n", 1726 | "Epoch 160/200\n", 1727 | " - 2s - loss: 0.1651 - val_loss: 0.1688\n", 1728 | "Epoch 161/200\n", 1729 | " - 2s - loss: 0.1633 - val_loss: 0.1681\n", 1730 | "Epoch 162/200\n", 1731 | " - 2s - loss: 0.1629 - val_loss: 0.1683\n", 1732 | "Epoch 163/200\n", 1733 | " - 2s - loss: 0.1625 - val_loss: 0.1681\n", 1734 | "Epoch 164/200\n", 1735 | " - 2s - loss: 0.1623 - val_loss: 0.1679\n", 1736 | "Epoch 165/200\n", 1737 | " - 2s - loss: 0.1622 - val_loss: 0.1673\n", 1738 | "Epoch 166/200\n", 1739 | " - 2s - loss: 0.1619 - val_loss: 0.1675\n", 1740 | "Epoch 167/200\n", 1741 | " - 2s - loss: 0.1617 - val_loss: 0.1669\n", 1742 | "Epoch 168/200\n", 1743 | " - 2s - loss: 0.1615 - val_loss: 0.1669\n", 1744 | "Epoch 169/200\n", 1745 | " - 2s - loss: 0.1614 - val_loss: 0.1667\n", 1746 | "Epoch 170/200\n", 1747 | " - 2s - loss: 0.1612 - val_loss: 0.1664\n", 1748 | "Epoch 171/200\n", 1749 | " - 2s - loss: 0.1616 - val_loss: 0.1669\n", 1750 | "Epoch 172/200\n", 1751 | " - 2s - loss: 0.1613 - val_loss: 0.1670\n", 1752 | "Epoch 173/200\n", 1753 | " - 2s - loss: 0.1609 - val_loss: 0.1658\n", 1754 | "Epoch 174/200\n", 1755 | " - 2s - loss: 0.1603 - val_loss: 0.1661\n", 1756 | "Epoch 175/200\n", 1757 | " - 2s - loss: 0.1601 - val_loss: 0.1658\n", 1758 | "Epoch 176/200\n", 1759 | " - 2s - loss: 0.1600 - val_loss: 0.1661\n", 1760 | "Epoch 177/200\n", 1761 | " - 2s - loss: 0.1596 - val_loss: 0.1653\n", 1762 | "Epoch 178/200\n", 1763 | " - 2s - loss: 0.1603 - val_loss: 0.1654\n", 1764 | "Epoch 179/200\n", 1765 | " - 2s - loss: 0.1598 - val_loss: 0.1656\n", 1766 | "Epoch 180/200\n", 1767 | " - 2s - loss: 0.1593 - val_loss: 0.1655\n", 1768 | "Epoch 181/200\n", 1769 | " - 2s - loss: 0.1590 - val_loss: 0.1648\n", 1770 | "Epoch 182/200\n", 1771 | " - 2s - loss: 0.1587 - val_loss: 0.1647\n", 1772 | "Epoch 183/200\n", 1773 | " - 2s - loss: 0.1586 - val_loss: 0.1650\n", 1774 | "Epoch 184/200\n", 1775 | " - 2s - loss: 0.1588 - val_loss: 0.1642\n", 1776 | "Epoch 185/200\n", 1777 | " - 2s - loss: 0.1585 - val_loss: 0.1643\n", 1778 | "Epoch 186/200\n", 1779 | " - 2s - loss: 0.1584 - val_loss: 0.1642\n", 1780 | "Epoch 187/200\n", 1781 | " - 2s - loss: 0.1579 - val_loss: 0.1644\n", 1782 | "Epoch 188/200\n", 1783 | " - 2s - loss: 0.1576 - val_loss: 0.1642\n", 1784 | "Epoch 189/200\n", 1785 | " - 2s - loss: 0.1575 - val_loss: 0.1644\n", 1786 | "Epoch 190/200\n", 1787 | " - 2s - loss: 0.1577 - val_loss: 0.1642\n", 1788 | "Epoch 191/200\n", 1789 | " - 2s - loss: 0.1627 - val_loss: 0.1636\n", 1790 | "Epoch 192/200\n", 1791 | " - 2s - loss: 0.1592 - val_loss: 0.1646\n", 1792 | "Epoch 193/200\n", 1793 | " - 2s - loss: 0.1578 - val_loss: 0.1635\n", 1794 | "Epoch 194/200\n", 1795 | " - 2s - loss: 0.1573 - val_loss: 0.1631\n", 1796 | "Epoch 195/200\n", 1797 | " - 2s - loss: 0.1569 - val_loss: 0.1632\n", 1798 | "Epoch 196/200\n", 1799 | " - 2s - loss: 0.1568 - val_loss: 0.1631\n", 1800 | "Epoch 197/200\n", 1801 | " - 2s - loss: 0.1564 - val_loss: 0.1627\n", 1802 | "Epoch 198/200\n", 1803 | " - 2s - loss: 0.1561 - val_loss: 0.1622\n", 1804 | "Epoch 199/200\n", 1805 | " - 2s - loss: 0.1560 - val_loss: 0.1626\n", 1806 | "Epoch 200/200\n", 1807 | " - 2s - loss: 0.1559 - val_loss: 0.1625\n" 1808 | ] 1809 | } 1810 | ], 1811 | "source": [ 1812 | "adam = optimizers.Adam(lr)\n", 1813 | "lstm_autoencoder.compile(loss='mse', optimizer=adam)\n", 1814 | "\n", 1815 | "cp = ModelCheckpoint(filepath=\"lstm_autoencoder_classifier.h5\",\n", 1816 | " save_best_only=True,\n", 1817 | " verbose=0)\n", 1818 | "\n", 1819 | "tb = TensorBoard(log_dir='./logs',\n", 1820 | " histogram_freq=0,\n", 1821 | " write_graph=True,\n", 1822 | " write_images=True)\n", 1823 | "\n", 1824 | "lstm_autoencoder_history = lstm_autoencoder.fit(X_train_y0_scaled, X_train_y0_scaled, \n", 1825 | " epochs=epochs, \n", 1826 | " batch_size=batch, \n", 1827 | " validation_data=(X_valid_y0_scaled, X_valid_y0_scaled),\n", 1828 | " verbose=2).history" 1829 | ] 1830 | }, 1831 | { 1832 | "cell_type": "code", 1833 | "execution_count": 95, 1834 | "metadata": {}, 1835 | "outputs": [ 1836 | { 1837 | "data": { 1838 | "image/png": "\n", 1839 | "text/plain": [ 1840 | "
" 1841 | ] 1842 | }, 1843 | "metadata": { 1844 | "needs_background": "light" 1845 | }, 1846 | "output_type": "display_data" 1847 | } 1848 | ], 1849 | "source": [ 1850 | "plt.plot(lstm_autoencoder_history['loss'], linewidth=2, label='Train')\n", 1851 | "plt.plot(lstm_autoencoder_history['val_loss'], linewidth=2, label='Valid')\n", 1852 | "plt.legend(loc='upper right')\n", 1853 | "plt.title('Model loss')\n", 1854 | "plt.ylabel('Loss')\n", 1855 | "plt.xlabel('Epoch')\n", 1856 | "plt.show()" 1857 | ] 1858 | }, 1859 | { 1860 | "cell_type": "markdown", 1861 | "metadata": {}, 1862 | "source": [ 1863 | "### Sanity check\n", 1864 | "Doing a sanity check by validating the reconstruction error \n", 1865 | "on the train data. Here we will reconstruct the entire train \n", 1866 | "data with both 0 and 1 labels.\n", 1867 | "\n", 1868 | "**Expectation**: the reconstruction error of 0 labeled data should\n", 1869 | "be smaller than 1.\n", 1870 | "\n", 1871 | "**Caution**: do not use this result for model evaluation. It may\n", 1872 | "result into overfitting issues." 1873 | ] 1874 | }, 1875 | { 1876 | "cell_type": "code", 1877 | "execution_count": 96, 1878 | "metadata": {}, 1879 | "outputs": [ 1880 | { 1881 | "data": { 1882 | "image/png": "\n", 1883 | "text/plain": [ 1884 | "
" 1885 | ] 1886 | }, 1887 | "metadata": { 1888 | "needs_background": "light" 1889 | }, 1890 | "output_type": "display_data" 1891 | } 1892 | ], 1893 | "source": [ 1894 | "train_x_predictions = lstm_autoencoder.predict(X_train_scaled)\n", 1895 | "mse = np.mean(np.power(flatten(X_train_scaled) - flatten(train_x_predictions), 2), axis=1)\n", 1896 | "\n", 1897 | "error_df = pd.DataFrame({'Reconstruction_error': mse,\n", 1898 | " 'True_class': y_train.tolist()})\n", 1899 | "\n", 1900 | "groups = error_df.groupby('True_class')\n", 1901 | "fig, ax = plt.subplots()\n", 1902 | "\n", 1903 | "for name, group in groups:\n", 1904 | " ax.plot(group.index, group.Reconstruction_error, marker='o', ms=3.5, linestyle='',\n", 1905 | " label= \"Break\" if name == 1 else \"Normal\")\n", 1906 | "ax.legend()\n", 1907 | "plt.title(\"Reconstruction error for different classes\")\n", 1908 | "plt.ylabel(\"Reconstruction error\")\n", 1909 | "plt.xlabel(\"Data point index\")\n", 1910 | "plt.show();" 1911 | ] 1912 | }, 1913 | { 1914 | "cell_type": "markdown", 1915 | "metadata": {}, 1916 | "source": [ 1917 | "## Predictions using the Autoencoder" 1918 | ] 1919 | }, 1920 | { 1921 | "cell_type": "code", 1922 | "execution_count": 97, 1923 | "metadata": {}, 1924 | "outputs": [ 1925 | { 1926 | "data": { 1927 | "image/png": "\n", 1928 | "text/plain": [ 1929 | "
" 1930 | ] 1931 | }, 1932 | "metadata": { 1933 | "needs_background": "light" 1934 | }, 1935 | "output_type": "display_data" 1936 | } 1937 | ], 1938 | "source": [ 1939 | "valid_x_predictions = lstm_autoencoder.predict(X_valid_scaled)\n", 1940 | "mse = np.mean(np.power(flatten(X_valid_scaled) - flatten(valid_x_predictions), 2), axis=1)\n", 1941 | "\n", 1942 | "error_df = pd.DataFrame({'Reconstruction_error': mse,\n", 1943 | " 'True_class': y_valid.tolist()})\n", 1944 | "\n", 1945 | "precision_rt, recall_rt, threshold_rt = precision_recall_curve(error_df.True_class, error_df.Reconstruction_error)\n", 1946 | "plt.plot(threshold_rt, precision_rt[1:], label=\"Precision\",linewidth=5)\n", 1947 | "plt.plot(threshold_rt, recall_rt[1:], label=\"Recall\",linewidth=5)\n", 1948 | "plt.title('Precision and recall for different threshold values')\n", 1949 | "plt.xlabel('Threshold')\n", 1950 | "plt.ylabel('Precision/Recall')\n", 1951 | "plt.legend()\n", 1952 | "plt.show()" 1953 | ] 1954 | }, 1955 | { 1956 | "cell_type": "code", 1957 | "execution_count": 107, 1958 | "metadata": {}, 1959 | "outputs": [ 1960 | { 1961 | "data": { 1962 | "image/png": "\n", 1963 | "text/plain": [ 1964 | "
" 1965 | ] 1966 | }, 1967 | "metadata": { 1968 | "needs_background": "light" 1969 | }, 1970 | "output_type": "display_data" 1971 | } 1972 | ], 1973 | "source": [ 1974 | "test_x_predictions = lstm_autoencoder.predict(X_test_scaled)\n", 1975 | "mse = np.mean(np.power(flatten(X_test_scaled) - flatten(test_x_predictions), 2), axis=1)\n", 1976 | "\n", 1977 | "error_df = pd.DataFrame({'Reconstruction_error': mse,\n", 1978 | " 'True_class': y_test.tolist()})\n", 1979 | "\n", 1980 | "threshold_fixed = 0.3\n", 1981 | "groups = error_df.groupby('True_class')\n", 1982 | "fig, ax = plt.subplots()\n", 1983 | "\n", 1984 | "for name, group in groups:\n", 1985 | " ax.plot(group.index, group.Reconstruction_error, marker='o', ms=3.5, linestyle='',\n", 1986 | " label= \"Break\" if name == 1 else \"Normal\")\n", 1987 | "ax.hlines(threshold_fixed, ax.get_xlim()[0], ax.get_xlim()[1], colors=\"r\", zorder=100, label='Threshold')\n", 1988 | "ax.legend()\n", 1989 | "plt.title(\"Reconstruction error for different classes\")\n", 1990 | "plt.ylabel(\"Reconstruction error\")\n", 1991 | "plt.xlabel(\"Data point index\")\n", 1992 | "plt.show();" 1993 | ] 1994 | }, 1995 | { 1996 | "cell_type": "code", 1997 | "execution_count": 108, 1998 | "metadata": {}, 1999 | "outputs": [], 2000 | "source": [ 2001 | "pred_y = [1 if e > threshold_fixed else 0 for e in error_df.Reconstruction_error.values]" 2002 | ] 2003 | }, 2004 | { 2005 | "cell_type": "code", 2006 | "execution_count": 109, 2007 | "metadata": {}, 2008 | "outputs": [ 2009 | { 2010 | "data": { 2011 | "image/png": "\n", 2012 | "text/plain": [ 2013 | "
" 2014 | ] 2015 | }, 2016 | "metadata": { 2017 | "needs_background": "light" 2018 | }, 2019 | "output_type": "display_data" 2020 | } 2021 | ], 2022 | "source": [ 2023 | "conf_matrix = confusion_matrix(error_df.True_class, pred_y)\n", 2024 | "\n", 2025 | "plt.figure(figsize=(6, 6))\n", 2026 | "sns.heatmap(conf_matrix, xticklabels=LABELS, yticklabels=LABELS, annot=True, fmt=\"d\");\n", 2027 | "plt.title(\"Confusion matrix\")\n", 2028 | "plt.ylabel('True class')\n", 2029 | "plt.xlabel('Predicted class')\n", 2030 | "plt.show()" 2031 | ] 2032 | }, 2033 | { 2034 | "cell_type": "code", 2035 | "execution_count": 78, 2036 | "metadata": {}, 2037 | "outputs": [ 2038 | { 2039 | "data": { 2040 | "image/png": "\n", 2041 | "text/plain": [ 2042 | "
" 2043 | ] 2044 | }, 2045 | "metadata": { 2046 | "needs_background": "light" 2047 | }, 2048 | "output_type": "display_data" 2049 | } 2050 | ], 2051 | "source": [ 2052 | "false_pos_rate, true_pos_rate, thresholds = roc_curve(error_df.True_class, error_df.Reconstruction_error)\n", 2053 | "roc_auc = auc(false_pos_rate, true_pos_rate,)\n", 2054 | "\n", 2055 | "plt.plot(false_pos_rate, true_pos_rate, linewidth=5, label='AUC = %0.3f'% roc_auc)\n", 2056 | "plt.plot([0,1],[0,1], linewidth=5)\n", 2057 | "\n", 2058 | "plt.xlim([-0.01, 1])\n", 2059 | "plt.ylim([0, 1.01])\n", 2060 | "plt.legend(loc='lower right')\n", 2061 | "plt.title('Receiver operating characteristic curve (ROC)')\n", 2062 | "plt.ylabel('True Positive Rate')\n", 2063 | "plt.xlabel('False Positive Rate')\n", 2064 | "plt.show()" 2065 | ] 2066 | } 2067 | ], 2068 | "metadata": { 2069 | "kernelspec": { 2070 | "display_name": "Python 3", 2071 | "language": "python", 2072 | "name": "python3" 2073 | }, 2074 | "language_info": { 2075 | "codemirror_mode": { 2076 | "name": "ipython", 2077 | "version": 3 2078 | }, 2079 | "file_extension": ".py", 2080 | "mimetype": "text/x-python", 2081 | "name": "python", 2082 | "nbconvert_exporter": "python", 2083 | "pygments_lexer": "ipython3", 2084 | "version": "3.7.1" 2085 | } 2086 | }, 2087 | "nbformat": 4, 2088 | "nbformat_minor": 2 2089 | } 2090 | --------------------------------------------------------------------------------