├── 1.png ├── 2.png ├── 3.png ├── Energy_Consumption_Predictions_with_Bayesian_LSTMs_in_PyTorch.ipynb └── README.md /1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PawaritL/BayesianLSTM/251ba71f5c45e81f6083c9e54ecaa39c5d64713f/1.png -------------------------------------------------------------------------------- /2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PawaritL/BayesianLSTM/251ba71f5c45e81f6083c9e54ecaa39c5d64713f/2.png -------------------------------------------------------------------------------- /3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PawaritL/BayesianLSTM/251ba71f5c45e81f6083c9e54ecaa39c5d64713f/3.png -------------------------------------------------------------------------------- /Energy_Consumption_Predictions_with_Bayesian_LSTMs_in_PyTorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Energy Consumption Predictions with Bayesian LSTMs in PyTorch.ipynb", 7 | "provenance": [], 8 | "include_colab_link": true 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "view-in-github", 20 | "colab_type": "text" 21 | }, 22 | "source": [ 23 | "\"Open" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "id": "VHRshe9Y-Q0Q" 30 | }, 31 | "source": [ 32 | "# Energy Consumption Predictions with Bayesian LSTMs in PyTorch" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": { 38 | "id": "8GGWUqKU1paW" 39 | }, 40 | "source": [ 41 | "Author: Pawarit Laosunthara\n", 42 | "\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": { 48 | "id": "3FQQioKQ-Tkr" 49 | }, 50 | "source": [ 51 | "# **Important Note for GitHub Readers:**\n", 52 | "Please click the **Open in Colab** button above in order to view all **interactive visualizations**.\n", 53 | "\n", 54 | "This notebook demonstrates an implementation of an (Approximate) Bayesian Recurrent Neural Network in PyTorch, originally inspired by the *Deep and Confident Prediction for Time Series at Uber* (https://arxiv.org/pdf/1709.01907.pdf)\n", 55 | "\n", 56 | "
\n", 57 | "\n", 58 | "In this approach, Monte Carlo dropout is used to **approximate** Bayesian inference, allowing our predictions to have explicit uncertainties and confidence intervals. This property makes Bayesian Neural Networks highly appealing to critical applications requiring uncertainty quantification.\n", 59 | "The *Appliances energy prediction* dataset used in this example is from the UCI Machine Learning Repository (https://archive.ics.uci.edu/ml/datasets/Appliances+energy+prediction)\n", 60 | "\n", 61 | "\n", 62 | "**Note:** this notebook purely serves to demonstrate the implementation of Bayesian LSTMs (Long Short-Term Memory) networks in PyTorch. Therefore, extensive data exploration and feature engineering is not part of the scope of this investigation." 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": { 68 | "id": "psI_d17Y_3_9" 69 | }, 70 | "source": [ 71 | "# Preliminary Data Wrangling" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": { 77 | "id": "OFY02qDSAbqI" 78 | }, 79 | "source": [ 80 | "**Selected Columns:**\n", 81 | "\n", 82 | "For simplicity and speed when running this notebook, only temporal and autoregressive features are used.\n", 83 | "\n", 84 | "- date time year-month-day hour:minute:second, sampled every 10 minutes \\\n", 85 | "- Appliances, energy use in Wh for the corresponding 10-minute timestamp \\\n", 86 | "- day_of_week, where Monday corresponds to 0 \\\n", 87 | "- hour_of_day\n" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "metadata": { 93 | "id": "uWNK0BtB2W0E" 94 | }, 95 | "source": [ 96 | "import pandas as pd" 97 | ], 98 | "execution_count": null, 99 | "outputs": [] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "metadata": { 104 | "id": "WDYp3nRh-Bpv" 105 | }, 106 | "source": [ 107 | "energy_df = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/00374/energydata_complete.csv')\n", 108 | "\n", 109 | "energy_df['date'] = pd.to_datetime(energy_df['date'])\n", 110 | "\n", 111 | "energy_df['month'] = energy_df['date'].dt.month.astype(int)\n", 112 | "energy_df['day_of_month'] = energy_df['date'].dt.day.astype(int)\n", 113 | "\n", 114 | "# day_of_week=0 corresponds to Monday\n", 115 | "energy_df['day_of_week'] = energy_df['date'].dt.dayofweek.astype(int)\n", 116 | "energy_df['hour_of_day'] = energy_df['date'].dt.hour.astype(int)\n", 117 | "\n", 118 | "selected_columns = ['date', 'day_of_week', 'hour_of_day', 'Appliances']\n", 119 | "energy_df = energy_df[selected_columns]" 120 | ], 121 | "execution_count": null, 122 | "outputs": [] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "metadata": { 127 | "id": "vv4cV2qr-E-8", 128 | "colab": { 129 | "base_uri": "https://localhost:8080/", 130 | "height": 206 131 | }, 132 | "outputId": "476229ff-04c5-4b20-e035-dc0a6c22e8de" 133 | }, 134 | "source": [ 135 | "energy_df.head()" 136 | ], 137 | "execution_count": null, 138 | "outputs": [ 139 | { 140 | "output_type": "execute_result", 141 | "data": { 142 | "text/html": [ 143 | "
\n", 144 | "\n", 157 | "\n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | "
dateday_of_weekhour_of_dayAppliances
02016-01-11 17:00:0001760
12016-01-11 17:10:0001760
22016-01-11 17:20:0001750
32016-01-11 17:30:0001750
42016-01-11 17:40:0001760
\n", 205 | "
" 206 | ], 207 | "text/plain": [ 208 | " date day_of_week hour_of_day Appliances\n", 209 | "0 2016-01-11 17:00:00 0 17 60\n", 210 | "1 2016-01-11 17:10:00 0 17 60\n", 211 | "2 2016-01-11 17:20:00 0 17 50\n", 212 | "3 2016-01-11 17:30:00 0 17 50\n", 213 | "4 2016-01-11 17:40:00 0 17 60" 214 | ] 215 | }, 216 | "metadata": { 217 | "tags": [] 218 | }, 219 | "execution_count": 3 220 | } 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": { 226 | "id": "d8BJQAqaJqv7" 227 | }, 228 | "source": [ 229 | "## Time Series Transformations\n", 230 | "\n", 231 | "1. The dataset is to be re-sampled at an hourly rate for more meaningful analytics.\n", 232 | "\n", 233 | "2. To alleviate exponential effects, the target variable is log-transformed as per the Uber paper.\n", 234 | "\n", 235 | "3. For simplicity and speed when running this notebook, only temporal and autoregressive features, namely `day_of_week`, `hour_of_day`, \\\n", 236 | "and previous values of `Appliances` are used as features" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "metadata": { 242 | "id": "LcXp9XEaJqXO" 243 | }, 244 | "source": [ 245 | "import numpy as np\n", 246 | "\n", 247 | "resample_df = energy_df.set_index('date').resample('1H').mean()\n", 248 | "resample_df['date'] = resample_df.index\n", 249 | "resample_df['log_energy_consumption'] = np.log(resample_df['Appliances'])\n", 250 | "\n", 251 | "datetime_columns = ['date', 'day_of_week', 'hour_of_day']\n", 252 | "target_column = 'log_energy_consumption'\n", 253 | "\n", 254 | "feature_columns = datetime_columns + ['log_energy_consumption']\n", 255 | "\n", 256 | "# For clarity in visualization and presentation, \n", 257 | "# only consider the first 150 hours of data.\n", 258 | "resample_df = resample_df[feature_columns]" 259 | ], 260 | "execution_count": null, 261 | "outputs": [] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "metadata": { 266 | "id": "tY1pfEIGDvvc", 267 | "colab": { 268 | "base_uri": "https://localhost:8080/", 269 | "height": 542 270 | }, 271 | "outputId": "13c8afe6-a03b-460d-cfa3-d6705e6dee42" 272 | }, 273 | "source": [ 274 | "import plotly.express as px\n", 275 | "\n", 276 | "plot_length = 150\n", 277 | "plot_df = resample_df.copy(deep=True).iloc[:plot_length]\n", 278 | "plot_df['weekday'] = plot_df['date'].dt.day_name()\n", 279 | "\n", 280 | "fig = px.line(plot_df,\n", 281 | " x=\"date\",\n", 282 | " y=\"log_energy_consumption\", \n", 283 | " color=\"weekday\", \n", 284 | " title=\"Log of Appliance Energy Consumption vs Time\")\n", 285 | "fig.show()" 286 | ], 287 | "execution_count": null, 288 | "outputs": [ 289 | { 290 | "output_type": "display_data", 291 | "data": { 292 | "text/html": [ 293 | "\n", 294 | "\n", 295 | "\n", 296 | "
\n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | "
\n", 301 | " \n", 339 | "
\n", 340 | "\n", 341 | "" 342 | ] 343 | }, 344 | "metadata": { 345 | "tags": [] 346 | } 347 | } 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "metadata": { 353 | "id": "aFnQ3txg_2lj" 354 | }, 355 | "source": [ 356 | "# Prepare Training Data" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": { 362 | "id": "9gUa6m83NgDn" 363 | }, 364 | "source": [ 365 | "For this example, we will use sliding windows of 10 points per each window (equivalent to 10 hours) to predict each next point. The window size can be altered via the `sequence_length` variable.\n", 366 | "\n", 367 | "Min-Max scaling has also been fitted to the training data to aid the convergence of the neural network. " 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "metadata": { 373 | "id": "aKQHyh5rRVVj" 374 | }, 375 | "source": [ 376 | "from sklearn.preprocessing import MinMaxScaler\n", 377 | "\n", 378 | "def create_sliding_window(data, sequence_length, stride=1):\n", 379 | " X_list, y_list = [], []\n", 380 | " for i in range(len(data)):\n", 381 | " if (i + sequence_length) < len(data):\n", 382 | " X_list.append(data.iloc[i:i+sequence_length:stride, :].values)\n", 383 | " y_list.append(data.iloc[i+sequence_length, -1])\n", 384 | " return np.array(X_list), np.array(y_list)\n", 385 | "\n", 386 | "train_split = 0.7\n", 387 | "n_train = int(train_split * len(resample_df))\n", 388 | "n_test = len(resample_df) - n_train\n", 389 | "\n", 390 | "features = ['day_of_week', 'hour_of_day', 'log_energy_consumption']\n", 391 | "feature_array = resample_df[features].values\n", 392 | "\n", 393 | "# Fit Scaler only on Training features\n", 394 | "feature_scaler = MinMaxScaler()\n", 395 | "feature_scaler.fit(feature_array[:n_train])\n", 396 | "# Fit Scaler only on Training target values\n", 397 | "target_scaler = MinMaxScaler()\n", 398 | "target_scaler.fit(feature_array[:n_train, -1].reshape(-1, 1))\n", 399 | "\n", 400 | "# Transfom on both Training and Test data\n", 401 | "scaled_array = pd.DataFrame(feature_scaler.transform(feature_array),\n", 402 | " columns=features)\n", 403 | "\n", 404 | "sequence_length = 10\n", 405 | "X, y = create_sliding_window(scaled_array, \n", 406 | " sequence_length)\n", 407 | "\n", 408 | "X_train = X[:n_train]\n", 409 | "y_train = y[:n_train]\n", 410 | "\n", 411 | "X_test = X[n_train:]\n", 412 | "y_test = y[n_train:]" 413 | ], 414 | "execution_count": null, 415 | "outputs": [] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "metadata": { 420 | "id": "Ue1JcZm_bgsd" 421 | }, 422 | "source": [ 423 | "# Define Bayesian LSTM Architecture" 424 | ] 425 | }, 426 | { 427 | "cell_type": "markdown", 428 | "metadata": { 429 | "id": "VNag-wa-04WZ" 430 | }, 431 | "source": [ 432 | "To demonstrate a simple working example of the Bayesian LSTM, a model with a similar architecture and size to that in Uber's paper has been used a starting point. The network architecture is as follows:\n", 433 | "\n", 434 | "Encoder-Decoder Stage:\n", 435 | " - A uni-directional LSTM with 2 stacked layers & 128 hidden units acting as an encoding layer to construct a fixed-dimension embedding state\n", 436 | " - A uni-directional LSTM with 2 stacked layers & 32 hidden units acting as a decoding layer to produce predictions at future steps\n", 437 | " - Dropout is applied at **both** training and inference for both LSTM layers\n", 438 | "\n", 439 | "\n", 440 | " Predictor Stage:\n", 441 | " - 1 fully-connected output layer with 1 output (for predicting the target value) to produce a single value for the target variable\n", 442 | "\n", 443 | "\n", 444 | "By allowing dropout at both training and testing time, the model simulates random sampling, thus allowing varying predictions that can be used to estimate the underlying distribution of the target value, enabling explicit model uncertainties.\n" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "metadata": { 450 | "id": "OgWyOffPbO0b" 451 | }, 452 | "source": [ 453 | "import torch\n", 454 | "import torch.nn as nn\n", 455 | "import torch.nn.functional as F\n", 456 | "from torch.autograd import Variable\n", 457 | "\n", 458 | "class BayesianLSTM(nn.Module):\n", 459 | "\n", 460 | " def __init__(self, n_features, output_length, batch_size):\n", 461 | "\n", 462 | " super(BayesianLSTM, self).__init__()\n", 463 | "\n", 464 | " self.batch_size = batch_size # user-defined\n", 465 | "\n", 466 | " self.hidden_size_1 = 128 # number of encoder cells (from paper)\n", 467 | " self.hidden_size_2 = 32 # number of decoder cells (from paper)\n", 468 | " self.stacked_layers = 2 # number of (stacked) LSTM layers for each stage\n", 469 | " self.dropout_probability = 0.5 # arbitrary value (the paper suggests that performance is generally stable across all ranges)\n", 470 | "\n", 471 | " self.lstm1 = nn.LSTM(n_features, \n", 472 | " self.hidden_size_1, \n", 473 | " num_layers=self.stacked_layers,\n", 474 | " batch_first=True)\n", 475 | " self.lstm2 = nn.LSTM(self.hidden_size_1,\n", 476 | " self.hidden_size_2,\n", 477 | " num_layers=self.stacked_layers,\n", 478 | " batch_first=True)\n", 479 | " \n", 480 | " self.fc = nn.Linear(self.hidden_size_2, output_length)\n", 481 | " self.loss_fn = nn.MSELoss()\n", 482 | " \n", 483 | " def forward(self, x):\n", 484 | " batch_size, seq_len, _ = x.size()\n", 485 | "\n", 486 | " hidden = self.init_hidden1(batch_size)\n", 487 | " output, _ = self.lstm1(x, hidden)\n", 488 | " output = F.dropout(output, p=self.dropout_probability, training=True)\n", 489 | " state = self.init_hidden2(batch_size)\n", 490 | " output, state = self.lstm2(output, state)\n", 491 | " output = F.dropout(output, p=self.dropout_probability, training=True)\n", 492 | " output = output[:, -1, :] # take the last decoder cell's outputs\n", 493 | " y_pred = self.fc(output)\n", 494 | " return y_pred\n", 495 | " \n", 496 | " def init_hidden1(self, batch_size):\n", 497 | " hidden_state = Variable(torch.zeros(self.stacked_layers, batch_size, self.hidden_size_1))\n", 498 | " cell_state = Variable(torch.zeros(self.stacked_layers, batch_size, self.hidden_size_1))\n", 499 | " return hidden_state, cell_state\n", 500 | " \n", 501 | " def init_hidden2(self, batch_size):\n", 502 | " hidden_state = Variable(torch.zeros(self.stacked_layers, batch_size, self.hidden_size_2))\n", 503 | " cell_state = Variable(torch.zeros(self.stacked_layers, batch_size, self.hidden_size_2))\n", 504 | " return hidden_state, cell_state\n", 505 | " \n", 506 | " def loss(self, pred, truth):\n", 507 | " return self.loss_fn(pred, truth)\n", 508 | "\n", 509 | " def predict(self, X):\n", 510 | " return self(torch.tensor(X, dtype=torch.float32)).view(-1).detach().numpy()" 511 | ], 512 | "execution_count": null, 513 | "outputs": [] 514 | }, 515 | { 516 | "cell_type": "markdown", 517 | "metadata": { 518 | "id": "DQ8JLm-ShlaU" 519 | }, 520 | "source": [ 521 | "### Begin Training" 522 | ] 523 | }, 524 | { 525 | "cell_type": "markdown", 526 | "metadata": { 527 | "id": "015pu48r3X1F" 528 | }, 529 | "source": [ 530 | "To train the Bayesian LSTM, we use the ADAM optimizer along with mini-batch gradient descent (`batch_size = 128`). For quick demonstration purposes, the model is trained for 150 epochs.\n", 531 | "\n", 532 | "The Bayesian LSTM is trained on the first 70% of data points, using the aforementioned sliding windows of size 10. The remaining 30% of the dataset is held out purely for testing." 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "metadata": { 538 | "id": "47X-uO6UhbCy" 539 | }, 540 | "source": [ 541 | "n_features = scaled_array.shape[-1]\n", 542 | "sequence_length = 10\n", 543 | "output_length = 1\n", 544 | "\n", 545 | "batch_size = 128\n", 546 | "n_epochs = 150\n", 547 | "learning_rate = 0.01\n", 548 | "\n", 549 | "bayesian_lstm = BayesianLSTM(n_features=n_features,\n", 550 | " output_length=output_length,\n", 551 | " batch_size = batch_size)\n", 552 | "\n", 553 | "criterion = torch.nn.MSELoss()\n", 554 | "optimizer = torch.optim.Adam(bayesian_lstm.parameters(), lr=learning_rate)" 555 | ], 556 | "execution_count": null, 557 | "outputs": [] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "metadata": { 562 | "id": "7iZ__nxaCzZE", 563 | "colab": { 564 | "base_uri": "https://localhost:8080/" 565 | }, 566 | "outputId": "915e737a-6235-4842-b7fe-e94648a86cfb" 567 | }, 568 | "source": [ 569 | "bayesian_lstm.train()\n", 570 | "\n", 571 | "for e in range(1, n_epochs+1):\n", 572 | " for b in range(0, len(X_train), batch_size):\n", 573 | " features = X_train[b:b+batch_size,:,:]\n", 574 | " target = y_train[b:b+batch_size] \n", 575 | "\n", 576 | " X_batch = torch.tensor(features,dtype=torch.float32) \n", 577 | " y_batch = torch.tensor(target,dtype=torch.float32)\n", 578 | "\n", 579 | " output = bayesian_lstm(X_batch)\n", 580 | " loss = criterion(output.view(-1), y_batch) \n", 581 | "\n", 582 | " loss.backward()\n", 583 | " optimizer.step() \n", 584 | " optimizer.zero_grad() \n", 585 | "\n", 586 | " if e % 10 == 0:\n", 587 | " print('epoch', e, 'loss: ', loss.item())" 588 | ], 589 | "execution_count": null, 590 | "outputs": [ 591 | { 592 | "output_type": "stream", 593 | "text": [ 594 | "epoch 10 loss: 0.021331658586859703\n", 595 | "epoch 20 loss: 0.01902003400027752\n", 596 | "epoch 30 loss: 0.01965978927910328\n", 597 | "epoch 40 loss: 0.017205342650413513\n", 598 | "epoch 50 loss: 0.01645931601524353\n", 599 | "epoch 60 loss: 0.017352448776364326\n", 600 | "epoch 70 loss: 0.011554614640772343\n", 601 | "epoch 80 loss: 0.013196350075304508\n", 602 | "epoch 90 loss: 0.0104022566229105\n", 603 | "epoch 100 loss: 0.008596537634730339\n", 604 | "epoch 110 loss: 0.00793543178588152\n", 605 | "epoch 120 loss: 0.006831762846559286\n", 606 | "epoch 130 loss: 0.007397405803203583\n", 607 | "epoch 140 loss: 0.0039006543811410666\n", 608 | "epoch 150 loss: 0.004320652689784765\n" 609 | ], 610 | "name": "stdout" 611 | } 612 | ] 613 | }, 614 | { 615 | "cell_type": "markdown", 616 | "metadata": { 617 | "id": "vUS459C6ro22" 618 | }, 619 | "source": [ 620 | "# Evaluating Model Performance" 621 | ] 622 | }, 623 | { 624 | "cell_type": "markdown", 625 | "metadata": { 626 | "id": "iGYj2vTl311y" 627 | }, 628 | "source": [ 629 | "The Bayesian LSTM implemented is shown to produce reasonably accurate and sensible results on both the training and test sets, often comparable to other existing frequentist machine learning and deep learning methods.\n", 630 | "\n" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "metadata": { 636 | "id": "C-1VXE3_VH_0" 637 | }, 638 | "source": [ 639 | "offset = sequence_length\n", 640 | "\n", 641 | "def inverse_transform(y):\n", 642 | " return target_scaler.inverse_transform(y.reshape(-1, 1))\n", 643 | "\n", 644 | "training_df = pd.DataFrame()\n", 645 | "training_df['date'] = resample_df['date'].iloc[offset:n_train + offset:1] \n", 646 | "training_predictions = bayesian_lstm.predict(X_train)\n", 647 | "training_df['log_energy_consumption'] = inverse_transform(training_predictions)\n", 648 | "training_df['source'] = 'Training Prediction'\n", 649 | "\n", 650 | "training_truth_df = pd.DataFrame()\n", 651 | "training_truth_df['date'] = training_df['date']\n", 652 | "training_truth_df['log_energy_consumption'] = resample_df['log_energy_consumption'].iloc[offset:n_train + offset:1] \n", 653 | "training_truth_df['source'] = 'True Values'\n", 654 | "\n", 655 | "testing_df = pd.DataFrame()\n", 656 | "testing_df['date'] = resample_df['date'].iloc[n_train + offset::1] \n", 657 | "testing_predictions = bayesian_lstm.predict(X_test)\n", 658 | "testing_df['log_energy_consumption'] = inverse_transform(testing_predictions)\n", 659 | "testing_df['source'] = 'Test Prediction'\n", 660 | "\n", 661 | "testing_truth_df = pd.DataFrame()\n", 662 | "testing_truth_df['date'] = testing_df['date']\n", 663 | "testing_truth_df['log_energy_consumption'] = resample_df['log_energy_consumption'].iloc[n_train + offset::1] \n", 664 | "testing_truth_df['source'] = 'True Values'\n", 665 | "\n", 666 | "evaluation = pd.concat([training_df, \n", 667 | " testing_df,\n", 668 | " training_truth_df,\n", 669 | " testing_truth_df\n", 670 | " ], axis=0)" 671 | ], 672 | "execution_count": null, 673 | "outputs": [] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "metadata": { 678 | "id": "uXQJwRO2FZio", 679 | "colab": { 680 | "base_uri": "https://localhost:8080/", 681 | "height": 542 682 | }, 683 | "outputId": "0b953344-92e2-481f-e4fd-e10af64fd424" 684 | }, 685 | "source": [ 686 | "fig = px.line(evaluation.loc[evaluation['date'].between('2016-04-14', '2016-04-23')],\n", 687 | " x=\"date\",\n", 688 | " y=\"log_energy_consumption\",\n", 689 | " color=\"source\",\n", 690 | " title=\"Log of Appliance Energy Consumption in Wh vs Time\")\n", 691 | "fig.show()" 692 | ], 693 | "execution_count": null, 694 | "outputs": [ 695 | { 696 | "output_type": "display_data", 697 | "data": { 698 | "text/html": [ 699 | "\n", 700 | "\n", 701 | "\n", 702 | "
\n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | "
\n", 707 | " \n", 745 | "
\n", 746 | "\n", 747 | "" 748 | ] 749 | }, 750 | "metadata": { 751 | "tags": [] 752 | } 753 | } 754 | ] 755 | }, 756 | { 757 | "cell_type": "markdown", 758 | "metadata": { 759 | "id": "Jrwiy646yq7t" 760 | }, 761 | "source": [ 762 | "# Uncertainty Quantification" 763 | ] 764 | }, 765 | { 766 | "cell_type": "markdown", 767 | "metadata": { 768 | "id": "8FEzMrU147zx" 769 | }, 770 | "source": [ 771 | "The fact that stochastic dropouts are applied after each LSTM layer in the Bayesian LSTM enables users to interpret the model outputs as random samples from the posterior distribution of the target variable. \n", 772 | "\n", 773 | "This implies that by running multiple experiments/predictions, can approximate parameters of the posterioir distribution, namely the mean and the variance, in order to create confidence intervals for each prediction.\n", 774 | "\n", 775 | "In this example, we construct 99% confidence intervals that are three standard deviations away from the approximate mean of each prediction." 776 | ] 777 | }, 778 | { 779 | "cell_type": "code", 780 | "metadata": { 781 | "id": "Jb4xyVW6DVUV" 782 | }, 783 | "source": [ 784 | "n_experiments = 100\n", 785 | "\n", 786 | "test_uncertainty_df = pd.DataFrame()\n", 787 | "test_uncertainty_df['date'] = testing_df['date']\n", 788 | "\n", 789 | "for i in range(n_experiments):\n", 790 | " experiment_predictions = bayesian_lstm.predict(X_test)\n", 791 | " test_uncertainty_df['log_energy_consumption_{}'.format(i)] = inverse_transform(experiment_predictions)\n", 792 | "\n", 793 | "log_energy_consumption_df = test_uncertainty_df.filter(like='log_energy_consumption', axis=1)\n", 794 | "test_uncertainty_df['log_energy_consumption_mean'] = log_energy_consumption_df.mean(axis=1)\n", 795 | "test_uncertainty_df['log_energy_consumption_std'] = log_energy_consumption_df.std(axis=1)\n", 796 | "\n", 797 | "test_uncertainty_df = test_uncertainty_df[['date', 'log_energy_consumption_mean', 'log_energy_consumption_std']]" 798 | ], 799 | "execution_count": null, 800 | "outputs": [] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "metadata": { 805 | "id": "SNrb70dSdDH0" 806 | }, 807 | "source": [ 808 | "test_uncertainty_df['lower_bound'] = test_uncertainty_df['log_energy_consumption_mean'] - 3*test_uncertainty_df['log_energy_consumption_std']\n", 809 | "test_uncertainty_df['upper_bound'] = test_uncertainty_df['log_energy_consumption_mean'] + 3*test_uncertainty_df['log_energy_consumption_std']" 810 | ], 811 | "execution_count": null, 812 | "outputs": [] 813 | }, 814 | { 815 | "cell_type": "code", 816 | "metadata": { 817 | "id": "WdHylS8OdEHt", 818 | "colab": { 819 | "base_uri": "https://localhost:8080/", 820 | "height": 542 821 | }, 822 | "outputId": "a366f060-d302-40e9-d065-19cdefcbddfb" 823 | }, 824 | "source": [ 825 | "import plotly.graph_objects as go\n", 826 | "\n", 827 | "test_uncertainty_plot_df = test_uncertainty_df.copy(deep=True)\n", 828 | "test_uncertainty_plot_df = test_uncertainty_plot_df.loc[test_uncertainty_plot_df['date'].between('2016-05-01', '2016-05-09')]\n", 829 | "truth_uncertainty_plot_df = testing_truth_df.copy(deep=True)\n", 830 | "truth_uncertainty_plot_df = truth_uncertainty_plot_df.loc[testing_truth_df['date'].between('2016-05-01', '2016-05-09')]\n", 831 | "\n", 832 | "upper_trace = go.Scatter(\n", 833 | " x=test_uncertainty_plot_df['date'],\n", 834 | " y=test_uncertainty_plot_df['upper_bound'],\n", 835 | " mode='lines',\n", 836 | " fill=None,\n", 837 | " name='99% Upper Confidence Bound'\n", 838 | " )\n", 839 | "lower_trace = go.Scatter(\n", 840 | " x=test_uncertainty_plot_df['date'],\n", 841 | " y=test_uncertainty_plot_df['lower_bound'],\n", 842 | " mode='lines',\n", 843 | " fill='tonexty',\n", 844 | " fillcolor='rgba(255, 211, 0, 0.1)',\n", 845 | " name='99% Lower Confidence Bound'\n", 846 | " )\n", 847 | "real_trace = go.Scatter(\n", 848 | " x=truth_uncertainty_plot_df['date'],\n", 849 | " y=truth_uncertainty_plot_df['log_energy_consumption'],\n", 850 | " mode='lines',\n", 851 | " fill=None,\n", 852 | " name='Real Values'\n", 853 | " )\n", 854 | "\n", 855 | "data = [upper_trace, lower_trace, real_trace]\n", 856 | "\n", 857 | "fig = go.Figure(data=data)\n", 858 | "fig.update_layout(title='Uncertainty Quantification for Energy Consumption Test Data',\n", 859 | " xaxis_title='Time',\n", 860 | " yaxis_title='log_energy_consumption (log Wh)')\n", 861 | "\n", 862 | "fig.show()" 863 | ], 864 | "execution_count": null, 865 | "outputs": [ 866 | { 867 | "output_type": "display_data", 868 | "data": { 869 | "text/html": [ 870 | "\n", 871 | "\n", 872 | "\n", 873 | "
\n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | "
\n", 878 | " \n", 916 | "
\n", 917 | "\n", 918 | "" 919 | ] 920 | }, 921 | "metadata": { 922 | "tags": [] 923 | } 924 | } 925 | ] 926 | }, 927 | { 928 | "cell_type": "markdown", 929 | "metadata": { 930 | "id": "7THEK4P96J0S" 931 | }, 932 | "source": [ 933 | "#### Evaluating Uncertainty" 934 | ] 935 | }, 936 | { 937 | "cell_type": "markdown", 938 | "metadata": { 939 | "id": "PPuR8L6D6PkL" 940 | }, 941 | "source": [ 942 | "Using multiple experiments above, 99% confidence intervals have been constructed for each the prediction of the target variable (the logarithm of appliance power consumption). While we can visually observe that the model is generally capturing the behavior of the time-series, approximately only 50% of the real data points lie within a 99% confidence interval from the mean prediction value.\n", 943 | "\n", 944 | "Despite the relatively low percentage of points within the confidence interval, it must be noted that Bayesian Neural Networks only seek to quantify the epistemic model uncertainty and does not account for aleatoric uncertainty (i.e. noise)." 945 | ] 946 | }, 947 | { 948 | "cell_type": "code", 949 | "metadata": { 950 | "id": "mV2_6qekxzLn", 951 | "colab": { 952 | "base_uri": "https://localhost:8080/" 953 | }, 954 | "outputId": "31c0cd39-7eb6-4b7d-9f75-f139678ab76f" 955 | }, 956 | "source": [ 957 | "bounds_df = pd.DataFrame()\n", 958 | "\n", 959 | "# Using 99% confidence bounds\n", 960 | "bounds_df['lower_bound'] = test_uncertainty_plot_df['lower_bound']\n", 961 | "bounds_df['prediction'] = test_uncertainty_plot_df['log_energy_consumption_mean']\n", 962 | "bounds_df['real_value'] = truth_uncertainty_plot_df['log_energy_consumption']\n", 963 | "bounds_df['upper_bound'] = test_uncertainty_plot_df['upper_bound']\n", 964 | "\n", 965 | "bounds_df['contained'] = ((bounds_df['real_value'] >= bounds_df['lower_bound']) &\n", 966 | " (bounds_df['real_value'] <= bounds_df['upper_bound']))\n", 967 | "\n", 968 | "print(\"Proportion of points contained within 99% confidence interval:\", \n", 969 | " bounds_df['contained'].mean())" 970 | ], 971 | "execution_count": null, 972 | "outputs": [ 973 | { 974 | "output_type": "stream", 975 | "text": [ 976 | "Proportion of points contained within 99% confidence interval: 0.6632124352331606\n" 977 | ], 978 | "name": "stdout" 979 | } 980 | ] 981 | }, 982 | { 983 | "cell_type": "markdown", 984 | "metadata": { 985 | "id": "CLrYy9qgxp_C" 986 | }, 987 | "source": [ 988 | "# Conclusions\n", 989 | "\n", 990 | "- Bayesian LSTMs have been able to produce comparable performance to their frequentist counterparts (all else being equal)\n", 991 | "- Stochastic dropout enables users to approximate the posterior distribution of the target variable, \\\n", 992 | "and thus construct confidence intervals for each prediction \n", 993 | "- Bayesian Neural Networks only attempt to account for epistemic model uncertainty and do not necessarily address aleatoric uncertainty\n", 994 | "- Computational overhead for repeated/multiple Bayesian LSTM predictions at inference to construct confidence intervals represent a potential challenge for real-time inference use-cases." 995 | ] 996 | } 997 | ] 998 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bayesian LSTM Implementation in PyTorch 2 | 3 | Inspired by: Deep and Confident Prediction for Time Series at Uber (2007) 4 | https://arxiv.org/pdf/1709.01907.pdf 5 | 6 | Bayesian Neural Networks are gaining interest due to their highly desirable properties 7 | of providing quantifiable uncertainties and confidence intervals, unlike equivalent frequentist methods. 8 | 9 | This repository demonstrates an implementation in PyTorch and summarizes several key 10 | features of Bayesian LSTM (Long Short-Term Memory) networks through a real-world example of forecasting building energy consumption. 11 | The Appliances energy prediction dataset used in this example is from the UCI Machine Learning Repository (https://archive.ics.uci.edu/ml/datasets/Appliances+energy+prediction) 12 | 13 | 14 | The accompanying notebook is shared directly from Google Colab. 15 | As a result, interactive visualizations have not been transferred to GitHub. 16 | 17 | Please view the notebook in Google Colab by clicking the **Open in Colab** button 18 | or by clicking here: 19 | https://colab.research.google.com/github/PawaritL/BayesianLSTM/blob/master/Energy_Consumption_Predictions_with_Bayesian_LSTMs_in_PyTorch.ipynb 20 | --------------------------------------------------------------------------------