├── .gitignore ├── README.md ├── autoencoder_classifier.h5 └── autoencoder_classifier.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | logs/ 3 | *.DS_Store 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # autoencoder_classifier 2 | Autoencoder model for rare event classification 3 | -------------------------------------------------------------------------------- /autoencoder_classifier.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cran2367/autoencoder_classifier/6377ebe747592066a7dbbd5bec300de8ec7249c4/autoencoder_classifier.h5 -------------------------------------------------------------------------------- /autoencoder_classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Rare Event Binary Classification using Autoencoder\n", 8 | "\n", 9 | "Here we will show an implementation of building a binary classifier using Autoencoders. The purpose is to show the implementation steps. The Autoencoder tuning for performance improvement can be done.\n", 10 | "\n", 11 | "The dataset used here is taken from here,\n", 12 | "\n", 13 | "**Dataset: Rare Event Classification in Multivariate Time Series** https://arxiv.org/abs/1809.10717 (please cite this article, if using the dataset)." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stderr", 23 | "output_type": "stream", 24 | "text": [ 25 | "Using TensorFlow backend.\n" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "%matplotlib inline\n", 31 | "import matplotlib.pyplot as plt\n", 32 | "import seaborn as sns\n", 33 | "\n", 34 | "import pandas as pd\n", 35 | "import numpy as np\n", 36 | "from pylab import rcParams\n", 37 | "\n", 38 | "import tensorflow as tf\n", 39 | "from keras.models import Model, load_model\n", 40 | "from keras.layers import Input, Dense\n", 41 | "from keras.callbacks import ModelCheckpoint, TensorBoard\n", 42 | "from keras import regularizers\n", 43 | "\n", 44 | "from sklearn.preprocessing import StandardScaler\n", 45 | "from sklearn.model_selection import train_test_split\n", 46 | "from sklearn.metrics import confusion_matrix, precision_recall_curve\n", 47 | "from sklearn.metrics import recall_score, classification_report, auc, roc_curve\n", 48 | "from sklearn.metrics import precision_recall_fscore_support, f1_score\n", 49 | "\n", 50 | "from numpy.random import seed\n", 51 | "seed(1)\n", 52 | "from tensorflow import set_random_seed\n", 53 | "set_random_seed(2)\n", 54 | "\n", 55 | "SEED = 123 #used to help randomly select the data points\n", 56 | "DATA_SPLIT_PCT = 0.2\n", 57 | "\n", 58 | "rcParams['figure.figsize'] = 8, 6\n", 59 | "LABELS = [\"Normal\",\"Break\"]" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Reading and preparing data" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 2, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/html": [ 77 | "
\n", 78 | "\n", 91 | "\n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | "
timeyx1x2x3x4x5x6x7x8...x52x53x54x55x56x57x58x59x60x61
05/1/99 0:0000.376665-4.596435-4.09575613.497687-0.118830-20.6698830.000732-0.061114...10.0917210.053279-4.936434-24.59014618.5154363.4734000.0334440.9532190.0060760
15/1/99 0:0200.475720-4.542502-4.01835916.230659-0.128733-18.7580790.000732-0.061114...10.0958710.062801-4.937179-32.41326622.7600652.6829330.0335361.0905020.0060830
25/1/99 0:0400.363848-4.681394-4.35314714.127998-0.138636-17.8366320.010803-0.061114...10.1002650.072322-4.937924-34.18377427.0046633.5374870.0336291.8405400.0060900
35/1/99 0:0600.301590-4.758934-4.02361213.161567-0.148142-18.5176010.002075-0.061114...10.1046600.081600-4.938669-35.95428121.6724493.9860950.0337212.5548800.0060970
45/1/99 0:0800.265578-4.749928-4.33315015.267340-0.155314-17.5059130.000732-0.061114...10.1090540.091121-4.939414-37.72478921.9072513.6015730.0337771.4104940.0061050
\n", 241 | "

5 rows × 63 columns

\n", 242 | "
" 243 | ], 244 | "text/plain": [ 245 | " time y x1 x2 x3 x4 x5 \\\n", 246 | "0 5/1/99 0:00 0 0.376665 -4.596435 -4.095756 13.497687 -0.118830 \n", 247 | "1 5/1/99 0:02 0 0.475720 -4.542502 -4.018359 16.230659 -0.128733 \n", 248 | "2 5/1/99 0:04 0 0.363848 -4.681394 -4.353147 14.127998 -0.138636 \n", 249 | "3 5/1/99 0:06 0 0.301590 -4.758934 -4.023612 13.161567 -0.148142 \n", 250 | "4 5/1/99 0:08 0 0.265578 -4.749928 -4.333150 15.267340 -0.155314 \n", 251 | "\n", 252 | " x6 x7 x8 ... x52 x53 x54 \\\n", 253 | "0 -20.669883 0.000732 -0.061114 ... 10.091721 0.053279 -4.936434 \n", 254 | "1 -18.758079 0.000732 -0.061114 ... 10.095871 0.062801 -4.937179 \n", 255 | "2 -17.836632 0.010803 -0.061114 ... 10.100265 0.072322 -4.937924 \n", 256 | "3 -18.517601 0.002075 -0.061114 ... 10.104660 0.081600 -4.938669 \n", 257 | "4 -17.505913 0.000732 -0.061114 ... 10.109054 0.091121 -4.939414 \n", 258 | "\n", 259 | " x55 x56 x57 x58 x59 x60 x61 \n", 260 | "0 -24.590146 18.515436 3.473400 0.033444 0.953219 0.006076 0 \n", 261 | "1 -32.413266 22.760065 2.682933 0.033536 1.090502 0.006083 0 \n", 262 | "2 -34.183774 27.004663 3.537487 0.033629 1.840540 0.006090 0 \n", 263 | "3 -35.954281 21.672449 3.986095 0.033721 2.554880 0.006097 0 \n", 264 | "4 -37.724789 21.907251 3.601573 0.033777 1.410494 0.006105 0 \n", 265 | "\n", 266 | "[5 rows x 63 columns]" 267 | ] 268 | }, 269 | "execution_count": 2, 270 | "metadata": {}, 271 | "output_type": "execute_result" 272 | } 273 | ], 274 | "source": [ 275 | "'''\n", 276 | "Download data here:\n", 277 | "https://docs.google.com/forms/d/e/1FAIpQLSdyUk3lfDl7I5KYK_pw285LCApc-_RcoC0Tf9cnDnZ_TWzPAw/viewform\n", 278 | "'''\n", 279 | "df = pd.read_csv(\"data/processminer-rare-event-mts - data.csv\") \n", 280 | "df.head(n=5) # visualize the data." 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "### Shift the data\n", 288 | "\n", 289 | "This is a timeseries data in which we have to predict the event (y = 1) ahead in time. In this data, consecutive rows are 2 minutes apart. We will shift the labels in column `y` by 2 rows to do a 4 minute ahead prediction." 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 3, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "sign = lambda x: (1, -1)[x < 0]\n", 299 | "\n", 300 | "def curve_shift(df, shift_by):\n", 301 | " '''\n", 302 | " This function will shift the binary labels in a dataframe.\n", 303 | " The curve shift will be with respect to the 1s. \n", 304 | " For example, if shift is -2, the following process\n", 305 | " will happen: if row n is labeled as 1, then\n", 306 | " - Make row (n+shift_by):(n+shift_by-1) = 1.\n", 307 | " - Remove row n.\n", 308 | " i.e. the labels will be shifted up to 2 rows up.\n", 309 | " \n", 310 | " Inputs:\n", 311 | " df A pandas dataframe with a binary labeled column. \n", 312 | " This labeled column should be named as 'y'.\n", 313 | " shift_by An integer denoting the number of rows to shift.\n", 314 | " \n", 315 | " Output\n", 316 | " df A dataframe with the binary labels shifted by shift.\n", 317 | " '''\n", 318 | "\n", 319 | " vector = df['y'].copy()\n", 320 | " for s in range(abs(shift_by)):\n", 321 | " tmp = vector.shift(sign(shift_by))\n", 322 | " tmp = tmp.fillna(0)\n", 323 | " vector += tmp\n", 324 | " labelcol = 'y'\n", 325 | " # Add vector to the df\n", 326 | " df.insert(loc=0, column=labelcol+'tmp', value=vector)\n", 327 | " # Remove the rows with labelcol == 1.\n", 328 | " df = df.drop(df[df[labelcol] == 1].index)\n", 329 | " # Drop labelcol and rename the tmp col as labelcol\n", 330 | " df = df.drop(labelcol, axis=1)\n", 331 | " df = df.rename(columns={labelcol+'tmp': labelcol})\n", 332 | " # Make the labelcol binary\n", 333 | " df.loc[df[labelcol] > 0, labelcol] = 1\n", 334 | "\n", 335 | " return df" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 4, 341 | "metadata": {}, 342 | "outputs": [ 343 | { 344 | "name": "stdout", 345 | "output_type": "stream", 346 | "text": [ 347 | "Before shifting\n" 348 | ] 349 | }, 350 | { 351 | "data": { 352 | "text/html": [ 353 | "
\n", 354 | "\n", 367 | "\n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | "
timeyx1x2x3
2565/1/99 8:3201.016235-4.058394-1.097158
2575/1/99 8:3401.005602-3.876199-1.074373
2585/1/99 8:3600.933933-3.868467-1.249954
2595/1/99 8:3810.892311-13.332664-10.006578
2605/1/99 10:5000.020062-3.987897-1.248529
\n", 421 | "
" 422 | ], 423 | "text/plain": [ 424 | " time y x1 x2 x3\n", 425 | "256 5/1/99 8:32 0 1.016235 -4.058394 -1.097158\n", 426 | "257 5/1/99 8:34 0 1.005602 -3.876199 -1.074373\n", 427 | "258 5/1/99 8:36 0 0.933933 -3.868467 -1.249954\n", 428 | "259 5/1/99 8:38 1 0.892311 -13.332664 -10.006578\n", 429 | "260 5/1/99 10:50 0 0.020062 -3.987897 -1.248529" 430 | ] 431 | }, 432 | "metadata": {}, 433 | "output_type": "display_data" 434 | }, 435 | { 436 | "name": "stdout", 437 | "output_type": "stream", 438 | "text": [ 439 | "After shifting\n" 440 | ] 441 | }, 442 | { 443 | "data": { 444 | "text/html": [ 445 | "
\n", 446 | "\n", 459 | "\n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | "
ytimex1x2x3
2550.05/1/99 8:300.997107-3.865720-1.133779
2560.05/1/99 8:321.016235-4.058394-1.097158
2571.05/1/99 8:341.005602-3.876199-1.074373
2581.05/1/99 8:360.933933-3.868467-1.249954
2600.05/1/99 10:500.020062-3.987897-1.248529
\n", 513 | "
" 514 | ], 515 | "text/plain": [ 516 | " y time x1 x2 x3\n", 517 | "255 0.0 5/1/99 8:30 0.997107 -3.865720 -1.133779\n", 518 | "256 0.0 5/1/99 8:32 1.016235 -4.058394 -1.097158\n", 519 | "257 1.0 5/1/99 8:34 1.005602 -3.876199 -1.074373\n", 520 | "258 1.0 5/1/99 8:36 0.933933 -3.868467 -1.249954\n", 521 | "260 0.0 5/1/99 10:50 0.020062 -3.987897 -1.248529" 522 | ] 523 | }, 524 | "metadata": {}, 525 | "output_type": "display_data" 526 | } 527 | ], 528 | "source": [ 529 | "'''\n", 530 | "Shift the data by 2 units, equal to 4 minutes.\n", 531 | "\n", 532 | "Test: Testing whether the shift happened correctly.\n", 533 | "'''\n", 534 | "print('Before shifting') # Positive labeled rows before shifting.\n", 535 | "one_indexes = df.index[df['y'] == 1]\n", 536 | "display(df.iloc[(one_indexes[0]-3):(one_indexes[0]+2), 0:5].head(n=5))\n", 537 | "\n", 538 | "# Shift the response column y by 2 rows to do a 4-min ahead prediction.\n", 539 | "df = curve_shift(df, shift_by = -2)\n", 540 | "\n", 541 | "print('After shifting') # Validating if the shift happened correctly.\n", 542 | "display(df.iloc[(one_indexes[0]-4):(one_indexes[0]+1), 0:5].head(n=5)) " 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 5, 548 | "metadata": {}, 549 | "outputs": [], 550 | "source": [ 551 | "# Remove time column, and the categorical columns\n", 552 | "df = df.drop(['time', 'x28', 'x61'], axis=1)" 553 | ] 554 | }, 555 | { 556 | "cell_type": "markdown", 557 | "metadata": {}, 558 | "source": [ 559 | "### Divide the data into train, valid, and test" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": 6, 565 | "metadata": {}, 566 | "outputs": [], 567 | "source": [ 568 | "df_train, df_test = train_test_split(df, test_size=DATA_SPLIT_PCT, random_state=SEED)\n", 569 | "df_train, df_valid = train_test_split(df_train, test_size=DATA_SPLIT_PCT, random_state=SEED)" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "metadata": {}, 575 | "source": [ 576 | "In the autoencoder, we will be encoding only the negatively labeled data. That is, we will take the part of data for which `y=0` and build an autoencoder. For that, we will divide the datasets as following." 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 7, 582 | "metadata": {}, 583 | "outputs": [], 584 | "source": [ 585 | "df_train_0 = df_train.loc[df['y'] == 0]\n", 586 | "df_train_1 = df_train.loc[df['y'] == 1]\n", 587 | "df_train_0_x = df_train_0.drop(['y'], axis=1)\n", 588 | "df_train_1_x = df_train_1.drop(['y'], axis=1)\n", 589 | "\n", 590 | "df_valid_0 = df_valid.loc[df['y'] == 0]\n", 591 | "df_valid_1 = df_valid.loc[df['y'] == 1]\n", 592 | "df_valid_0_x = df_valid_0.drop(['y'], axis=1)\n", 593 | "df_valid_1_x = df_valid_1.drop(['y'], axis=1)\n", 594 | "\n", 595 | "df_test_0 = df_test.loc[df['y'] == 0]\n", 596 | "df_test_1 = df_test.loc[df['y'] == 1]\n", 597 | "df_test_0_x = df_test_0.drop(['y'], axis=1)\n", 598 | "df_test_1_x = df_test_1.drop(['y'], axis=1)" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "metadata": {}, 604 | "source": [ 605 | "### Standardize the data\n", 606 | "It is usually better to use a standardized data (transformed to Gaussian, mean 0 and sd 1) for autoencoders." 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": 8, 612 | "metadata": {}, 613 | "outputs": [], 614 | "source": [ 615 | "scaler = StandardScaler().fit(df_train_0_x)\n", 616 | "df_train_0_x_rescaled = scaler.transform(df_train_0_x)\n", 617 | "df_valid_0_x_rescaled = scaler.transform(df_valid_0_x)\n", 618 | "df_valid_x_rescaled = scaler.transform(df_valid.drop(['y'], axis = 1))\n", 619 | "\n", 620 | "df_test_0_x_rescaled = scaler.transform(df_test_0_x)\n", 621 | "df_test_x_rescaled = scaler.transform(df_test.drop(['y'], axis = 1))" 622 | ] 623 | }, 624 | { 625 | "cell_type": "markdown", 626 | "metadata": {}, 627 | "source": [ 628 | "## Autoencoder training" 629 | ] 630 | }, 631 | { 632 | "cell_type": "markdown", 633 | "metadata": {}, 634 | "source": [ 635 | "First we will initialize the Autoencoder architecture. We are building a simple autoencoder. More complex architectures and other configurations should be explored." 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": 106, 641 | "metadata": {}, 642 | "outputs": [ 643 | { 644 | "name": "stdout", 645 | "output_type": "stream", 646 | "text": [ 647 | "_________________________________________________________________\n", 648 | "Layer (type) Output Shape Param # \n", 649 | "=================================================================\n", 650 | "input_6 (InputLayer) (None, 59) 0 \n", 651 | "_________________________________________________________________\n", 652 | "dense_23 (Dense) (None, 32) 1920 \n", 653 | "_________________________________________________________________\n", 654 | "dense_24 (Dense) (None, 16) 528 \n", 655 | "_________________________________________________________________\n", 656 | "dense_25 (Dense) (None, 16) 272 \n", 657 | "_________________________________________________________________\n", 658 | "dense_26 (Dense) (None, 32) 544 \n", 659 | "_________________________________________________________________\n", 660 | "dense_27 (Dense) (None, 59) 1947 \n", 661 | "=================================================================\n", 662 | "Total params: 5,211\n", 663 | "Trainable params: 5,211\n", 664 | "Non-trainable params: 0\n", 665 | "_________________________________________________________________\n" 666 | ] 667 | } 668 | ], 669 | "source": [ 670 | "nb_epoch = 200\n", 671 | "batch_size = 128\n", 672 | "input_dim = df_train_0_x_rescaled.shape[1] #num of predictor variables, \n", 673 | "encoding_dim = 32\n", 674 | "hidden_dim = int(encoding_dim / 2)\n", 675 | "learning_rate = 1e-3\n", 676 | "\n", 677 | "input_layer = Input(shape=(input_dim, ))\n", 678 | "encoder = Dense(encoding_dim, activation=\"relu\", activity_regularizer=regularizers.l1(learning_rate))(input_layer)\n", 679 | "encoder = Dense(hidden_dim, activation=\"relu\")(encoder)\n", 680 | "decoder = Dense(hidden_dim, activation=\"relu\")(encoder)\n", 681 | "decoder = Dense(encoding_dim, activation=\"relu\")(decoder)\n", 682 | "decoder = Dense(input_dim, activation=\"linear\")(decoder)\n", 683 | "autoencoder = Model(inputs=input_layer, outputs=decoder)\n", 684 | "autoencoder.summary()" 685 | ] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "execution_count": 107, 690 | "metadata": {}, 691 | "outputs": [ 692 | { 693 | "name": "stdout", 694 | "output_type": "stream", 695 | "text": [ 696 | "Train on 11541 samples, validate on 2883 samples\n", 697 | "Epoch 1/200\n", 698 | "11541/11541 [==============================] - 0s 26us/step - loss: 2.0636 - acc: 0.0358 - val_loss: 1.4716 - val_acc: 0.0555\n", 699 | "Epoch 2/200\n", 700 | "11541/11541 [==============================] - 0s 9us/step - loss: 1.1450 - acc: 0.0787 - val_loss: 0.9044 - val_acc: 0.0711\n", 701 | "Epoch 3/200\n", 702 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.7805 - acc: 0.0728 - val_loss: 0.6991 - val_acc: 0.0832\n", 703 | "Epoch 4/200\n", 704 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.6396 - acc: 0.0994 - val_loss: 0.6002 - val_acc: 0.1145\n", 705 | "Epoch 5/200\n", 706 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.5665 - acc: 0.1396 - val_loss: 0.5466 - val_acc: 0.1471\n", 707 | "Epoch 6/200\n", 708 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.5214 - acc: 0.1752 - val_loss: 0.5085 - val_acc: 0.1894\n", 709 | "Epoch 7/200\n", 710 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.4869 - acc: 0.2054 - val_loss: 0.4756 - val_acc: 0.2112\n", 711 | "Epoch 8/200\n", 712 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.4584 - acc: 0.2346 - val_loss: 0.4485 - val_acc: 0.2404\n", 713 | "Epoch 9/200\n", 714 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.4348 - acc: 0.2428 - val_loss: 0.4279 - val_acc: 0.2338\n", 715 | "Epoch 10/200\n", 716 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.4140 - acc: 0.2442 - val_loss: 0.4076 - val_acc: 0.2494\n", 717 | "Epoch 11/200\n", 718 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.3968 - acc: 0.2470 - val_loss: 0.3919 - val_acc: 0.2366\n", 719 | "Epoch 12/200\n", 720 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.3827 - acc: 0.2582 - val_loss: 0.3795 - val_acc: 0.2518\n", 721 | "Epoch 13/200\n", 722 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.3715 - acc: 0.2651 - val_loss: 0.3683 - val_acc: 0.2636\n", 723 | "Epoch 14/200\n", 724 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.3631 - acc: 0.2761 - val_loss: 0.3591 - val_acc: 0.2782\n", 725 | "Epoch 15/200\n", 726 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.3529 - acc: 0.2878 - val_loss: 0.3506 - val_acc: 0.2969\n", 727 | "Epoch 16/200\n", 728 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.3451 - acc: 0.2897 - val_loss: 0.3433 - val_acc: 0.2692\n", 729 | "Epoch 17/200\n", 730 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.3383 - acc: 0.2944 - val_loss: 0.3377 - val_acc: 0.2869\n", 731 | "Epoch 18/200\n", 732 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.3316 - acc: 0.3010 - val_loss: 0.3339 - val_acc: 0.3084\n", 733 | "Epoch 19/200\n", 734 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.3275 - acc: 0.3021 - val_loss: 0.3259 - val_acc: 0.3014\n", 735 | "Epoch 20/200\n", 736 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.3205 - acc: 0.3047 - val_loss: 0.3215 - val_acc: 0.2959\n", 737 | "Epoch 21/200\n", 738 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.3155 - acc: 0.3092 - val_loss: 0.3146 - val_acc: 0.3032\n", 739 | "Epoch 22/200\n", 740 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.3106 - acc: 0.3194 - val_loss: 0.3126 - val_acc: 0.3104\n", 741 | "Epoch 23/200\n", 742 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.3064 - acc: 0.3144 - val_loss: 0.3107 - val_acc: 0.2973\n", 743 | "Epoch 24/200\n", 744 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.3033 - acc: 0.3191 - val_loss: 0.3038 - val_acc: 0.3160\n", 745 | "Epoch 25/200\n", 746 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.3001 - acc: 0.3191 - val_loss: 0.3012 - val_acc: 0.3091\n", 747 | "Epoch 26/200\n", 748 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2950 - acc: 0.3260 - val_loss: 0.2994 - val_acc: 0.3247\n", 749 | "Epoch 27/200\n", 750 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2937 - acc: 0.3282 - val_loss: 0.2947 - val_acc: 0.3365\n", 751 | "Epoch 28/200\n", 752 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2898 - acc: 0.3379 - val_loss: 0.2938 - val_acc: 0.3313\n", 753 | "Epoch 29/200\n", 754 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2897 - acc: 0.3402 - val_loss: 0.2899 - val_acc: 0.3351\n", 755 | "Epoch 30/200\n", 756 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2848 - acc: 0.3399 - val_loss: 0.2901 - val_acc: 0.3441\n", 757 | "Epoch 31/200\n", 758 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2824 - acc: 0.3468 - val_loss: 0.2902 - val_acc: 0.3552\n", 759 | "Epoch 32/200\n", 760 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2814 - acc: 0.3528 - val_loss: 0.2846 - val_acc: 0.3552\n", 761 | "Epoch 33/200\n", 762 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2782 - acc: 0.3573 - val_loss: 0.2828 - val_acc: 0.3569\n", 763 | "Epoch 34/200\n", 764 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2753 - acc: 0.3670 - val_loss: 0.2804 - val_acc: 0.3535\n", 765 | "Epoch 35/200\n", 766 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2740 - acc: 0.3718 - val_loss: 0.2777 - val_acc: 0.3673\n", 767 | "Epoch 36/200\n", 768 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2726 - acc: 0.3707 - val_loss: 0.2742 - val_acc: 0.3711\n", 769 | "Epoch 37/200\n", 770 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2692 - acc: 0.3807 - val_loss: 0.2727 - val_acc: 0.3826\n", 771 | "Epoch 38/200\n", 772 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2746 - acc: 0.3850 - val_loss: 0.2746 - val_acc: 0.3933\n", 773 | "Epoch 39/200\n", 774 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2675 - acc: 0.3925 - val_loss: 0.2698 - val_acc: 0.3909\n", 775 | "Epoch 40/200\n", 776 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2659 - acc: 0.3897 - val_loss: 0.2686 - val_acc: 0.3913\n", 777 | "Epoch 41/200\n", 778 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2656 - acc: 0.3972 - val_loss: 0.2665 - val_acc: 0.3812\n", 779 | "Epoch 42/200\n", 780 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2647 - acc: 0.3988 - val_loss: 0.2690 - val_acc: 0.3871\n", 781 | "Epoch 43/200\n", 782 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2624 - acc: 0.3989 - val_loss: 0.2663 - val_acc: 0.4017\n", 783 | "Epoch 44/200\n", 784 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2600 - acc: 0.4007 - val_loss: 0.2635 - val_acc: 0.4096\n", 785 | "Epoch 45/200\n", 786 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2599 - acc: 0.4053 - val_loss: 0.2631 - val_acc: 0.3895\n", 787 | "Epoch 46/200\n", 788 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2589 - acc: 0.4056 - val_loss: 0.2625 - val_acc: 0.3972\n", 789 | "Epoch 47/200\n", 790 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2567 - acc: 0.4107 - val_loss: 0.2589 - val_acc: 0.4100\n", 791 | "Epoch 48/200\n", 792 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2551 - acc: 0.4125 - val_loss: 0.2598 - val_acc: 0.4062\n", 793 | "Epoch 49/200\n", 794 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2568 - acc: 0.4121 - val_loss: 0.2576 - val_acc: 0.4121\n", 795 | "Epoch 50/200\n", 796 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2537 - acc: 0.4160 - val_loss: 0.2619 - val_acc: 0.4121\n", 797 | "Epoch 51/200\n", 798 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2559 - acc: 0.4185 - val_loss: 0.2612 - val_acc: 0.4152\n", 799 | "Epoch 52/200\n", 800 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2549 - acc: 0.4191 - val_loss: 0.2556 - val_acc: 0.4180\n", 801 | "Epoch 53/200\n", 802 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2503 - acc: 0.4253 - val_loss: 0.2551 - val_acc: 0.4200\n", 803 | "Epoch 54/200\n", 804 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2518 - acc: 0.4233 - val_loss: 0.2534 - val_acc: 0.4214\n", 805 | "Epoch 55/200\n", 806 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2490 - acc: 0.4248 - val_loss: 0.2564 - val_acc: 0.4128\n", 807 | "Epoch 56/200\n", 808 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2489 - acc: 0.4239 - val_loss: 0.2508 - val_acc: 0.4273\n", 809 | "Epoch 57/200\n", 810 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2474 - acc: 0.4268 - val_loss: 0.2517 - val_acc: 0.4200\n", 811 | "Epoch 58/200\n", 812 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2495 - acc: 0.4282 - val_loss: 0.2527 - val_acc: 0.4114\n", 813 | "Epoch 59/200\n", 814 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2469 - acc: 0.4296 - val_loss: 0.2493 - val_acc: 0.4318\n", 815 | "Epoch 60/200\n" 816 | ] 817 | }, 818 | { 819 | "name": "stdout", 820 | "output_type": "stream", 821 | "text": [ 822 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2462 - acc: 0.4287 - val_loss: 0.2495 - val_acc: 0.4194\n", 823 | "Epoch 61/200\n", 824 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2464 - acc: 0.4270 - val_loss: 0.2530 - val_acc: 0.4266\n", 825 | "Epoch 62/200\n", 826 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2457 - acc: 0.4301 - val_loss: 0.2476 - val_acc: 0.4239\n", 827 | "Epoch 63/200\n", 828 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2447 - acc: 0.4319 - val_loss: 0.2498 - val_acc: 0.4339\n", 829 | "Epoch 64/200\n", 830 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2435 - acc: 0.4313 - val_loss: 0.2487 - val_acc: 0.4232\n", 831 | "Epoch 65/200\n", 832 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2455 - acc: 0.4287 - val_loss: 0.2452 - val_acc: 0.4391\n", 833 | "Epoch 66/200\n", 834 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2432 - acc: 0.4299 - val_loss: 0.2460 - val_acc: 0.4280\n", 835 | "Epoch 67/200\n", 836 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2415 - acc: 0.4347 - val_loss: 0.2461 - val_acc: 0.4145\n", 837 | "Epoch 68/200\n", 838 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2416 - acc: 0.4336 - val_loss: 0.2469 - val_acc: 0.4332\n", 839 | "Epoch 69/200\n", 840 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2426 - acc: 0.4355 - val_loss: 0.2436 - val_acc: 0.4273\n", 841 | "Epoch 70/200\n", 842 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2399 - acc: 0.4372 - val_loss: 0.2444 - val_acc: 0.4367\n", 843 | "Epoch 71/200\n", 844 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2406 - acc: 0.4367 - val_loss: 0.2422 - val_acc: 0.4336\n", 845 | "Epoch 72/200\n", 846 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2388 - acc: 0.4353 - val_loss: 0.2477 - val_acc: 0.4225\n", 847 | "Epoch 73/200\n", 848 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2405 - acc: 0.4333 - val_loss: 0.2474 - val_acc: 0.4225\n", 849 | "Epoch 74/200\n", 850 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2389 - acc: 0.4396 - val_loss: 0.2402 - val_acc: 0.4266\n", 851 | "Epoch 75/200\n", 852 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2395 - acc: 0.4332 - val_loss: 0.2432 - val_acc: 0.4426\n", 853 | "Epoch 76/200\n", 854 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2371 - acc: 0.4355 - val_loss: 0.2430 - val_acc: 0.4284\n", 855 | "Epoch 77/200\n", 856 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2384 - acc: 0.4349 - val_loss: 0.2409 - val_acc: 0.4374\n", 857 | "Epoch 78/200\n", 858 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2368 - acc: 0.4374 - val_loss: 0.2376 - val_acc: 0.4405\n", 859 | "Epoch 79/200\n", 860 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2352 - acc: 0.4366 - val_loss: 0.2415 - val_acc: 0.4429\n", 861 | "Epoch 80/200\n", 862 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2373 - acc: 0.4394 - val_loss: 0.2380 - val_acc: 0.4225\n", 863 | "Epoch 81/200\n", 864 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2343 - acc: 0.4401 - val_loss: 0.2369 - val_acc: 0.4492\n", 865 | "Epoch 82/200\n", 866 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2342 - acc: 0.4431 - val_loss: 0.2380 - val_acc: 0.4339\n", 867 | "Epoch 83/200\n", 868 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2335 - acc: 0.4422 - val_loss: 0.2367 - val_acc: 0.4440\n", 869 | "Epoch 84/200\n", 870 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2320 - acc: 0.4416 - val_loss: 0.2393 - val_acc: 0.4478\n", 871 | "Epoch 85/200\n", 872 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2356 - acc: 0.4349 - val_loss: 0.2351 - val_acc: 0.4429\n", 873 | "Epoch 86/200\n", 874 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2309 - acc: 0.4463 - val_loss: 0.2363 - val_acc: 0.4499\n", 875 | "Epoch 87/200\n", 876 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2333 - acc: 0.4470 - val_loss: 0.2393 - val_acc: 0.4353\n", 877 | "Epoch 88/200\n", 878 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2372 - acc: 0.4382 - val_loss: 0.2402 - val_acc: 0.4353\n", 879 | "Epoch 89/200\n", 880 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2317 - acc: 0.4461 - val_loss: 0.2340 - val_acc: 0.4350\n", 881 | "Epoch 90/200\n", 882 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2305 - acc: 0.4424 - val_loss: 0.2329 - val_acc: 0.4530\n", 883 | "Epoch 91/200\n", 884 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2311 - acc: 0.4444 - val_loss: 0.2355 - val_acc: 0.4468\n", 885 | "Epoch 92/200\n", 886 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2283 - acc: 0.4459 - val_loss: 0.2327 - val_acc: 0.4513\n", 887 | "Epoch 93/200\n", 888 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2294 - acc: 0.4426 - val_loss: 0.2406 - val_acc: 0.4325\n", 889 | "Epoch 94/200\n", 890 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2389 - acc: 0.4452 - val_loss: 0.2346 - val_acc: 0.4450\n", 891 | "Epoch 95/200\n", 892 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2296 - acc: 0.4482 - val_loss: 0.2359 - val_acc: 0.4409\n", 893 | "Epoch 96/200\n", 894 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2328 - acc: 0.4436 - val_loss: 0.2326 - val_acc: 0.4325\n", 895 | "Epoch 97/200\n", 896 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2280 - acc: 0.4508 - val_loss: 0.2335 - val_acc: 0.4502\n", 897 | "Epoch 98/200\n", 898 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2264 - acc: 0.4509 - val_loss: 0.2292 - val_acc: 0.4527\n", 899 | "Epoch 99/200\n", 900 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2271 - acc: 0.4486 - val_loss: 0.2321 - val_acc: 0.4527\n", 901 | "Epoch 100/200\n", 902 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2280 - acc: 0.4525 - val_loss: 0.2314 - val_acc: 0.4454\n", 903 | "Epoch 101/200\n", 904 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2265 - acc: 0.4518 - val_loss: 0.2322 - val_acc: 0.4475\n", 905 | "Epoch 102/200\n", 906 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2253 - acc: 0.4571 - val_loss: 0.2316 - val_acc: 0.4461\n", 907 | "Epoch 103/200\n", 908 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2280 - acc: 0.4552 - val_loss: 0.2335 - val_acc: 0.4471\n", 909 | "Epoch 104/200\n", 910 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2288 - acc: 0.4534 - val_loss: 0.2288 - val_acc: 0.4447\n", 911 | "Epoch 105/200\n", 912 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2263 - acc: 0.4567 - val_loss: 0.2315 - val_acc: 0.4617\n", 913 | "Epoch 106/200\n", 914 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2259 - acc: 0.4550 - val_loss: 0.2295 - val_acc: 0.4509\n", 915 | "Epoch 107/200\n", 916 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2235 - acc: 0.4549 - val_loss: 0.2287 - val_acc: 0.4346\n", 917 | "Epoch 108/200\n", 918 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2244 - acc: 0.4520 - val_loss: 0.2279 - val_acc: 0.4547\n", 919 | "Epoch 109/200\n", 920 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2236 - acc: 0.4565 - val_loss: 0.2261 - val_acc: 0.4533\n", 921 | "Epoch 110/200\n", 922 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2239 - acc: 0.4524 - val_loss: 0.2284 - val_acc: 0.4603\n", 923 | "Epoch 111/200\n", 924 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2279 - acc: 0.4580 - val_loss: 0.2262 - val_acc: 0.4468\n", 925 | "Epoch 112/200\n", 926 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2281 - acc: 0.4533 - val_loss: 0.2319 - val_acc: 0.4488\n", 927 | "Epoch 113/200\n", 928 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2251 - acc: 0.4572 - val_loss: 0.2285 - val_acc: 0.4669\n", 929 | "Epoch 114/200\n", 930 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2235 - acc: 0.4595 - val_loss: 0.2265 - val_acc: 0.4606\n", 931 | "Epoch 115/200\n", 932 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2235 - acc: 0.4621 - val_loss: 0.2275 - val_acc: 0.4506\n", 933 | "Epoch 116/200\n", 934 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2210 - acc: 0.4615 - val_loss: 0.2251 - val_acc: 0.4599\n", 935 | "Epoch 117/200\n", 936 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2216 - acc: 0.4634 - val_loss: 0.2268 - val_acc: 0.4575\n", 937 | "Epoch 118/200\n", 938 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2231 - acc: 0.4603 - val_loss: 0.2282 - val_acc: 0.4572\n", 939 | "Epoch 119/200\n" 940 | ] 941 | }, 942 | { 943 | "name": "stdout", 944 | "output_type": "stream", 945 | "text": [ 946 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2209 - acc: 0.4644 - val_loss: 0.2275 - val_acc: 0.4568\n", 947 | "Epoch 120/200\n", 948 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2220 - acc: 0.4630 - val_loss: 0.2244 - val_acc: 0.4655\n", 949 | "Epoch 121/200\n", 950 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2219 - acc: 0.4607 - val_loss: 0.2299 - val_acc: 0.4690\n", 951 | "Epoch 122/200\n", 952 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2270 - acc: 0.4611 - val_loss: 0.2259 - val_acc: 0.4631\n", 953 | "Epoch 123/200\n", 954 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2196 - acc: 0.4659 - val_loss: 0.2222 - val_acc: 0.4703\n", 955 | "Epoch 124/200\n", 956 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2199 - acc: 0.4631 - val_loss: 0.2292 - val_acc: 0.4530\n", 957 | "Epoch 125/200\n", 958 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2204 - acc: 0.4663 - val_loss: 0.2235 - val_acc: 0.4662\n", 959 | "Epoch 126/200\n", 960 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2195 - acc: 0.4658 - val_loss: 0.2254 - val_acc: 0.4475\n", 961 | "Epoch 127/200\n", 962 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2210 - acc: 0.4655 - val_loss: 0.2255 - val_acc: 0.4565\n", 963 | "Epoch 128/200\n", 964 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2218 - acc: 0.4681 - val_loss: 0.2266 - val_acc: 0.4672\n", 965 | "Epoch 129/200\n", 966 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2186 - acc: 0.4712 - val_loss: 0.2208 - val_acc: 0.4721\n", 967 | "Epoch 130/200\n", 968 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2183 - acc: 0.4686 - val_loss: 0.2193 - val_acc: 0.4679\n", 969 | "Epoch 131/200\n", 970 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2179 - acc: 0.4666 - val_loss: 0.2217 - val_acc: 0.4738\n", 971 | "Epoch 132/200\n", 972 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2188 - acc: 0.4715 - val_loss: 0.2222 - val_acc: 0.4624\n", 973 | "Epoch 133/200\n", 974 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2168 - acc: 0.4753 - val_loss: 0.2229 - val_acc: 0.4579\n", 975 | "Epoch 134/200\n", 976 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2186 - acc: 0.4714 - val_loss: 0.2262 - val_acc: 0.4554\n", 977 | "Epoch 135/200\n", 978 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2174 - acc: 0.4713 - val_loss: 0.2217 - val_acc: 0.4686\n", 979 | "Epoch 136/200\n", 980 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2168 - acc: 0.4696 - val_loss: 0.2213 - val_acc: 0.4599\n", 981 | "Epoch 137/200\n", 982 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2196 - acc: 0.4756 - val_loss: 0.2234 - val_acc: 0.4801\n", 983 | "Epoch 138/200\n", 984 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2177 - acc: 0.4715 - val_loss: 0.2218 - val_acc: 0.4665\n", 985 | "Epoch 139/200\n", 986 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2201 - acc: 0.4750 - val_loss: 0.2272 - val_acc: 0.4735\n", 987 | "Epoch 140/200\n", 988 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2165 - acc: 0.4755 - val_loss: 0.2184 - val_acc: 0.4672\n", 989 | "Epoch 141/200\n", 990 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2148 - acc: 0.4766 - val_loss: 0.2190 - val_acc: 0.4700\n", 991 | "Epoch 142/200\n", 992 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2172 - acc: 0.4708 - val_loss: 0.2207 - val_acc: 0.4672\n", 993 | "Epoch 143/200\n", 994 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2176 - acc: 0.4767 - val_loss: 0.2225 - val_acc: 0.4807\n", 995 | "Epoch 144/200\n", 996 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2164 - acc: 0.4766 - val_loss: 0.2193 - val_acc: 0.4821\n", 997 | "Epoch 145/200\n", 998 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2147 - acc: 0.4793 - val_loss: 0.2183 - val_acc: 0.4672\n", 999 | "Epoch 146/200\n", 1000 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2146 - acc: 0.4756 - val_loss: 0.2167 - val_acc: 0.4804\n", 1001 | "Epoch 147/200\n", 1002 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2148 - acc: 0.4791 - val_loss: 0.2216 - val_acc: 0.4835\n", 1003 | "Epoch 148/200\n", 1004 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2147 - acc: 0.4786 - val_loss: 0.2185 - val_acc: 0.4624\n", 1005 | "Epoch 149/200\n", 1006 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2142 - acc: 0.4787 - val_loss: 0.2172 - val_acc: 0.4974\n", 1007 | "Epoch 150/200\n", 1008 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2156 - acc: 0.4779 - val_loss: 0.2177 - val_acc: 0.4603\n", 1009 | "Epoch 151/200\n", 1010 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2162 - acc: 0.4803 - val_loss: 0.2174 - val_acc: 0.4676\n", 1011 | "Epoch 152/200\n", 1012 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2126 - acc: 0.4820 - val_loss: 0.2172 - val_acc: 0.4790\n", 1013 | "Epoch 153/200\n", 1014 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2129 - acc: 0.4832 - val_loss: 0.2199 - val_acc: 0.4939\n", 1015 | "Epoch 154/200\n", 1016 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2126 - acc: 0.4790 - val_loss: 0.2156 - val_acc: 0.4835\n", 1017 | "Epoch 155/200\n", 1018 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2139 - acc: 0.4848 - val_loss: 0.2180 - val_acc: 0.4665\n", 1019 | "Epoch 156/200\n", 1020 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2134 - acc: 0.4843 - val_loss: 0.2199 - val_acc: 0.4717\n", 1021 | "Epoch 157/200\n", 1022 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2156 - acc: 0.4785 - val_loss: 0.2191 - val_acc: 0.4901\n", 1023 | "Epoch 158/200\n", 1024 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2121 - acc: 0.4812 - val_loss: 0.2170 - val_acc: 0.4832\n", 1025 | "Epoch 159/200\n", 1026 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2150 - acc: 0.4834 - val_loss: 0.2177 - val_acc: 0.4915\n", 1027 | "Epoch 160/200\n", 1028 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2163 - acc: 0.4765 - val_loss: 0.2138 - val_acc: 0.4846\n", 1029 | "Epoch 161/200\n", 1030 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2113 - acc: 0.4810 - val_loss: 0.2168 - val_acc: 0.4710\n", 1031 | "Epoch 162/200\n", 1032 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2115 - acc: 0.4847 - val_loss: 0.2161 - val_acc: 0.4794\n", 1033 | "Epoch 163/200\n", 1034 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2128 - acc: 0.4832 - val_loss: 0.2170 - val_acc: 0.4832\n", 1035 | "Epoch 164/200\n", 1036 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2121 - acc: 0.4842 - val_loss: 0.2191 - val_acc: 0.4894\n", 1037 | "Epoch 165/200\n", 1038 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2127 - acc: 0.4861 - val_loss: 0.2157 - val_acc: 0.4735\n", 1039 | "Epoch 166/200\n", 1040 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2110 - acc: 0.4910 - val_loss: 0.2152 - val_acc: 0.5071\n", 1041 | "Epoch 167/200\n", 1042 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2097 - acc: 0.4922 - val_loss: 0.2136 - val_acc: 0.4762\n", 1043 | "Epoch 168/200\n", 1044 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2097 - acc: 0.4902 - val_loss: 0.2152 - val_acc: 0.4939\n", 1045 | "Epoch 169/200\n", 1046 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2097 - acc: 0.4890 - val_loss: 0.2152 - val_acc: 0.4984\n", 1047 | "Epoch 170/200\n", 1048 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2130 - acc: 0.4891 - val_loss: 0.2184 - val_acc: 0.4696\n", 1049 | "Epoch 171/200\n", 1050 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2122 - acc: 0.4867 - val_loss: 0.2181 - val_acc: 0.4551\n", 1051 | "Epoch 172/200\n", 1052 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2084 - acc: 0.4968 - val_loss: 0.2100 - val_acc: 0.4853\n", 1053 | "Epoch 173/200\n", 1054 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2097 - acc: 0.4891 - val_loss: 0.2113 - val_acc: 0.4894\n", 1055 | "Epoch 174/200\n", 1056 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2111 - acc: 0.4884 - val_loss: 0.2191 - val_acc: 0.4672\n", 1057 | "Epoch 175/200\n", 1058 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2125 - acc: 0.4869 - val_loss: 0.2144 - val_acc: 0.4860\n", 1059 | "Epoch 176/200\n", 1060 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2098 - acc: 0.4944 - val_loss: 0.2133 - val_acc: 0.4922\n", 1061 | "Epoch 177/200\n", 1062 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2187 - acc: 0.4934 - val_loss: 0.2169 - val_acc: 0.4835\n", 1063 | "Epoch 178/200\n" 1064 | ] 1065 | }, 1066 | { 1067 | "name": "stdout", 1068 | "output_type": "stream", 1069 | "text": [ 1070 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2229 - acc: 0.4890 - val_loss: 0.2238 - val_acc: 0.4710\n", 1071 | "Epoch 179/200\n", 1072 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2112 - acc: 0.4916 - val_loss: 0.2121 - val_acc: 0.4717\n", 1073 | "Epoch 180/200\n", 1074 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2125 - acc: 0.4888 - val_loss: 0.2131 - val_acc: 0.4918\n", 1075 | "Epoch 181/200\n", 1076 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2083 - acc: 0.4968 - val_loss: 0.2106 - val_acc: 0.4974\n", 1077 | "Epoch 182/200\n", 1078 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2070 - acc: 0.4958 - val_loss: 0.2091 - val_acc: 0.4925\n", 1079 | "Epoch 183/200\n", 1080 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2064 - acc: 0.4990 - val_loss: 0.2106 - val_acc: 0.4839\n", 1081 | "Epoch 184/200\n", 1082 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2066 - acc: 0.4987 - val_loss: 0.2108 - val_acc: 0.4929\n", 1083 | "Epoch 185/200\n", 1084 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2065 - acc: 0.4980 - val_loss: 0.2141 - val_acc: 0.4853\n", 1085 | "Epoch 186/200\n", 1086 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2076 - acc: 0.4947 - val_loss: 0.2136 - val_acc: 0.4776\n", 1087 | "Epoch 187/200\n", 1088 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2078 - acc: 0.4940 - val_loss: 0.2101 - val_acc: 0.4877\n", 1089 | "Epoch 188/200\n", 1090 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2074 - acc: 0.4951 - val_loss: 0.2130 - val_acc: 0.4846\n", 1091 | "Epoch 189/200\n", 1092 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2065 - acc: 0.4947 - val_loss: 0.2136 - val_acc: 0.4981\n", 1093 | "Epoch 190/200\n", 1094 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2083 - acc: 0.4997 - val_loss: 0.2113 - val_acc: 0.4880\n", 1095 | "Epoch 191/200\n", 1096 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2077 - acc: 0.4964 - val_loss: 0.2096 - val_acc: 0.4981\n", 1097 | "Epoch 192/200\n", 1098 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2064 - acc: 0.5042 - val_loss: 0.2086 - val_acc: 0.5064\n", 1099 | "Epoch 193/200\n", 1100 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2057 - acc: 0.5029 - val_loss: 0.2113 - val_acc: 0.4977\n", 1101 | "Epoch 194/200\n", 1102 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2064 - acc: 0.4980 - val_loss: 0.2120 - val_acc: 0.4925\n", 1103 | "Epoch 195/200\n", 1104 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2054 - acc: 0.5015 - val_loss: 0.2082 - val_acc: 0.4946\n", 1105 | "Epoch 196/200\n", 1106 | "11541/11541 [==============================] - 0s 8us/step - loss: 0.2077 - acc: 0.5014 - val_loss: 0.2116 - val_acc: 0.4998\n", 1107 | "Epoch 197/200\n", 1108 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2062 - acc: 0.4984 - val_loss: 0.2130 - val_acc: 0.5064\n", 1109 | "Epoch 198/200\n", 1110 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2054 - acc: 0.5000 - val_loss: 0.2143 - val_acc: 0.4887\n", 1111 | "Epoch 199/200\n", 1112 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2093 - acc: 0.4966 - val_loss: 0.2153 - val_acc: 0.5082\n", 1113 | "Epoch 200/200\n", 1114 | "11541/11541 [==============================] - 0s 9us/step - loss: 0.2056 - acc: 0.5007 - val_loss: 0.2107 - val_acc: 0.4964\n" 1115 | ] 1116 | } 1117 | ], 1118 | "source": [ 1119 | "autoencoder.compile(metrics=['accuracy'],\n", 1120 | " loss='mean_squared_error',\n", 1121 | " optimizer='adam')\n", 1122 | "\n", 1123 | "cp = ModelCheckpoint(filepath=\"autoencoder_classifier.h5\",\n", 1124 | " save_best_only=True,\n", 1125 | " verbose=0)\n", 1126 | "\n", 1127 | "tb = TensorBoard(log_dir='./logs',\n", 1128 | " histogram_freq=0,\n", 1129 | " write_graph=True,\n", 1130 | " write_images=True)\n", 1131 | "\n", 1132 | "history = autoencoder.fit(df_train_0_x_rescaled, df_train_0_x_rescaled,\n", 1133 | " epochs=nb_epoch,\n", 1134 | " batch_size=batch_size,\n", 1135 | " shuffle=True,\n", 1136 | " validation_data=(df_valid_0_x_rescaled, df_valid_0_x_rescaled),\n", 1137 | " verbose=1,\n", 1138 | " callbacks=[cp, tb]).history" 1139 | ] 1140 | }, 1141 | { 1142 | "cell_type": "code", 1143 | "execution_count": 108, 1144 | "metadata": {}, 1145 | "outputs": [], 1146 | "source": [ 1147 | "autoencoder = load_model('autoencoder_classifier.h5')" 1148 | ] 1149 | }, 1150 | { 1151 | "cell_type": "code", 1152 | "execution_count": 109, 1153 | "metadata": {}, 1154 | "outputs": [ 1155 | { 1156 | "data": { 1157 | "image/png": "\n", 1158 | "text/plain": [ 1159 | "
" 1160 | ] 1161 | }, 1162 | "metadata": { 1163 | "needs_background": "light" 1164 | }, 1165 | "output_type": "display_data" 1166 | } 1167 | ], 1168 | "source": [ 1169 | "plt.plot(history['loss'], linewidth=2, label='Train')\n", 1170 | "plt.plot(history['val_loss'], linewidth=2, label='Valid')\n", 1171 | "plt.legend(loc='upper right')\n", 1172 | "plt.title('Model loss')\n", 1173 | "plt.ylabel('Loss')\n", 1174 | "plt.xlabel('Epoch')\n", 1175 | "plt.show()" 1176 | ] 1177 | }, 1178 | { 1179 | "cell_type": "code", 1180 | "execution_count": 110, 1181 | "metadata": {}, 1182 | "outputs": [ 1183 | { 1184 | "data": { 1185 | "image/png": "\n", 1186 | "text/plain": [ 1187 | "
" 1188 | ] 1189 | }, 1190 | "metadata": { 1191 | "needs_background": "light" 1192 | }, 1193 | "output_type": "display_data" 1194 | } 1195 | ], 1196 | "source": [ 1197 | "valid_x_predictions = autoencoder.predict(df_valid_x_rescaled)\n", 1198 | "mse = np.mean(np.power(df_valid_x_rescaled - valid_x_predictions, 2), axis=1)\n", 1199 | "error_df = pd.DataFrame({'Reconstruction_error': mse,\n", 1200 | " 'True_class': df_valid['y']})\n", 1201 | "\n", 1202 | "precision_rt, recall_rt, threshold_rt = precision_recall_curve(error_df.True_class, error_df.Reconstruction_error)\n", 1203 | "plt.plot(threshold_rt, precision_rt[1:], label=\"Precision\",linewidth=5)\n", 1204 | "plt.plot(threshold_rt, recall_rt[1:], label=\"Recall\",linewidth=5)\n", 1205 | "plt.title('Precision and recall for different threshold values')\n", 1206 | "plt.xlabel('Threshold')\n", 1207 | "plt.ylabel('Precision/Recall')\n", 1208 | "plt.legend()\n", 1209 | "plt.show()" 1210 | ] 1211 | }, 1212 | { 1213 | "cell_type": "code", 1214 | "execution_count": 111, 1215 | "metadata": {}, 1216 | "outputs": [], 1217 | "source": [ 1218 | "test_x_predictions = autoencoder.predict(df_test_x_rescaled)\n", 1219 | "mse = np.mean(np.power(df_test_x_rescaled - test_x_predictions, 2), axis=1)\n", 1220 | "error_df_test = pd.DataFrame({'Reconstruction_error': mse,\n", 1221 | " 'True_class': df_test['y']})\n", 1222 | "error_df_test = error_df_test.reset_index()" 1223 | ] 1224 | }, 1225 | { 1226 | "cell_type": "code", 1227 | "execution_count": 120, 1228 | "metadata": {}, 1229 | "outputs": [ 1230 | { 1231 | "data": { 1232 | "image/png": "\n", 1233 | "text/plain": [ 1234 | "
" 1235 | ] 1236 | }, 1237 | "metadata": { 1238 | "needs_background": "light" 1239 | }, 1240 | "output_type": "display_data" 1241 | } 1242 | ], 1243 | "source": [ 1244 | "threshold_fixed = 0.4\n", 1245 | "groups = error_df_test.groupby('True_class')\n", 1246 | "\n", 1247 | "fig, ax = plt.subplots()\n", 1248 | "\n", 1249 | "for name, group in groups:\n", 1250 | " ax.plot(group.index, group.Reconstruction_error, marker='o', ms=3.5, linestyle='',\n", 1251 | " label= \"Break\" if name == 1 else \"Normal\")\n", 1252 | "ax.hlines(threshold_fixed, ax.get_xlim()[0], ax.get_xlim()[1], colors=\"r\", zorder=100, label='Threshold')\n", 1253 | "ax.legend()\n", 1254 | "plt.title(\"Reconstruction error for different classes\")\n", 1255 | "plt.ylabel(\"Reconstruction error\")\n", 1256 | "plt.xlabel(\"Data point index\")\n", 1257 | "plt.show();" 1258 | ] 1259 | }, 1260 | { 1261 | "cell_type": "code", 1262 | "execution_count": 121, 1263 | "metadata": {}, 1264 | "outputs": [], 1265 | "source": [ 1266 | "pred_y = [1 if e > threshold_fixed else 0 for e in error_df.Reconstruction_error.values]\n" 1267 | ] 1268 | }, 1269 | { 1270 | "cell_type": "code", 1271 | "execution_count": 122, 1272 | "metadata": {}, 1273 | "outputs": [], 1274 | "source": [ 1275 | "predictions = pd.DataFrame({'true': error_df.True_class,\n", 1276 | " 'predicted': pred_y})" 1277 | ] 1278 | }, 1279 | { 1280 | "cell_type": "code", 1281 | "execution_count": 123, 1282 | "metadata": {}, 1283 | "outputs": [ 1284 | { 1285 | "data": { 1286 | "image/png": "\n", 1287 | "text/plain": [ 1288 | "
" 1289 | ] 1290 | }, 1291 | "metadata": { 1292 | "needs_background": "light" 1293 | }, 1294 | "output_type": "display_data" 1295 | } 1296 | ], 1297 | "source": [ 1298 | "conf_matrix = confusion_matrix(error_df.True_class, pred_y)\n", 1299 | "plt.figure(figsize=(8, 8))\n", 1300 | "sns.heatmap(conf_matrix, xticklabels=LABELS, yticklabels=LABELS, annot=True, fmt=\"d\");\n", 1301 | "plt.title(\"Confusion matrix\")\n", 1302 | "plt.ylabel('True class')\n", 1303 | "plt.xlabel('Predicted class')\n", 1304 | "plt.show()" 1305 | ] 1306 | }, 1307 | { 1308 | "cell_type": "code", 1309 | "execution_count": null, 1310 | "metadata": {}, 1311 | "outputs": [], 1312 | "source": [ 1313 | "false_pos_rate, true_pos_rate, thresholds = roc_curve(error_df.True_class, error_df.Reconstruction_error)\n", 1314 | "roc_auc = auc(false_pos_rate, true_pos_rate,)\n", 1315 | "\n", 1316 | "plt.plot(false_pos_rate, true_pos_rate, linewidth=5, label='AUC = %0.3f'% roc_auc)\n", 1317 | "plt.plot([0,1],[0,1], linewidth=5)\n", 1318 | "\n", 1319 | "plt.xlim([-0.01, 1])\n", 1320 | "plt.ylim([0, 1.01])\n", 1321 | "plt.legend(loc='lower right')\n", 1322 | "plt.title('Receiver operating characteristic curve (ROC)')\n", 1323 | "plt.ylabel('True Positive Rate')\n", 1324 | "plt.xlabel('False Positive Rate')\n", 1325 | "plt.show()" 1326 | ] 1327 | } 1328 | ], 1329 | "metadata": { 1330 | "kernelspec": { 1331 | "display_name": "Python 3", 1332 | "language": "python", 1333 | "name": "python3" 1334 | }, 1335 | "language_info": { 1336 | "codemirror_mode": { 1337 | "name": "ipython", 1338 | "version": 3 1339 | }, 1340 | "file_extension": ".py", 1341 | "mimetype": "text/x-python", 1342 | "name": "python", 1343 | "nbconvert_exporter": "python", 1344 | "pygments_lexer": "ipython3", 1345 | "version": "3.7.1" 1346 | } 1347 | }, 1348 | "nbformat": 4, 1349 | "nbformat_minor": 2 1350 | } 1351 | --------------------------------------------------------------------------------