├── .gitignore ├── .other └── cover.png ├── LICENSE ├── README.md ├── chapter10 ├── Causal_CNN.ipynb ├── RNN.ipynb ├── Time_Series_with_Deep_Learning.ipynb └── passengers.csv ├── chapter11 ├── Ranking_with_Bandits.ipynb ├── Trading_with_DQN.ipynb └── jesterfinal151cols.csv ├── chapter12 ├── Energy_Demand_Forecasting.ipynb ├── gaussian_process.ipynb └── test_models.ipynb ├── chapter2 ├── Air Pollution.ipynb ├── EEG Signals.ipynb ├── README.md ├── monthly_csv.csv └── spm.csv ├── chapter3 └── Preprocessing.ipynb ├── chapter5 └── Forecasting.ipynb ├── chapter6 └── Change-Points_Anomalies.ipynb ├── chapter7 ├── KNN_with_dynamic_DTW.ipynb ├── Kats.ipynb ├── Silverkite.ipynb └── XGBoost.ipynb ├── chapter8 ├── Drift_Detection.ipynb └── Online_Learning.ipynb └── chapter9 ├── Causal_Impact_Volkswagen.ipynb ├── Fuzzy_time_series_forecasting.ipynb ├── Markov_switching_model.ipynb └── Prophet_model.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | .DS_Store -------------------------------------------------------------------------------- /.other/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PacktPublishing/Machine-Learning-for-Time-Series-with-Python/f76b589b83c260563bee1515e393260f5a68c96a/.other/cover.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Packt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Machine Learning Summit 2025

2 | 3 | ## Machine Learning Summit 2025 4 | **Bridging Theory and Practice: ML Solutions for Today’s Challenges** 5 | 6 | 3 days, 20+ experts, and 25+ tech sessions and talks covering critical aspects of: 7 | - **Agentic and Generative AI** 8 | - **Applied Machine Learning in the Real World** 9 | - **ML Engineering and Optimization** 10 | 11 | 👉 [Book your ticket now >>](https://packt.link/mlsumgh) 12 | 13 | --- 14 | 15 | ## Join Our Newsletters 📬 16 | 17 | ### DataPro 18 | *The future of AI is unfolding. Don’t fall behind.* 19 | 20 |

DataPro QR

21 | 22 | Stay ahead with [**DataPro**](https://landing.packtpub.com/subscribe-datapronewsletter/?link_from_packtlink=yes), the free weekly newsletter for data scientists, AI/ML researchers, and data engineers. 23 | From trending tools like **PyTorch**, **scikit-learn**, **XGBoost**, and **BentoML** to hands-on insights on **database optimization** and real-world **ML workflows**, you’ll get what matters, fast. 24 | 25 | > Stay sharp with [DataPro](https://landing.packtpub.com/subscribe-datapronewsletter/?link_from_packtlink=yes). Join **115K+ data professionals** who never miss a beat. 26 | 27 | --- 28 | 29 | ### BIPro 30 | *Business runs on data. Make sure yours tells the right story.* 31 | 32 |

BIPro QR

33 | 34 | [**BIPro**](https://landing.packtpub.com/subscribe-bipro-newsletter/?link_from_packtlink=yes) is your free weekly newsletter for BI professionals, analysts, and data leaders. 35 | Get practical tips on **dashboarding**, **data visualization**, and **analytics strategy** with tools like **Power BI**, **Tableau**, **Looker**, **SQL**, and **dbt**. 36 | 37 | > Get smarter with [BIPro](https://landing.packtpub.com/subscribe-bipro-newsletter/?link_from_packtlink=yes). Trusted by **35K+ BI professionals**, see what you’re missing. 38 | 39 | 40 | 41 | 42 | # Machine-Learning-for-Time-Series-with-Python 43 | 44 | [](https://www.amazon.com/Machine-Learning-Time-Python-state/dp/1801819629/) 45 | 46 | Become proficient in deriving insights from time-series data and analyzing a model’s performance 47 | 48 | ## Links 49 | 50 | * [Amazon](https://www.amazon.com/Machine-Learning-Time-Python-state/dp/1801819629/) 51 | * [Packt Publishing](https://www.packtpub.com/product/machine-learning-for-time-series-with-python/9781801819626) 52 | 53 | ## Key Features 54 | * Explore popular and modern machine learning methods including the latest online and deep learning algorithms 55 | * Learn to increase the accuracy of your predictions by matching the right model with the right problem 56 | * Master time-series via real-world case studies on operations management, digital marketing, finance, and healthcare 57 | 58 | ## What you will learn 59 | - Understand the main classes of time-series and learn how to detect outliers and patterns 60 | - Choose the right method to solve time-series problems 61 | - Characterize seasonal and correlation patterns through autocorrelation and statistical techniques 62 | - Get to grips with time-series data visualization 63 | - Understand classical time-series models like ARMA and ARIMA 64 | - Implement deep learning models, like Gaussian processes, transformers, and state-of-the-art machine learning models 65 | - Become familiar with many libraries like Prophet, XGboost, and TensorFlow 66 | 67 | ## Who This Book Is For 68 | This book is ideal for data analysts, data scientists, and Python developers who are looking to perform time-series analysis to effectively predict outcomes. Basic knowledge of the Python language is essential. Familiarity with statistics is desirable. 69 | 70 | ## Table of Contents 71 | 1. Introduction to Time-Series with Python 72 | 2. Time-Series Analysis with Python 73 | 3. Preprocessing Time-Series 74 | 4. Introduction to Machine Learning for Time-Series 75 | 5. Forecasting with Moving Averages and Autoregressive Models 76 | 6. Unsupervised Methods for Time-Series 77 | 7. Machine Learning Models for Time-Series 78 | 8. Online Learning for Time-Series 79 | 9. Probabilistic Models for Time-Series 80 | 10. Deep Learning for Time-Series 81 | 11. Reinforcement Learning for Time-Series 82 | 12. Multivariate Forecasting 83 | 84 | ## Author Notes 85 | 86 | I've heard from a few people struggling with tsfresh and featuretools for chapter 3. 87 | 88 | [My PR](https://github.com/blue-yonder/tsfresh/pull/912) for tsfresh was merged mid-December fixing a version incompatibility - featuretools went through many breaking changes with the release of version 1.0.0 (congratulations to the team!). Please see how to fix any problems in the [discussion here](https://github.com/PacktPublishing/Machine-Learning-for-Time-Series-with-Python/issues/2). 89 | ### Download a free PDF 90 | 91 | If you have already purchased a print or Kindle version of this book, you can get a DRM-free PDF version at no cost.
Simply click on the link to claim your free PDF.
92 |

https://packt.link/free-ebook/9781801819626

-------------------------------------------------------------------------------- /chapter10/Causal_CNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 6, 13 | "metadata": { 14 | "id": "bU3tHC0PArEI", 15 | "tags": [] 16 | }, 17 | "outputs": [], 18 | "source": [ 19 | "import numpy as np\n", 20 | "import pandas as pd\n", 21 | "from keras.layers import Conv1D, Input, Add, Activation, Dropout\n", 22 | "from keras.models import Sequential, Model\n", 23 | "from keras.layers.advanced_activations import LeakyReLU, ELU\n", 24 | "from keras import optimizers\n", 25 | "import tensorflow as tf\n", 26 | "\n", 27 | "callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=10)\n", 28 | "\n", 29 | "\n", 30 | "def DC_CNN_Block(nb_filter, filter_length, dilation):\n", 31 | " def f(input_):\n", 32 | " residual = input_\n", 33 | " layer_out = Conv1D(\n", 34 | " filters=nb_filter, kernel_size=filter_length, \n", 35 | " dilation_rate=dilation, \n", 36 | " activation='linear', padding='causal', use_bias=False\n", 37 | " )(input_) \n", 38 | " layer_out = Activation('selu')(layer_out) \n", 39 | " skip_out = Conv1D(1, 1, activation='linear', use_bias=False)(layer_out) \n", 40 | " network_in = Conv1D(1, 1, activation='linear', use_bias=False)(layer_out) \n", 41 | " network_out = Add()([residual, network_in]) \n", 42 | " return network_out, skip_out \n", 43 | " return f\n", 44 | "\n", 45 | "\n", 46 | "def DC_CNN_Model(length):\n", 47 | " input = Input(shape=(length,1))\n", 48 | " l1a, l1b = DC_CNN_Block(32, 2, 1)(input) \n", 49 | " l2a, l2b = DC_CNN_Block(32, 2, 2)(l1a) \n", 50 | " l3a, l3b = DC_CNN_Block(32, 2, 4)(l2a)\n", 51 | " l4a, l4b = DC_CNN_Block(32, 2, 8)(l3a)\n", 52 | " l5a, l5b = DC_CNN_Block(32, 2, 16)(l4a)\n", 53 | " l6a, l6b = DC_CNN_Block(32, 2, 32)(l5a)\n", 54 | " l6b = Dropout(0.8)(l6b)\n", 55 | " l7a, l7b = DC_CNN_Block(32, 2, 64)(l6a)\n", 56 | " l7b = Dropout(0.8)(l7b)\n", 57 | " l8 = Add()([l1b, l2b, l3b, l4b, l5b, l6b, l7b])\n", 58 | " l9 = Activation('relu')(l8) \n", 59 | " l21 = Conv1D(1, 1, activation='linear', use_bias=False)(l9)\n", 60 | " model = Model(inputs=input, outputs=l21)\n", 61 | " model.compile(loss='mae', optimizer=optimizers.adam_v2.Adam(), metrics=['mse'])\n", 62 | " return model" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 7, 68 | "metadata": { 69 | "tags": [] 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "def fit_model(timeseries):\n", 74 | " length = len(timeseries)-1\n", 75 | " model = DC_CNN_Model(length)\n", 76 | " model.summary()\n", 77 | " X = timeseries[:-1].reshape(1,length, 1)\n", 78 | " y = timeseries[1:].reshape(1,length, 1)\n", 79 | " model.fit(X, y, epochs=3000, callbacks=[callback])\n", 80 | " return model\n", 81 | " \n", 82 | "def forecast(model, timeseries, horizon: int):\n", 83 | " length = len(timeseries)-1\n", 84 | " pred_array = np.zeros(horizon).reshape(1, horizon, 1)\n", 85 | " X_test_initial = timeseries[1:].reshape(1,length,1)\n", 86 | " \n", 87 | " pred_array[: ,0, :] = model.predict(X_test_initial)[:, -1:, :]\n", 88 | " for i in range(horizon-1):\n", 89 | " pred_array[:, i+1:, :] = model.predict(\n", 90 | " np.append(\n", 91 | " X_test_initial[:, i+1:, :], \n", 92 | " pred_array[:, :i+1, :]\n", 93 | " ).reshape(1, length, 1))[:, -1:, :]\n", 94 | " return pred_array.flatten()\n", 95 | " \n", 96 | "def evaluate_timeseries(series, horizon: int):\n", 97 | " model = fit_model(series)\n", 98 | " pred_array = forecast(model, series, horizon)\n", 99 | " return pred_array, model" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "id": "LCzxf36xxvri", 107 | "tags": [] 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "import pandas as pd\n", 112 | "import matplotlib.pyplot as plt\n", 113 | "import seaborn as sns\n", 114 | "\n", 115 | "def show_result(y_test, predicted, ylabel=\"Passengers\"):\n", 116 | " plt.figure(figsize=(16, 6))\n", 117 | " plt.plot(y_test.index, predicted, 'o-', label=\"predicted\")\n", 118 | " plt.plot(y_test.index, y_test, '.-', label=\"actual\")\n", 119 | "\n", 120 | " plt.ylabel(ylabel)\n", 121 | " plt.legend()" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 9, 127 | "metadata": { 128 | "id": "Tu-831LAylwK", 129 | "tags": [] 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "import pandas as pd\n", 134 | "\n", 135 | "values = [ \n", 136 | " 112., 118., 132., 129., 121., 135., 148., 148., 136., 119., 104., 118., 115., 126.,\n", 137 | " 141., 135., 125., 149., 170., 170., 158., 133., 114., 140., 145., 150., 178., 163.,\n", 138 | " 172., 178., 199., 199., 184., 162., 146., 166., 171., 180., 193., 181., 183., 218.,\n", 139 | " 230., 242., 209., 191., 172., 194., 196., 196., 236., 235., 229., 243., 264., 272.,\n", 140 | " 237., 211., 180., 201., 204., 188., 235., 227., 234., 264., 302., 293., 259., 229.,\n", 141 | " 203., 229., 242., 233., 267., 269., 270., 315., 364., 347., 312., 274., 237., 278.,\n", 142 | " 284., 277., 317., 313., 318., 374., 413., 405., 355., 306., 271., 306., 315., 301.,\n", 143 | " 356., 348., 355., 422., 465., 467., 404., 347., 305., 336., 340., 318., 362., 348.,\n", 144 | " 363., 435., 491., 505., 404., 359., 310., 337., 360., 342., 406., 396., 420., 472.,\n", 145 | " 548., 559., 463., 407., 362., 405., 417., 391., 419., 461., 472., 535., 622., 606.,\n", 146 | " 508., 461., 390., 432.,\n", 147 | " ]\n", 148 | "idx = pd.date_range(\"1949-01-01\", periods=len(values), freq=\"M\")\n", 149 | "passengers = pd.Series(values, index=idx, name=\"passengers\").to_frame()" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": { 156 | "colab": { 157 | "base_uri": "https://localhost:8080/" 158 | }, 159 | "id": "boJVi8Jiya4D", 160 | "outputId": "e4605e96-9198-4ab0-a277-489933daaa0b", 161 | "tags": [] 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "from sklearn.model_selection import train_test_split\n", 166 | "\n", 167 | "\n", 168 | "X_train, X_test, y_train, y_test = train_test_split(\n", 169 | " passengers.passengers, passengers.passengers.shift(-1), shuffle=False\n", 170 | ")\n", 171 | "HORIZON = len(y_test)\n", 172 | "predictions, model = evaluate_timeseries(X_train.values.reshape(-1, 1), horizon=HORIZON)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 14, 178 | "metadata": { 179 | "colab": { 180 | "base_uri": "https://localhost:8080/", 181 | "height": 379 182 | }, 183 | "id": "fxQ1vflLzcmt", 184 | "outputId": "1dfb65e4-fe34-46e0-879a-a8a095493e05", 185 | "tags": [] 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "show_result(y_test[:HORIZON], predictions[:HORIZON], \"Passengers\")" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 15, 195 | "metadata": { 196 | "id": "yH1hqiySztJ1", 197 | "tags": [] 198 | }, 199 | "outputs": [], 200 | "source": [] 201 | } 202 | ], 203 | "metadata": { 204 | "colab": { 205 | "collapsed_sections": [], 206 | "name": "Causal CNN", 207 | "provenance": [] 208 | }, 209 | "kernelspec": { 210 | "display_name": "Python 3", 211 | "language": "python", 212 | "name": "python3" 213 | }, 214 | "language_info": { 215 | "codemirror_mode": { 216 | "name": "ipython", 217 | "version": 3 218 | }, 219 | "file_extension": ".py", 220 | "mimetype": "text/x-python", 221 | "name": "python", 222 | "nbconvert_exporter": "python", 223 | "pygments_lexer": "ipython3", 224 | "version": "3.8.8" 225 | } 226 | }, 227 | "nbformat": 4, 228 | "nbformat_minor": 4 229 | } 230 | -------------------------------------------------------------------------------- /chapter10/RNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "id": "o47CT4PG6dsp" 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "import pandas as pd\n", 19 | "\n", 20 | "values = [ \n", 21 | " 112., 118., 132., 129., 121., 135., 148., 148., 136., 119., 104., 118., 115., 126.,\n", 22 | " 141., 135., 125., 149., 170., 170., 158., 133., 114., 140., 145., 150., 178., 163.,\n", 23 | " 172., 178., 199., 199., 184., 162., 146., 166., 171., 180., 193., 181., 183., 218.,\n", 24 | " 230., 242., 209., 191., 172., 194., 196., 196., 236., 235., 229., 243., 264., 272.,\n", 25 | " 237., 211., 180., 201., 204., 188., 235., 227., 234., 264., 302., 293., 259., 229.,\n", 26 | " 203., 229., 242., 233., 267., 269., 270., 315., 364., 347., 312., 274., 237., 278.,\n", 27 | " 284., 277., 317., 313., 318., 374., 413., 405., 355., 306., 271., 306., 315., 301.,\n", 28 | " 356., 348., 355., 422., 465., 467., 404., 347., 305., 336., 340., 318., 362., 348.,\n", 29 | " 363., 435., 491., 505., 404., 359., 310., 337., 360., 342., 406., 396., 420., 472.,\n", 30 | " 548., 559., 463., 407., 362., 405., 417., 391., 419., 461., 472., 535., 622., 606.,\n", 31 | " 508., 461., 390., 432.,\n", 32 | " ]\n", 33 | "idx = pd.date_range(\"1949-01-01\", periods=len(values), freq=\"M\")\n", 34 | "passengers = pd.Series(values, index=idx, name=\"passengers\").to_frame()" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": { 41 | "colab": { 42 | "base_uri": "https://localhost:8080/", 43 | "height": 419 44 | }, 45 | "id": "BNOhWQC46fJU", 46 | "outputId": "1687dde6-4e92-45f5-f40c-e4668cd7ba32" 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "passengers" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "id": "1Fhde-Qg6pPj" 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "LOOKBACK = 10\n", 62 | "\n", 63 | "def wrap_data(df, lookback):\n", 64 | " dataset = []\n", 65 | " for index in range(lookback, len(df)+1):\n", 66 | " features = {\n", 67 | " f\"col_{i}\": float(val) for i, val in enumerate(\n", 68 | " df.iloc[index-lookback:index].values\n", 69 | " )\n", 70 | " }\n", 71 | " row = pd.DataFrame.from_dict([features])\n", 72 | " row.index = [df.index[index-1]]\n", 73 | " dataset.append(row)\n", 74 | " return pd.concat(dataset, axis=0)\n", 75 | "\n", 76 | "dataset = wrap_data(passengers, lookback=LOOKBACK)\n", 77 | "dataset = dataset.join(passengers.shift(-1))" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": { 84 | "id": "xkD2twYx7lxK" 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "import tensorflow.keras as keras\n", 89 | "from tensorflow.keras.layers import Input, Bidirectional, LSTM, Dense\n", 90 | "import tensorflow as tf\n", 91 | "\n", 92 | "callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n", 93 | "\n", 94 | "def create_model(passengers):\n", 95 | " input_layer = Input(shape=(LOOKBACK, 1))\n", 96 | " recurrent = Bidirectional(LSTM(20, activation=\"tanh\"))(input_layer)\n", 97 | " output_layer = Dense(1)(recurrent)\n", 98 | " model = keras.models.Model(inputs=input_layer, outputs=output_layer)\n", 99 | " model.compile(\n", 100 | " loss='mse', optimizer=keras.optimizers.Adagrad(),\n", 101 | " metrics=[\n", 102 | " keras.metrics.RootMeanSquaredError(),\n", 103 | " keras.metrics.MeanAbsoluteError()\n", 104 | " ])\n", 105 | " return model\n", 106 | "\n", 107 | "model = create_model(passengers)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": { 114 | "id": "uv0e7MJNBIrl" 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "from sklearn.model_selection import train_test_split\n", 119 | "\n", 120 | "X_train, X_test, y_train, y_test = train_test_split(\n", 121 | " dataset.drop(columns=\"passengers\"),\n", 122 | " dataset[\"passengers\"],\n", 123 | " shuffle=False\n", 124 | ")" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": { 131 | "colab": { 132 | "base_uri": "https://localhost:8080/" 133 | }, 134 | "id": "bN4gzfAhAK4q", 135 | "outputId": "4c812a58-5a93-421a-fd20-e47bea56d52b" 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "model.fit(X_train, y_train, epochs=1000, callbacks=[callback])" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": { 146 | "id": "o6lhB_7VApOJ" 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "import pandas as pd\n", 151 | "import matplotlib.pyplot as plt\n", 152 | "import seaborn as sns\n", 153 | "\n", 154 | "def show_result(y_test, predicted):\n", 155 | " plt.figure(figsize=(16, 6))\n", 156 | " plt.plot(y_test.index, predicted, 'o-', label=\"predicted\")\n", 157 | " plt.plot(y_test.index, y_test, '.-', label=\"actual\")\n", 158 | "\n", 159 | " plt.ylabel(\"Passengers\")\n", 160 | " plt.legend()" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": { 167 | "colab": { 168 | "base_uri": "https://localhost:8080/", 169 | "height": 379 170 | }, 171 | "id": "9TN2UDG_Bu0c", 172 | "outputId": "4c84996c-8906-4d53-b84b-2204811b017c" 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "predicted = model.predict(X_test)\n", 177 | "show_result(y_test, predicted)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "id": "V668DSiQBzV1" 185 | }, 186 | "outputs": [], 187 | "source": [] 188 | } 189 | ], 190 | "metadata": { 191 | "colab": { 192 | "name": "RNN", 193 | "provenance": [] 194 | }, 195 | "kernelspec": { 196 | "display_name": "Python 3", 197 | "language": "python", 198 | "name": "python3" 199 | }, 200 | "language_info": { 201 | "codemirror_mode": { 202 | "name": "ipython", 203 | "version": 3 204 | }, 205 | "file_extension": ".py", 206 | "mimetype": "text/x-python", 207 | "name": "python", 208 | "nbconvert_exporter": "python", 209 | "pygments_lexer": "ipython3", 210 | "version": "3.8.8" 211 | } 212 | }, 213 | "nbformat": 4, 214 | "nbformat_minor": 4 215 | } 216 | -------------------------------------------------------------------------------- /chapter10/Time_Series_with_Deep_Learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# from https://github.com/FinYang/tsdl/blob/56e091544cb81e573ee6db20c6f9cd39c70e6243/data-raw/boxjenk/seriesg.dat" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "id": "jEiBBpYjLCj3" 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "values = [ \n", 28 | " 112., 118., 132., 129., 121., 135., 148., 148., 136., 119., 104., 118., 115., 126.,\n", 29 | " 141., 135., 125., 149., 170., 170., 158., 133., 114., 140., 145., 150., 178., 163.,\n", 30 | " 172., 178., 199., 199., 184., 162., 146., 166., 171., 180., 193., 181., 183., 218.,\n", 31 | " 230., 242., 209., 191., 172., 194., 196., 196., 236., 235., 229., 243., 264., 272.,\n", 32 | " 237., 211., 180., 201., 204., 188., 235., 227., 234., 264., 302., 293., 259., 229.,\n", 33 | " 203., 229., 242., 233., 267., 269., 270., 315., 364., 347., 312., 274., 237., 278.,\n", 34 | " 284., 277., 317., 313., 318., 374., 413., 405., 355., 306., 271., 306., 315., 301.,\n", 35 | " 356., 348., 355., 422., 465., 467., 404., 347., 305., 336., 340., 318., 362., 348.,\n", 36 | " 363., 435., 491., 505., 404., 359., 310., 337., 360., 342., 406., 396., 420., 472.,\n", 37 | " 548., 559., 463., 407., 362., 405., 417., 391., 419., 461., 472., 535., 622., 606.,\n", 38 | " 508., 461., 390., 432.,\n", 39 | " ]" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "id": "9bBfTTdMK9AE" 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "import pandas as pd\n", 51 | "idx = pd.date_range(\"1949-01-01\", periods=len(values), freq=\"M\")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "id": "z3sG4R4VKnWT" 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "passengers = pd.Series(values, index=idx, name=\"passengers\").to_frame()" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "id": "PBlu9I2p8uib" 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "from sklearn.model_selection import train_test_split\n", 74 | "\n", 75 | "X_train, X_test, y_train, y_test = train_test_split(passengers, passengers.passengers.shift(-1), shuffle=False)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": { 82 | "id": "DkejrcL8aFt1" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "import tensorflow.keras as keras\n", 87 | "import tensorflow as tf\n", 88 | "\n", 89 | "DROPOUT_RATIO = 0.2\n", 90 | "HIDDEN_NEURONS = 10\n", 91 | "\n", 92 | "callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n", 93 | "\n", 94 | "def create_model(passengers):\n", 95 | " input_layer = keras.layers.Input(len(passengers.columns))\n", 96 | "\n", 97 | " hiden_layer = keras.layers.Dropout(DROPOUT_RATIO)(input_layer)\n", 98 | " hiden_layer = keras.layers.Dense(HIDDEN_NEURONS, activation='relu')(hiden_layer)\n", 99 | "\n", 100 | " output_layer = keras.layers.Dropout(DROPOUT_RATIO)(hiden_layer)\n", 101 | " output_layer = keras.layers.Dense(1)(output_layer)\n", 102 | "\n", 103 | " model = keras.models.Model(inputs=input_layer, outputs=output_layer)\n", 104 | "\n", 105 | " model.compile(loss='mse', optimizer=keras.optimizers.Adagrad(),\n", 106 | " metrics=[keras.metrics.RootMeanSquaredError(), keras.metrics.MeanAbsoluteError()])\n", 107 | " return model\n", 108 | "\n", 109 | "model = create_model(passengers)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "colab": { 117 | "base_uri": "https://localhost:8080/" 118 | }, 119 | "id": "qJCCbRTib8Up", 120 | "outputId": "63136aae-e096-4ff3-f4fb-2d728054d099" 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "model.fit(X_train, y_train, epochs=1000, callbacks=[callback])" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": { 131 | "id": "BgHreXsUdc6a" 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "predicted = model.predict(X_test)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": { 142 | "id": "mJC6mB74dkh9" 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "import matplotlib.pyplot as plt\n", 147 | "\n", 148 | "def show_result(y_test, predicted):\n", 149 | " plt.figure(figsize=(16, 6))\n", 150 | " plt.plot(y_test.index, predicted, 'o-', label=\"predicted\")\n", 151 | " plt.plot(y_test.index, y_test, '.-', label=\"actual\")\n", 152 | "\n", 153 | " plt.ylabel(\"Passengers\")\n", 154 | " plt.legend()" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "colab": { 162 | "base_uri": "https://localhost:8080/", 163 | "height": 396 164 | }, 165 | "id": "4spy29-UdzRS", 166 | "outputId": "86904074-d6ec-48a2-ab70-d5085ff21535" 167 | }, 168 | "outputs": [], 169 | "source": [ 170 | "show_result(y_test, predicted)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": { 177 | "colab": { 178 | "base_uri": "https://localhost:8080/" 179 | }, 180 | "id": "Mh4E54N_egps", 181 | "outputId": "3930edda-49da-4ba5-ba5b-2d9c09241a94" 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "passengers[\"month\"] = passengers.index.month.values\n", 186 | "passengers[\"year\"] = passengers.index.year.values\n", 187 | "\n", 188 | "model = create_model(passengers)\n", 189 | "X_train, X_test, y_train, y_test = train_test_split(passengers, passengers.passengers.shift(-1), shuffle=False)\n", 190 | "model.fit(X_train, y_train, epochs=100, callbacks=[callback])\n", 191 | "predicted = model.predict(X_test)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": { 198 | "colab": { 199 | "base_uri": "https://localhost:8080/", 200 | "height": 379 201 | }, 202 | "id": "Y9_1dwAofTM7", 203 | "outputId": "6f8515de-b3d6-4aee-a65b-cde8dc442631" 204 | }, 205 | "outputs": [], 206 | "source": [ 207 | "show_result(y_test, predicted)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": { 214 | "id": "3eSnjPMYrcbN" 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "from tensorflow.keras.layers.experimental import preprocessing\n", 219 | "import tensorflow as tf\n", 220 | "\n", 221 | "\n", 222 | "DROPOUT_RATIO = 0.1\n", 223 | "HIDDEN_NEURONS = 5\n", 224 | "\n", 225 | "\n", 226 | "def create_model(passengers):\n", 227 | " scale = tf.constant(passengers.passengers.std())\n", 228 | "\n", 229 | " continuous_input_layer = keras.layers.Input(shape=1)\n", 230 | "\n", 231 | " categorical_input_layer = keras.layers.Input(shape=1)\n", 232 | " embedded = keras.layers.Embedding(12, 5)(categorical_input_layer)\n", 233 | " embedded_flattened = keras.layers.Flatten()(embedded)\n", 234 | "\n", 235 | " year_input = keras.layers.Input(shape=1)\n", 236 | " year_layer = keras.layers.Dense(1)(year_input)\n", 237 | "\n", 238 | " hidden_output = keras.layers.Concatenate(-1)([embedded_flattened, year_layer, continuous_input_layer])\n", 239 | " output_layer = keras.layers.Dense(1)(hidden_output)\n", 240 | " output = output_layer * scale + continuous_input_layer\n", 241 | "\n", 242 | " model = keras.models.Model(inputs=[\n", 243 | " continuous_input_layer, categorical_input_layer, year_input\n", 244 | " ], outputs=output)\n", 245 | "\n", 246 | " model.compile(loss='mse', optimizer=keras.optimizers.Adam(),\n", 247 | " metrics=[keras.metrics.RootMeanSquaredError(), keras.metrics.MeanAbsoluteError()])\n", 248 | " return model" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": { 255 | "colab": { 256 | "base_uri": "https://localhost:8080/" 257 | }, 258 | "id": "zSyGPARMt9zx", 259 | "outputId": "df8b02fb-fc3a-49da-ac2f-d6e951eca334" 260 | }, 261 | "outputs": [], 262 | "source": [ 263 | "passengers = pd.Series(values, index=idx, name=\"passengers\").to_frame()\n", 264 | "passengers[\"year\"] = passengers.index.year.values - passengers.index.year.values.min()\n", 265 | "passengers[\"month\"] = passengers.index.month.values - 1\n", 266 | "\n", 267 | "X_train, X_test, y_train, y_test = train_test_split(passengers, passengers.passengers.shift(-1), shuffle=False)\n", 268 | "model = create_model(X_train)\n", 269 | "model.fit(\n", 270 | " (X_train[\"passengers\"], X_train[\"year\"], X_train[\"month\"]),\n", 271 | " y_train, epochs=1000,\n", 272 | " callbacks=[callback]\n", 273 | ")\n", 274 | "predicted = model.predict((X_test[\"passengers\"], X_test[\"year\"], X_test[\"month\"]))" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": { 281 | "colab": { 282 | "base_uri": "https://localhost:8080/", 283 | "height": 379 284 | }, 285 | "id": "D2fH073LAYYc", 286 | "outputId": "c3bc19a2-7bf6-47dc-f753-75fdfb215aa2" 287 | }, 288 | "outputs": [], 289 | "source": [ 290 | "show_result(y_test, predicted)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": { 297 | "id": "SMfFWDb8Vs6x" 298 | }, 299 | "outputs": [], 300 | "source": [] 301 | } 302 | ], 303 | "metadata": { 304 | "colab": { 305 | "collapsed_sections": [], 306 | "name": "Time-Series with Deep Learning", 307 | "provenance": [] 308 | }, 309 | "kernelspec": { 310 | "display_name": "Python 3", 311 | "language": "python", 312 | "name": "python3" 313 | }, 314 | "language_info": { 315 | "codemirror_mode": { 316 | "name": "ipython", 317 | "version": 3 318 | }, 319 | "file_extension": ".py", 320 | "mimetype": "text/x-python", 321 | "name": "python", 322 | "nbconvert_exporter": "python", 323 | "pygments_lexer": "ipython3", 324 | "version": "3.8.8" 325 | } 326 | }, 327 | "nbformat": 4, 328 | "nbformat_minor": 4 329 | } 330 | -------------------------------------------------------------------------------- /chapter10/passengers.csv: -------------------------------------------------------------------------------- 1 | date,passengers 2 | 1949-01-31,112.0 3 | 1949-02-28,118.0 4 | 1949-03-31,132.0 5 | 1949-04-30,129.0 6 | 1949-05-31,121.0 7 | 1949-06-30,135.0 8 | 1949-07-31,148.0 9 | 1949-08-31,148.0 10 | 1949-09-30,136.0 11 | 1949-10-31,119.0 12 | 1949-11-30,104.0 13 | 1949-12-31,118.0 14 | 1950-01-31,115.0 15 | 1950-02-28,126.0 16 | 1950-03-31,141.0 17 | 1950-04-30,135.0 18 | 1950-05-31,125.0 19 | 1950-06-30,149.0 20 | 1950-07-31,170.0 21 | 1950-08-31,170.0 22 | 1950-09-30,158.0 23 | 1950-10-31,133.0 24 | 1950-11-30,114.0 25 | 1950-12-31,140.0 26 | 1951-01-31,145.0 27 | 1951-02-28,150.0 28 | 1951-03-31,178.0 29 | 1951-04-30,163.0 30 | 1951-05-31,172.0 31 | 1951-06-30,178.0 32 | 1951-07-31,199.0 33 | 1951-08-31,199.0 34 | 1951-09-30,184.0 35 | 1951-10-31,162.0 36 | 1951-11-30,146.0 37 | 1951-12-31,166.0 38 | 1952-01-31,171.0 39 | 1952-02-29,180.0 40 | 1952-03-31,193.0 41 | 1952-04-30,181.0 42 | 1952-05-31,183.0 43 | 1952-06-30,218.0 44 | 1952-07-31,230.0 45 | 1952-08-31,242.0 46 | 1952-09-30,209.0 47 | 1952-10-31,191.0 48 | 1952-11-30,172.0 49 | 1952-12-31,194.0 50 | 1953-01-31,196.0 51 | 1953-02-28,196.0 52 | 1953-03-31,236.0 53 | 1953-04-30,235.0 54 | 1953-05-31,229.0 55 | 1953-06-30,243.0 56 | 1953-07-31,264.0 57 | 1953-08-31,272.0 58 | 1953-09-30,237.0 59 | 1953-10-31,211.0 60 | 1953-11-30,180.0 61 | 1953-12-31,201.0 62 | 1954-01-31,204.0 63 | 1954-02-28,188.0 64 | 1954-03-31,235.0 65 | 1954-04-30,227.0 66 | 1954-05-31,234.0 67 | 1954-06-30,264.0 68 | 1954-07-31,302.0 69 | 1954-08-31,293.0 70 | 1954-09-30,259.0 71 | 1954-10-31,229.0 72 | 1954-11-30,203.0 73 | 1954-12-31,229.0 74 | 1955-01-31,242.0 75 | 1955-02-28,233.0 76 | 1955-03-31,267.0 77 | 1955-04-30,269.0 78 | 1955-05-31,270.0 79 | 1955-06-30,315.0 80 | 1955-07-31,364.0 81 | 1955-08-31,347.0 82 | 1955-09-30,312.0 83 | 1955-10-31,274.0 84 | 1955-11-30,237.0 85 | 1955-12-31,278.0 86 | 1956-01-31,284.0 87 | 1956-02-29,277.0 88 | 1956-03-31,317.0 89 | 1956-04-30,313.0 90 | 1956-05-31,318.0 91 | 1956-06-30,374.0 92 | 1956-07-31,413.0 93 | 1956-08-31,405.0 94 | 1956-09-30,355.0 95 | 1956-10-31,306.0 96 | 1956-11-30,271.0 97 | 1956-12-31,306.0 98 | 1957-01-31,315.0 99 | 1957-02-28,301.0 100 | 1957-03-31,356.0 101 | 1957-04-30,348.0 102 | 1957-05-31,355.0 103 | 1957-06-30,422.0 104 | 1957-07-31,465.0 105 | 1957-08-31,467.0 106 | 1957-09-30,404.0 107 | 1957-10-31,347.0 108 | 1957-11-30,305.0 109 | 1957-12-31,336.0 110 | 1958-01-31,340.0 111 | 1958-02-28,318.0 112 | 1958-03-31,362.0 113 | 1958-04-30,348.0 114 | 1958-05-31,363.0 115 | 1958-06-30,435.0 116 | 1958-07-31,491.0 117 | 1958-08-31,505.0 118 | 1958-09-30,404.0 119 | 1958-10-31,359.0 120 | 1958-11-30,310.0 121 | 1958-12-31,337.0 122 | 1959-01-31,360.0 123 | 1959-02-28,342.0 124 | 1959-03-31,406.0 125 | 1959-04-30,396.0 126 | 1959-05-31,420.0 127 | 1959-06-30,472.0 128 | 1959-07-31,548.0 129 | 1959-08-31,559.0 130 | 1959-09-30,463.0 131 | 1959-10-31,407.0 132 | 1959-11-30,362.0 133 | 1959-12-31,405.0 134 | 1960-01-31,417.0 135 | 1960-02-29,391.0 136 | 1960-03-31,419.0 137 | 1960-04-30,461.0 138 | 1960-05-31,472.0 139 | 1960-06-30,535.0 140 | 1960-07-31,622.0 141 | 1960-08-31,606.0 142 | 1960-09-30,508.0 143 | 1960-10-31,461.0 144 | 1960-11-30,390.0 145 | 1960-12-31,432.0 146 | -------------------------------------------------------------------------------- /chapter11/Ranking_with_Bandits.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "708ef2ed-176f-42e3-a47a-bcc324221daa", 6 | "metadata": {}, 7 | "source": [ 8 | "\"Open" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "55093d0c-238e-4ac3-ad08-2fe7ff52ca13", 14 | "metadata": {}, 15 | "source": [ 16 | "After https://github.com/Kenza-AI/mab-ranking/blob/master/examples/jester/example.ipynb" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 3, 22 | "id": "c5fcc1c0-2e0c-48b3-9f10-5263e1f7ae33", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "URL = 'https://raw.githubusercontent.com/PacktPublishing/Machine-Learning-for-Time-Series-with-Python/main/chapter11/jesterfinal151cols.csv'" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 4, 32 | "id": "533e06c6-62dd-48f7-be50-3356666d364b", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import pandas as pd\n", 37 | "jester_data = pd.read_csv(URL, header=None)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "a9a7e5a7-bdfa-4a11-ad83-f319735b44cc", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "# jester_data.columns = [f\"joke_{col}\" for col in jester_data.columns]\n", 48 | "jester_data.index.name = \"users\"" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "7490c980-6f0f-453f-8b02-8ca5f101aecb", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "jester_data.head()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "5990c5cc-66ce-482c-b7f4-1f4911470d6e", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "for col in jester_data.columns:\n", 69 | " jester_data[col] = jester_data[col].apply(lambda x: 0.0 if x>=99 or x<7.0 else 1.0)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "f4b220de-0657-4e4f-a01f-33c879a20aa5", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "jester_data" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "578d30fa-7e4f-41a9-bbdd-86c1f3898fbe", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "# keep users with at least one rating for a joke\n", 90 | "jester_data = jester_data[jester_data.sum(axis=1) > 0]" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "9432df2f-4500-46c7-a47c-1a1f0e592b38", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "jester_data" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "85cf3f6c-c823-4883-b250-8bf7a7c2cf3e", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "pip install git+https://github.com/benman1/mab-ranking" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "f5302c81-3478-4ac3-b1f9-96f721f75515", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "# setting up the bandits:\n", 121 | "from mab_ranking.bandits.rank_bandits import IndependentBandits\n", 122 | "from mab_ranking.bandits.bandits import BetaThompsonSampling, DirichletThompsonSampling\n", 123 | "\n", 124 | "independent_bandits = IndependentBandits(\n", 125 | " num_arms=jester_data.shape[1],\n", 126 | " num_ranks=10,\n", 127 | " bandit_class=DirichletThompsonSampling\n", 128 | ")" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "329ba955-7490-402a-a007-428121d130f4", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "from tqdm import trange\n", 139 | "\n", 140 | "num_steps = 7000\n", 141 | "hit_rates = []\n", 142 | "for _ in trange(1, num_steps + 1):\n", 143 | " selected_items = set(independent_bandits.choose())\n", 144 | " # Pick a users choices at random\n", 145 | " random_user = jester_data.sample().iloc[0, :]\n", 146 | " ground_truth = set(random_user[random_user == 1].index)\n", 147 | " hit_rate = len(ground_truth.intersection(selected_items)) / len(ground_truth)\n", 148 | " feedback_list = [1.0 if item in ground_truth else 0.0 for item in selected_items]\n", 149 | " independent_bandits.update(selected_items, feedback_list)\n", 150 | " hit_rates.append(hit_rate)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "6c1acf3d-b2ca-4f00-9b22-3d703c22713a", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "import matplotlib.pyplot as plt\n", 161 | "\n", 162 | "stats = pd.Series(hit_rates)\n", 163 | "plt.figure(figsize=(12, 6))\n", 164 | "plt.plot(stats.index, stats.rolling(200).mean(), \"--\")\n", 165 | "plt.xlabel('Iteration')\n", 166 | "plt.ylabel('Hit rate')" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "d1535111-fa30-45b1-b723-5bfaa24eba8e", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "from sklearn.cluster import KMeans\n", 177 | "from sklearn.preprocessing import StandardScaler\n", 178 | "\n", 179 | "scaler = StandardScaler().fit(jester_data)\n", 180 | "kmeans = KMeans(n_clusters=5, random_state=0).fit(scaler.transform(jester_data))\n", 181 | "contexts = pd.Series(kmeans.labels_, index=jester_data.index)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "ac755bbe-306e-4086-a24f-a8ef55792287", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "contexts.value_counts()" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "id": "128d0fe1-5edb-47e1-aeab-62531a620a56", 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "independent_bandits = IndependentBandits(\n", 202 | " num_arms=jester_data.shape[1],\n", 203 | " num_ranks=10,\n", 204 | " bandit_class=DirichletThompsonSampling\n", 205 | ")\n", 206 | "\n", 207 | "num_steps = 7000\n", 208 | "hit_rates = []\n", 209 | "for _ in trange(1, num_steps + 1):\n", 210 | " # Pick a users choices at random\n", 211 | " random_user = jester_data.sample().iloc[0, :]\n", 212 | " context = {\"previous_action\": contexts.loc[random_user.name]}\n", 213 | " selected_items = set(independent_bandits.choose(\n", 214 | " context=context\n", 215 | " ))\n", 216 | " ground_truth = set(random_user[random_user == 1].index)\n", 217 | " hit_rate = len(ground_truth.intersection(selected_items)) / len(ground_truth)\n", 218 | " feedback_list = [1.0 if item in ground_truth else 0.0 for item in selected_items]\n", 219 | " independent_bandits.update(selected_items, feedback_list, context=context)\n", 220 | " hit_rates.append(hit_rate)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "id": "daefac60-eba6-4429-8a33-067c26483402", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "import matplotlib.pyplot as plt\n", 231 | "\n", 232 | "stats = pd.Series(hit_rates)\n", 233 | "plt.figure(figsize=(12, 6))\n", 234 | "plt.plot(stats.index, stats.rolling(200).mean(), \"--\")\n", 235 | "plt.xlabel('Iteration')\n", 236 | "plt.ylabel('Hit rate')" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "id": "016e11fb-b2e7-4705-8aae-251468762bc4", 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [] 246 | } 247 | ], 248 | "metadata": { 249 | "kernelspec": { 250 | "display_name": "Python 3", 251 | "language": "python", 252 | "name": "python3" 253 | }, 254 | "language_info": { 255 | "codemirror_mode": { 256 | "name": "ipython", 257 | "version": 3 258 | }, 259 | "file_extension": ".py", 260 | "mimetype": "text/x-python", 261 | "name": "python", 262 | "nbconvert_exporter": "python", 263 | "pygments_lexer": "ipython3", 264 | "version": "3.8.8" 265 | } 266 | }, 267 | "nbformat": 4, 268 | "nbformat_minor": 5 269 | } 270 | -------------------------------------------------------------------------------- /chapter11/Trading_with_DQN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "id": "P6xXgYonWE8n" 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "# based on https://github.com/tensortrade-org/tensortrade/blob/master/examples/train_and_evaluate.ipynb" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": { 25 | "colab": { 26 | "base_uri": "https://localhost:8080/" 27 | }, 28 | "id": "hp1Fk0dGJXpT", 29 | "outputId": "c7215ab3-26d0-45a9-8c47-6973f527ace6" 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "pip install git+https://github.com/tensortrade-org/tensortrade.git" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": { 40 | "id": "UkrQ9KU3JXpU" 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "# all imports:\n", 45 | "import pandas as pd\n", 46 | "import tensortrade.env.default as default\n", 47 | "\n", 48 | "from tensortrade.data.cdd import CryptoDataDownload\n", 49 | "from tensortrade.feed.core import Stream, DataFeed\n", 50 | "from tensortrade.oms.exchanges import Exchange\n", 51 | "from tensortrade.oms.services.execution.simulated import execute_order\n", 52 | "from tensortrade.oms.instruments import USD, BTC, ETH\n", 53 | "from tensortrade.oms.wallets import Wallet, Portfolio\n", 54 | "from tensortrade.agents import DQNAgent\n", 55 | "\n", 56 | "\n", 57 | "%matplotlib inline" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "id": "QpPq3bKwJXpU" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "cdd = CryptoDataDownload()\n", 69 | "\n", 70 | "data = cdd.fetch(\"Bitstamp\", \"USD\", \"BTC\", \"1h\")" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": { 77 | "colab": { 78 | "base_uri": "https://localhost:8080/", 79 | "height": 204 80 | }, 81 | "id": "yepqF5x7JXpV", 82 | "outputId": "166d22d5-e050-4f44-9b27-96d5c4f7b6e1" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "data.head()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": { 93 | "id": "8PVhnDBvJXpW" 94 | }, 95 | "outputs": [], 96 | "source": [ 97 | "# we'll create a couple of indicators:\n", 98 | "def rsi(price: Stream[float], period: float) -> Stream[float]:\n", 99 | " r = price.diff()\n", 100 | " upside = r.clamp_min(0).abs()\n", 101 | " downside = r.clamp_max(0).abs()\n", 102 | " rs = upside.ewm(alpha=1 / period).mean() / downside.ewm(alpha=1 / period).mean()\n", 103 | " return 100*(1 - (1 + rs) ** -1)\n", 104 | "\n", 105 | "\n", 106 | "def macd(price: Stream[float], fast: float, slow: float, signal: float) -> Stream[float]:\n", 107 | " fm = price.ewm(span=fast, adjust=False).mean()\n", 108 | " sm = price.ewm(span=slow, adjust=False).mean()\n", 109 | " md = fm - sm\n", 110 | " signal = md - md.ewm(span=signal, adjust=False).mean()\n", 111 | " return signal## Create features with the feed module" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "# choosing the closing price:\n", 121 | "features = []\n", 122 | "for c in data.columns[1:]:\n", 123 | " s = Stream.source(list(data[c]), dtype=\"float\").rename(data[c].name)\n", 124 | " features += [s]\n", 125 | "\n", 126 | "cp = Stream.select(features, lambda s: s.name == \"close\")" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "id": "hYjM8JoJJXpW" 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "# adding the three features (trend indicator, RSI, MACD):\n", 138 | "features = [\n", 139 | " cp.log().diff().rename(\"lr\"),\n", 140 | " rsi(cp, period=20).rename(\"rsi\"),\n", 141 | " macd(cp, fast=10, slow=50, signal=5).rename(\"macd\")\n", 142 | "]\n", 143 | "\n", 144 | "feed = DataFeed(features)\n", 145 | "feed.compile()" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": { 152 | "colab": { 153 | "base_uri": "https://localhost:8080/" 154 | }, 155 | "id": "uqXBeRwrJXpX", 156 | "outputId": "69b29f71-03e6-4c70-a6e3-3d8aa3b23717" 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "for i in range(5):\n", 161 | " print(feed.next())" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "# setting up broker and the portfolio:\n", 171 | "bitstamp = Exchange(\"bitstamp\", service=execute_order)(\n", 172 | " Stream.source(list(data[\"close\"]), dtype=\"float\").rename(\"USD-BTC\")\n", 173 | ")\n", 174 | "\n", 175 | "portfolio = Portfolio(USD, [\n", 176 | " Wallet(bitstamp, 10000 * USD),\n", 177 | " Wallet(bitstamp, 10 * BTC)\n", 178 | "])" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "# renderer:\n", 188 | "renderer_feed = DataFeed([\n", 189 | " Stream.source(list(data[\"date\"])).rename(\"date\"),\n", 190 | " Stream.source(list(data[\"open\"]), dtype=\"float\").rename(\"open\"),\n", 191 | " Stream.source(list(data[\"high\"]), dtype=\"float\").rename(\"high\"),\n", 192 | " Stream.source(list(data[\"low\"]), dtype=\"float\").rename(\"low\"),\n", 193 | " Stream.source(list(data[\"close\"]), dtype=\"float\").rename(\"close\"), \n", 194 | " Stream.source(list(data[\"volume\"]), dtype=\"float\").rename(\"volume\") \n", 195 | "])" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": { 202 | "id": "IzWYGEyZJXpY" 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "# the trading environment:\n", 207 | "env = default.create(\n", 208 | " portfolio=portfolio,\n", 209 | " action_scheme=\"managed-risk\",\n", 210 | " reward_scheme=\"risk-adjusted\",\n", 211 | " feed=feed,\n", 212 | " renderer_feed=renderer_feed,\n", 213 | " renderer=default.renderers.PlotlyTradingChart(),\n", 214 | " window_size=20\n", 215 | ")" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": { 222 | "colab": { 223 | "base_uri": "https://localhost:8080/" 224 | }, 225 | "id": "hYUd95vjJXpY", 226 | "outputId": "45d3f699-abe3-49d4-fd87-d7bba3e953f9" 227 | }, 228 | "outputs": [], 229 | "source": [ 230 | "env.observer.feed.next()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": { 237 | "colab": { 238 | "base_uri": "https://localhost:8080/", 239 | "height": 1000, 240 | "referenced_widgets": [ 241 | "9bcf0bbf1c1c49859a2dfa40acce1038" 242 | ] 243 | }, 244 | "id": "YoxUET9ZJXpZ", 245 | "outputId": "a7672ac6-9c9c-42ff-b9d3-8a128fd882ae" 246 | }, 247 | "outputs": [], 248 | "source": [ 249 | "# training a DQN trading agent\n", 250 | "agent = DQNAgent(env)\n", 251 | "\n", 252 | "agent.train(n_steps=200, n_episodes=2, save_path=\"agents/\")" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": { 259 | "colab": { 260 | "base_uri": "https://localhost:8080/", 261 | "height": 282 262 | }, 263 | "id": "dtoYz5bOKGdo", 264 | "outputId": "fc4372f2-1d3f-4072-e452-2d91d41a2abe" 265 | }, 266 | "outputs": [], 267 | "source": [ 268 | "%matplotlib inline\n", 269 | "\n", 270 | "performance = pd.DataFrame.from_dict(env.action_scheme.portfolio.performance, orient='index')\n", 271 | "performance.plot()" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": { 278 | "colab": { 279 | "base_uri": "https://localhost:8080/", 280 | "height": 282 281 | }, 282 | "id": "URT6-E9wVtIt", 283 | "outputId": "0b293bc2-46a6-4031-ea0c-9c8415f60d77" 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "performance[\"net_worth\"].plot()" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": { 294 | "id": "GymCxiqfV0VM" 295 | }, 296 | "outputs": [], 297 | "source": [] 298 | } 299 | ], 300 | "metadata": { 301 | "colab": { 302 | "collapsed_sections": [], 303 | "name": "train_and_evaluate.ipynb", 304 | "provenance": [] 305 | }, 306 | "kernelspec": { 307 | "display_name": "Python 3", 308 | "language": "python", 309 | "name": "python3" 310 | }, 311 | "language_info": { 312 | "codemirror_mode": { 313 | "name": "ipython", 314 | "version": 3 315 | }, 316 | "file_extension": ".py", 317 | "mimetype": "text/x-python", 318 | "name": "python", 319 | "nbconvert_exporter": "python", 320 | "pygments_lexer": "ipython3", 321 | "version": "3.8.8" 322 | } 323 | }, 324 | "nbformat": 4, 325 | "nbformat_minor": 4 326 | } 327 | -------------------------------------------------------------------------------- /chapter2/README.md: -------------------------------------------------------------------------------- 1 | # Chapter 2 - Time-Series Analysis with Python 2 | 3 | The CSV file `spm.csv` was downloaded from the [Our World in Data github repository](https://github.com/owid/owid-datasets/blob/master/datasets) (Air pollution by city - Fouquet and DPCC). 4 | 5 | The global temperatures were downloaded from [Datahub](https://datahub.io/core/global-temp). 6 | -------------------------------------------------------------------------------- /chapter2/spm.csv: -------------------------------------------------------------------------------- 1 | Entity,Year,Smoke (Fouquet and DPCC (2011)),Suspended Particulate Matter (SPM) (Fouquet and DPCC (2011)) 2 | Delhi,1997,,363 3 | Delhi,1998,,378 4 | Delhi,1999,,375 5 | Delhi,2000,,431 6 | Delhi,2001,,382 7 | Delhi,2002,,456 8 | Delhi,2003,,391 9 | Delhi,2004,,390 10 | Delhi,2005,,373 11 | Delhi,2006,,433 12 | Delhi,2007,,365 13 | Delhi,2008,,416 14 | Delhi,2009,,492 15 | Delhi,2010,,481 16 | London,1700,142.8571429,259.7402597 17 | London,1701,144.2857143,262.3376623 18 | London,1702,145.7142857,264.9350649 19 | London,1703,147.1428571,267.5324675 20 | London,1704,148.5714286,270.1298701 21 | London,1705,150,272.7272727 22 | London,1706,151.4285714,275.3246753 23 | London,1707,152.8571429,277.9220779 24 | London,1708,154.2857143,280.5194805 25 | London,1709,155.7142857,283.1168831 26 | London,1710,157.1428571,285.7142857 27 | London,1711,158.5714286,288.3116883 28 | London,1712,160,290.9090909 29 | London,1713,161.4285714,293.5064935 30 | London,1714,162.8571429,296.1038961 31 | London,1715,164.2857143,298.7012987 32 | London,1716,165.7142857,301.2987013 33 | London,1717,167.1428571,303.8961039 34 | London,1718,168.5714286,306.4935065 35 | London,1719,170,309.0909091 36 | London,1720,171.4285714,311.6883117 37 | London,1721,172.8571429,314.2857143 38 | London,1722,174.2857143,316.8831169 39 | London,1723,175.7142857,319.4805195 40 | London,1724,177.1428571,322.0779221 41 | London,1725,178.5714286,324.6753247 42 | London,1726,179.2857143,325.974026 43 | London,1727,180,327.2727273 44 | London,1728,180.7142857,328.5714286 45 | London,1729,181.4285714,329.8701299 46 | London,1730,182.1428571,331.1688312 47 | London,1731,182.8571429,332.4675325 48 | London,1732,183.5714286,333.7662338 49 | London,1733,184.2857143,335.0649351 50 | London,1734,185,336.3636364 51 | London,1735,185.7142857,337.6623377 52 | London,1736,186.4285714,338.961039 53 | London,1737,187.1428571,340.2597403 54 | London,1738,187.8571429,341.5584416 55 | London,1739,188.5714286,342.8571429 56 | London,1740,189.2857143,344.1558442 57 | London,1741,190,345.4545455 58 | London,1742,190.7142857,346.7532468 59 | London,1743,191.4285714,348.0519481 60 | London,1744,192.1428571,349.3506494 61 | London,1745,192.8571429,350.6493506 62 | London,1746,193.5714286,351.9480519 63 | London,1747,194.2857143,353.2467532 64 | London,1748,195,354.5454545 65 | London,1749,195.7142857,355.8441558 66 | London,1750,196.4285714,357.1428571 67 | London,1751,196.7857143,357.7922078 68 | London,1752,197.1428571,358.4415584 69 | London,1753,197.5,359.0909091 70 | London,1754,197.8571429,359.7402597 71 | London,1755,198.2142857,360.3896104 72 | London,1756,198.5714286,361.038961 73 | London,1757,198.9285714,361.6883117 74 | London,1758,199.2857143,362.3376623 75 | London,1759,199.6428571,362.987013 76 | London,1760,200,363.6363636 77 | London,1761,200.3571429,364.2857143 78 | London,1762,200.7142857,364.9350649 79 | London,1763,201.0714286,365.5844156 80 | London,1764,201.4285714,366.2337662 81 | London,1765,201.7857143,366.8831169 82 | London,1766,202.1428571,367.5324675 83 | London,1767,202.5,368.1818182 84 | London,1768,202.8571429,368.8311688 85 | London,1769,203.2142857,369.4805195 86 | London,1770,203.5714286,370.1298701 87 | London,1771,203.9285714,370.7792208 88 | London,1772,204.2857143,371.4285714 89 | London,1773,204.6428571,372.0779221 90 | London,1774,205,372.7272727 91 | London,1775,205.3571429,373.3766234 92 | London,1776,205.7142857,374.025974 93 | London,1777,206.0714286,374.6753247 94 | London,1778,206.4285714,375.3246753 95 | London,1779,206.7857143,375.974026 96 | London,1780,207.1428571,376.6233766 97 | London,1781,207.5,377.2727273 98 | London,1782,207.8571429,377.9220779 99 | London,1783,208.2142857,378.5714286 100 | London,1784,208.5714286,379.2207792 101 | London,1785,208.9285714,379.8701299 102 | London,1786,209.2857143,380.5194805 103 | London,1787,209.6428571,381.1688312 104 | London,1788,210,381.8181818 105 | London,1789,210.3571429,382.4675325 106 | London,1790,210.7142857,383.1168831 107 | London,1791,211.0714286,383.7662338 108 | London,1792,211.4285714,384.4155844 109 | London,1793,211.7857143,385.0649351 110 | London,1794,212.1428571,385.7142857 111 | London,1795,212.5,386.3636364 112 | London,1796,212.8571429,387.012987 113 | London,1797,213.2142857,387.6623377 114 | London,1798,213.5714286,388.3116883 115 | London,1799,213.9285714,388.961039 116 | London,1800,214.2857143,389.6103896 117 | London,1801,216.1428571,392.987013 118 | London,1802,218,396.3636364 119 | London,1803,219.8571429,399.7402597 120 | London,1804,221.7142857,403.1168831 121 | London,1805,223.5714286,406.4935065 122 | London,1806,225.4285714,409.8701299 123 | London,1807,227.2857143,413.2467532 124 | London,1808,229.1428571,416.6233766 125 | London,1809,231,420 126 | London,1810,232.8571429,423.3766234 127 | London,1811,234.7142857,426.7532468 128 | London,1812,236.5714286,430.1298701 129 | London,1813,238.4285714,433.5064935 130 | London,1814,240.2857143,436.8831169 131 | London,1815,242.1428571,440.2597403 132 | London,1816,244,443.6363636 133 | London,1817,245.8571429,447.012987 134 | London,1818,247.7142857,450.3896104 135 | London,1819,249.5714286,453.7662338 136 | London,1820,251.4285714,457.1428571 137 | London,1821,253.2857143,460.5194805 138 | London,1822,255.1428571,463.8961039 139 | London,1823,257,467.2727273 140 | London,1824,258.8571429,470.6493506 141 | London,1825,260.7142857,474.025974 142 | London,1826,262.5714286,477.4025974 143 | London,1827,264.4285714,480.7792208 144 | London,1828,266.2857143,484.1558442 145 | London,1829,268.1428571,487.5324675 146 | London,1830,270,490.9090909 147 | London,1831,271.8571429,494.2857143 148 | London,1832,273.7142857,497.6623377 149 | London,1833,275.5714286,501.038961 150 | London,1834,277.4285714,504.4155844 151 | London,1835,279.2857143,507.7922078 152 | London,1836,281.1428571,511.1688312 153 | London,1837,283,514.5454545 154 | London,1838,284.8571429,517.9220779 155 | London,1839,286.7142857,521.2987013 156 | London,1840,288.5714286,524.6753247 157 | London,1841,290.4285714,528.0519481 158 | London,1842,292.2857143,531.4285714 159 | London,1843,294.1428571,534.8051948 160 | London,1844,296,538.1818182 161 | London,1845,297.8571429,541.5584416 162 | London,1846,299.7142857,544.9350649 163 | London,1847,301.5714286,548.3116883 164 | London,1848,303.4285714,551.6883117 165 | London,1849,305.2857143,555.0649351 166 | London,1850,307.1428571,558.4415584 167 | London,1851,307.7142857,559.4805195 168 | London,1852,308.2857143,560.5194805 169 | London,1853,308.8571429,561.5584416 170 | London,1854,309.4285714,562.5974026 171 | London,1855,310,563.6363636 172 | London,1856,310.5714286,564.6753247 173 | London,1857,311.1428571,565.7142857 174 | London,1858,311.7142857,566.7532468 175 | London,1859,312.2857143,567.7922078 176 | London,1860,312.8571429,568.8311688 177 | London,1861,313.4285714,569.8701299 178 | London,1862,314,570.9090909 179 | London,1863,314.5714286,571.9480519 180 | London,1864,315.1428571,572.987013 181 | London,1865,315.7142857,574.025974 182 | London,1866,316.2857143,575.0649351 183 | London,1867,316.8571429,576.1038961 184 | London,1868,317.4285714,577.1428571 185 | London,1869,318,578.1818182 186 | London,1870,318.5714286,579.2207792 187 | London,1871,319.1428571,580.2597403 188 | London,1872,319.7142857,581.2987013 189 | London,1873,320.2857143,582.3376623 190 | London,1874,320.8571429,583.3766234 191 | London,1875,321.4285714,584.4155844 192 | London,1876,322.7678571,586.8506494 193 | London,1877,324.1071429,589.2857143 194 | London,1878,325.4464286,591.7207792 195 | London,1879,326.7857143,594.1558442 196 | London,1880,328.125,596.5909091 197 | London,1881,329.4642857,599.025974 198 | London,1882,330.8035714,601.461039 199 | London,1883,332.1428571,603.8961039 200 | London,1884,333.4821429,606.3311688 201 | London,1885,334.8214286,608.7662338 202 | London,1886,336.1607143,611.2012987 203 | London,1887,337.5,613.6363636 204 | London,1888,338.8392857,616.0714286 205 | London,1889,340.1785714,618.5064935 206 | London,1890,341.5178571,620.9415584 207 | London,1891,342.8571429,623.3766234 208 | London,1892,339.6825397,617.6046176 209 | London,1893,336.5079365,611.8326118 210 | London,1894,333.3333333,606.0606061 211 | London,1895,330.1587302,600.2886003 212 | London,1896,326.984127,594.5165945 213 | London,1897,323.8095238,588.7445887 214 | London,1898,320.6349206,582.972583 215 | London,1899,317.4603175,577.2005772 216 | London,1900,314.2857143,571.4285714 217 | London,1901,314.1142857,571.1168831 218 | London,1902,313.9428571,570.8051948 219 | London,1903,313.7714286,570.4935065 220 | London,1904,313.6,570.1818182 221 | London,1905,313.4285714,569.8701299 222 | London,1906,313.2571429,569.5584416 223 | London,1907,313.0857143,569.2467532 224 | London,1908,312.9142857,568.9350649 225 | London,1909,312.7428571,568.6233766 226 | London,1910,312.5714286,568.3116883 227 | London,1911,312.4,568 228 | London,1912,312.2285714,567.6883117 229 | London,1913,310,563.6363636 230 | London,1914,305.4545455,555.3719008 231 | London,1915,300.9090909,547.107438 232 | London,1916,296.3636364,538.8429752 233 | London,1917,291.8181818,530.5785124 234 | London,1918,287.2727273,522.3140496 235 | London,1919,282.7272727,514.0495868 236 | London,1920,278.1818182,505.785124 237 | London,1921,273.6363636,497.5206612 238 | London,1922,269.0909091,489.2561983 239 | London,1923,264.5454545,480.9917355 240 | London,1924,260,472.7272727 241 | London,1925,225,409.0909091 242 | London,1926,200,363.6363636 243 | London,1927,198,360 244 | London,1928,197,358.1818182 245 | London,1929,198,360 246 | London,1930,199,361.8181818 247 | London,1931,200,363.6363636 248 | London,1932,210,381.8181818 249 | London,1933,220,400 250 | London,1934,222,403.6363636 251 | London,1935,225,409.0909091 252 | London,1936,223,405.4545455 253 | London,1937,221,401.8181818 254 | London,1938,220,400 255 | London,1939,211.25,384.0909091 256 | London,1940,202.5,368.1818182 257 | London,1941,193.75,352.2727273 258 | London,1942,185,336.3636364 259 | London,1943,176.25,320.4545455 260 | London,1944,167.5,304.5454545 261 | London,1945,158.75,288.6363636 262 | London,1946,150,272.7272727 263 | London,1947,141.25,256.8181818 264 | London,1948,132.5,240.9090909 265 | London,1949,123.75,225 266 | London,1950,115,209.0909091 267 | London,1951,112.6709302,204.8562368 268 | London,1952,110.3418605,200.6215645 269 | London,1953,108.0127907,196.3868922 270 | London,1954,105.6837209,192.1522199 271 | London,1955,103.3546512,187.9175476 272 | London,1956,101.0255814,183.6828753 273 | London,1957,98.69651163,179.448203 274 | London,1958,96.36744186,175.2135307 275 | London,1959,94.03837209,170.9788584 276 | London,1960,91.70930233,166.744186 277 | London,1961,89.38023256,162.5095137 278 | London,1962,87.05116279,158.2748414 279 | London,1963,84.72209302,154.0401691 280 | London,1964,82.39302326,149.8054968 281 | London,1965,80.06395349,145.5708245 282 | London,1966,77.73488372,141.3361522 283 | London,1967,75.40581395,137.1014799 284 | London,1968,73.07674419,132.8668076 285 | London,1969,70.74767442,128.6321353 286 | London,1970,68.41860465,124.397463 287 | London,1971,66.08953488,120.1627907 288 | London,1972,63.76046512,115.9281184 289 | London,1973,61.43139535,111.6934461 290 | London,1974,59.10232558,107.4587738 291 | London,1975,56.77325581,103.2241015 292 | London,1976,54.44418605,98.98942918 293 | London,1977,52.11511628,94.75475687 294 | London,1978,49.78604651,90.52008457 295 | London,1979,47.45697674,86.28541226 296 | London,1980,45.12790698,82.05073996 297 | London,1981,42.79883721,77.81606765 298 | London,1982,40.46976744,73.58139535 299 | London,1983,38.14069767,69.34672304 300 | London,1984,35.81162791,65.11205074 301 | London,1985,33.48255814,60.87737844 302 | London,1986,31.15348837,56.64270613 303 | London,1987,28.8244186,52.40803383 304 | London,1988,26.49534884,48.17336152 305 | London,1989,24.16627907,43.93868922 306 | London,1990,21.8372093,39.70401691 307 | London,1991,19.50813953,35.46934461 308 | London,1992,17.17906977,31.2346723 309 | London,1993,14.85,27 310 | London,1994,14.3,26 311 | London,1995,14.3,26 312 | London,1996,14.3,26 313 | London,1997,13.75,25 314 | London,1998,13.75,25 315 | London,1999,13.75,25 316 | London,2000,13.75,25 317 | London,2001,,23.22 318 | London,2002,,23 319 | London,2003,,24.79 320 | London,2004,,22.03 321 | London,2005,,22.35 322 | London,2006,,23.82 323 | London,2007,,21.56 324 | London,2008,,18.81 325 | London,2009,,19 326 | London,2010,,20 327 | London,2011,,20 328 | London,2012,,17 329 | London,2013,,17 330 | London,2014,,17 331 | London,2015,,15 332 | London,2016,,16 333 | -------------------------------------------------------------------------------- /chapter3/Preprocessing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6a06a10a-ee82-40ce-8821-3da680af4b95", 6 | "metadata": {}, 7 | "source": [ 8 | "\"Open" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "614a3348-214c-4472-969d-e7db10664075", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "pip install -U tsfresh workalendar astral \"featuretools[tsfresh]\" sktime" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "9a82dfb1-f90b-4f38-91fb-34756dece1e5", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import datetime\n", 29 | "import pandas as pd\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "import seaborn as sns" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "96db671d-bfdb-48f2-b572-f1d35e2b14a1", 37 | "metadata": {}, 38 | "source": [ 39 | "# Transformations" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "0bd9698d-3905-42a6-92b4-1d3d6c80f79f", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import numpy as np\n", 50 | "\n", 51 | "np.random.seed(0)\n", 52 | "pts = 10000\n", 53 | "vals = np.random.lognormal(0, 1.0, pts)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "55097e4a-c0d8-4cd9-a7e5-5e59738c0475", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "plt.hist(vals, bins=20, density=True)\n", 64 | "plt.yscale(\"log\")\n", 65 | "plt.ylabel(\"frequency\")\n", 66 | "plt.xlabel(\"value range\");" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "1a0306b4-c03d-46e4-9ed1-1c8e30c3f7bb", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "from sklearn.preprocessing import minmax_scale\n", 77 | "from sklearn.preprocessing import StandardScaler\n", 78 | "from scipy.stats import normaltest\n", 79 | "\n", 80 | "vals_mm = minmax_scale(vals)\n", 81 | "scaler = StandardScaler()\n", 82 | "vals_ss = scaler.fit_transform(vals.reshape(-1, 1))\n", 83 | "_, p = normaltest(vals_ss.squeeze())\n", 84 | "print(f\"significance: {p:.2f}\")" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "f32df59b-5443-48b2-b2ef-07a4f79bb869", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "_, p = normaltest(vals_mm.squeeze())\n", 95 | "print(f\"significance: {p:.2f}\")" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "9fc25e7e-8bca-44ad-8d14-100cd62ee1ec", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "plt.scatter(vals, vals_ss, alpha=0.3)\n", 106 | "plt.ylabel(\"standard scaled\")\n", 107 | "plt.xlabel(\"original\");" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "e88633cd-5542-4366-84bd-ef59e46f253a", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "from statsmodels.stats.diagnostic import kstest_normal\n", 118 | "\n", 119 | "log_transformed = np.log(vals)\n", 120 | "_, p = kstest_normal(log_transformed) # stats.normaltest\n", 121 | "print(f\"significance: {p:.2f}\")" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "8ed96d42-2023-463d-82cf-250b5bb69733", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "np.std(log_transformed)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "0f9faf67-1496-4383-a93e-abfb67e4dc9c", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "np.mean(log_transformed)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "id": "ae3a9a4b-f856-468e-8f7a-07178544f52d", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "plt.hist(log_transformed, bins=20, density=True)\n", 152 | "#plt.yscale(\"log\")\n", 153 | "plt.ylabel(\"frequency\")\n", 154 | "plt.xlabel(\"value range\");" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "id": "89f9f296-456d-4789-88c0-2e2eb35bc8a8", 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "from scipy.stats import boxcox\n", 165 | "vals_bc = boxcox(vals, 0.0)\n", 166 | "_, p = normaltest(vals_bc)\n", 167 | "print(f\"significance: {p:.2f}\")" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "40697ad8-fbc7-4c62-ab4a-64f813c13ec6", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "plt.hist(vals_bc, bins=20, density=True)\n", 178 | "plt.ylabel(\"frequency\")\n", 179 | "plt.xlabel(\"value range\");" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "id": "da2aa62a-fe40-49e2-b2d9-ebcb1a91cda8", 185 | "metadata": {}, 186 | "source": [ 187 | "# Imputation" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "id": "40bddaee-38d3-4d40-9773-e6a4721e4b62", 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "import numpy as np\n", 198 | "from sklearn.impute import SimpleImputer\n", 199 | "imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')\n", 200 | "imp_mean.fit([[7, 2, 3], [4, np.nan, 6], [10, 5, 9]])\n", 201 | "SimpleImputer()\n", 202 | "df = [[np.nan, 2, 3], [4, np.nan, 6], [10, np.nan, 9]]\n", 203 | "print(imp_mean.transform(df))\n" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "id": "503b0b9a-71a9-475c-9906-e59fb9853932", 209 | "metadata": {}, 210 | "source": [ 211 | "# Derived Date Features" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "id": "5aa0e489-638d-46e7-8850-518c616e22e5", 217 | "metadata": {}, 218 | "source": [ 219 | "## Holidays" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "9efee10c-2d3d-4265-886a-50e87e380a97", 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "from workalendar.europe.united_kingdom import UnitedKingdom\n", 230 | "UnitedKingdom().holidays()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "id": "1539b652-a1f8-41a6-8410-cccad3d5edb6", 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "from typing import List\n", 241 | "from dateutil.relativedelta import relativedelta, TH\n", 242 | "import datetime\n", 243 | "from workalendar.usa import California\n", 244 | "\n", 245 | "def create_custom_holidays(year) -> List:\n", 246 | " custom_holidays = California().holidays()\n", 247 | " custom_holidays.append((\n", 248 | " (datetime.datetime(year, 11, 1) + relativedelta(weekday=TH(+4)) + datetime.timedelta(days=1)).date(),\n", 249 | " \"Black Friday\"\n", 250 | " ))\n", 251 | " return {k: v for (k, v) in custom_holidays}\n", 252 | "\n", 253 | "custom_holidays = create_custom_holidays(2021)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "bec50434-e7b4-409b-8f37-9b400ce0a2bf", 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "custom_holidays" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "id": "3e8c9a76-0ec7-40c8-8eda-e7e480281501", 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "def is_holiday(current_date: datetime.date):\n", 274 | " \"\"\"Determine if we have a holiday.\"\"\"\n", 275 | " return custom_holidays.get(current_date, False)\n", 276 | "\n", 277 | "today = datetime.date(2021, 4, 11)\n", 278 | "is_holiday(today)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "id": "ddb63805-b782-4944-a88b-4996783177f7", 284 | "metadata": {}, 285 | "source": [ 286 | "## Date Annotations" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "id": "eab731b5-fa14-4c8a-bd39-0a805c36aa80", 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "import calendar\n", 297 | "\n", 298 | "calendar.monthrange(2021, 1)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "id": "f9d4a8b3-fea7-4e75-86e4-fff5cc3ffa9e", 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "from datetime import date\n", 309 | "def year_anchor(current_date: datetime.date):\n", 310 | " return (\n", 311 | " (current_date - date(current_date.year, 1, 1)).days,\n", 312 | " (date(current_date.year, 12, 31) - current_date).days,\n", 313 | " )\n", 314 | "\n", 315 | "year_anchor(today)\n" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "id": "d6a4e6d4-3a68-48a5-b564-29e9f7b0b72d", 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "def month_anchor(current_date: datetime.date):\n", 326 | " last_day = calendar.monthrange(current_date.year, current_date.month)[0]\n", 327 | " \n", 328 | " return (\n", 329 | " (current_date - datetime.date(current_date.year, current_date.month, 1)).days,\n", 330 | " (current_date - datetime.date(current_date.year, current_date.month, last_day)).days,\n", 331 | " )\n", 332 | "\n", 333 | "month_anchor(today)\n" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "id": "5453766c-e8d5-4fd8-88a8-4b87602c506f", 339 | "metadata": {}, 340 | "source": [ 341 | "## Paydays" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "id": "5fa16210-e016-4cb6-85b9-84f964255f81", 348 | "metadata": {}, 349 | "outputs": [], 350 | "source": [ 351 | "def get_last_friday(current_date: datetime.date, weekday=calendar.FRIDAY):\n", 352 | " return max(week[weekday]\n", 353 | " for week in calendar.monthcalendar(\n", 354 | " current_date.year, current_date.month\n", 355 | " ))\n", 356 | "\n", 357 | "get_last_friday(today)\n" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "id": "2a2cca85-1e9f-4e3c-ad34-a6b1122e217a", 363 | "metadata": {}, 364 | "source": [ 365 | "## Seasons" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "id": "6abd71ed-2478-4c8a-a578-1b961e8a1de4", 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "YEAR = 2021\n", 376 | "seasons = [\n", 377 | " ('winter', (date(YEAR, 1, 1), date(YEAR, 3, 20))),\n", 378 | " ('spring', (date(YEAR, 3, 21), date(YEAR, 6, 20))),\n", 379 | " ('summer', (date(YEAR, 6, 21), date(YEAR, 9, 22))),\n", 380 | " ('autumn', (date(YEAR, 9, 23), date(YEAR, 12, 20))),\n", 381 | " ('winter', (date(YEAR, 12, 21), date(YEAR, 12, 31)))\n", 382 | "]\n", 383 | "\n", 384 | "def is_in_interval(current_date: datetime.date, seasons):\n", 385 | " return next(season for season, (start, end) in seasons\n", 386 | " if start <= current_date.replace(year=YEAR) <= end)\n", 387 | " \n", 388 | "is_in_interval(today, seasons)\n" 389 | ] 390 | }, 391 | { 392 | "cell_type": "markdown", 393 | "id": "bff1e4f1-03a6-4710-8fcc-93ca10e529c1", 394 | "metadata": {}, 395 | "source": [ 396 | "## Sun and Moon" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "id": "2e79e8c5-a2c9-4ce0-bc55-b15c8564a877", 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [ 406 | "!pip install astral" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "id": "92c55bb2-caf3-4a70-b51a-84bfc073b2ce", 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "from astral.sun import sun\n", 417 | "from astral import LocationInfo\n", 418 | "CITY = LocationInfo(\"London\", \"England\", \"Europe/London\", 51.5, -0.116)\n", 419 | "def get_sunrise_dusk(current_date: datetime.date, city_name='London'):\n", 420 | " s = sun(CITY.observer, date=current_date)\n", 421 | " sunrise = s['sunrise']\n", 422 | " dusk = s['dusk']\n", 423 | " return (sunrise - dusk).seconds / 3600\n", 424 | "\n", 425 | "get_sunrise_dusk(today)\n" 426 | ] 427 | }, 428 | { 429 | "cell_type": "markdown", 430 | "id": "98f041ac-5c80-408e-909a-3364975f6be0", 431 | "metadata": {}, 432 | "source": [ 433 | "## Business Days" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": null, 439 | "id": "aac6eeef-2d1b-4bf2-a2cb-75a9c4648a93", 440 | "metadata": {}, 441 | "outputs": [], 442 | "source": [ 443 | "def get_business_days(current_date: datetime.date):\n", 444 | " last_day = calendar.monthrange(current_date.year, current_date.month)[1]\n", 445 | " rng = pd.date_range(current_date.replace(day=1), periods=last_day, freq='D')\n", 446 | " business_days = pd.bdate_range(rng[0], rng[-1])\n", 447 | " return len(business_days), last_day - len(business_days)\n", 448 | "\n", 449 | "get_business_days(date.today())\n" 450 | ] 451 | }, 452 | { 453 | "cell_type": "markdown", 454 | "id": "91f21e9b-c326-472b-a59a-85a908502fe4", 455 | "metadata": {}, 456 | "source": [ 457 | "# Automated Feature Extraction" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "id": "7df98e51-5f84-49fd-b13c-a7a488bcc4ea", 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [ 467 | "import featuretools as ft\n", 468 | "from featuretools.primitives import Minute, Hour, Day, Month, Year, Weekday\n", 469 | "\n", 470 | "data = pd.DataFrame(\n", 471 | " {'Time': ['2014-01-01 01:41:50',\n", 472 | " '2014-01-01 02:06:50',\n", 473 | " '2014-01-01 02:31:50',\n", 474 | " '2014-01-01 02:56:50',\n", 475 | " '2014-01-01 03:21:50'],\n", 476 | " 'Target': [0, 0, 0, 0, 1]}\n", 477 | ") \n", 478 | "data['index'] = data.index\n", 479 | "es = ft.EntitySet('My EntitySet')\n", 480 | "es.entity_from_dataframe(\n", 481 | " entity_id='main_data_table',\n", 482 | " index='index',\n", 483 | " dataframe=data,\n", 484 | " time_index='Time'\n", 485 | ")\n", 486 | "fm, features = ft.dfs(\n", 487 | " entityset=es, \n", 488 | " target_entity='main_data_table', \n", 489 | " trans_primitives=[Minute, Hour, Day, Month, Year, Weekday]\n", 490 | ")\n" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "id": "39b85175-58d9-44bc-8374-09af58f245d3", 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [ 500 | "fm" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "id": "11f5581e-0ce2-4a90-9f0c-5aeefbaa2096", 507 | "metadata": {}, 508 | "outputs": [], 509 | "source": [ 510 | "from tsfresh.feature_extraction import extract_features\n", 511 | "from tsfresh.feature_extraction import ComprehensiveFCParameters\n", 512 | "\n", 513 | "settings = ComprehensiveFCParameters()\n", 514 | "extract_features(data, column_id='Time', default_fc_parameters=settings)\n" 515 | ] 516 | }, 517 | { 518 | "cell_type": "markdown", 519 | "id": "d78b955f-04a1-4751-8142-e3b7c8002278", 520 | "metadata": {}, 521 | "source": [ 522 | "## ROCKET" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": null, 528 | "id": "cc173b6d-be20-424a-b8b9-3ae617cf5f39", 529 | "metadata": {}, 530 | "outputs": [], 531 | "source": [ 532 | "from sktime.datasets import load_arrow_head\n", 533 | "from sktime.utils.data_processing import from_nested_to_2d_array\n", 534 | "# please note that this import changes in version 0.8:\n", 535 | "# from sktime.datatypes._panel._convert import from_nested_to_2d_array\n", 536 | "\n", 537 | "X_train, y_train = load_arrow_head(split=\"train\", return_X_y=True)\n", 538 | "from_nested_to_2d_array(X_train).head()" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": null, 544 | "id": "b865bd92-8ba9-46f7-a15d-4745ad8bf5ea", 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "from sktime.transformations.panel.rocket import Rocket\n", 549 | "rocket = Rocket(num_kernels=1000)\n", 550 | "rocket.fit(X_train)\n", 551 | "X_train_transform = rocket.transform(X_train)\n" 552 | ] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "id": "474f4e40-d2b1-43da-98d9-767d922c6f91", 557 | "metadata": {}, 558 | "source": [ 559 | "## Shapelets" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": null, 565 | "id": "0d99e98f-17ec-4b01-b8e0-01b062e1dfbb", 566 | "metadata": {}, 567 | "outputs": [], 568 | "source": [ 569 | "from sktime.transformations.panel.shapelets import ContractedShapeletTransform\n", 570 | "shapelets_transform = ContractedShapeletTransform(\n", 571 | " time_contract_in_mins=1,\n", 572 | " num_candidates_to_sample_per_case=10,\n", 573 | " verbose=0,\n", 574 | ")\n", 575 | "shapelets_transform.fit(X_train, y_train)\n" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": null, 581 | "id": "a5468a48-7669-49e9-87e3-827f25cc8b2b", 582 | "metadata": {}, 583 | "outputs": [], 584 | "source": [ 585 | "X_train_transform = shapelets_transform.transform(X_train)" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": null, 591 | "id": "6812d1d2-7de2-48ee-91c5-70741ff8a320", 592 | "metadata": {}, 593 | "outputs": [], 594 | "source": [ 595 | "X_train_transform" 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "execution_count": null, 601 | "id": "6d3410b6-a2ab-40fb-a942-b1fcfd423512", 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [] 605 | } 606 | ], 607 | "metadata": { 608 | "kernelspec": { 609 | "display_name": "Python 3", 610 | "language": "python", 611 | "name": "python3" 612 | }, 613 | "language_info": { 614 | "codemirror_mode": { 615 | "name": "ipython", 616 | "version": 3 617 | }, 618 | "file_extension": ".py", 619 | "mimetype": "text/x-python", 620 | "name": "python", 621 | "nbconvert_exporter": "python", 622 | "pygments_lexer": "ipython3", 623 | "version": "3.8.8" 624 | } 625 | }, 626 | "nbformat": 4, 627 | "nbformat_minor": 5 628 | } 629 | -------------------------------------------------------------------------------- /chapter5/Forecasting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6a06a10a-ee82-40ce-8821-3da680af4b95", 6 | "metadata": {}, 7 | "source": [ 8 | "\"Open" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "614a3348-214c-4472-969d-e7db10664075", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "!pip install statsmodels yfinance pmdarima arch" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "cb11cb6d-0e34-418b-9085-2b5d88ed343d", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import matplotlib.pyplot as plt\n", 29 | "import seaborn as sns\n", 30 | "import matplotlib as mpl\n", 31 | "import pandas as pd\n", 32 | "\n", 33 | "plt.style.use('seaborn-whitegrid')\n", 34 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n", 35 | "plt.rcParams[\"font.size\"] = \"17\"\n", 36 | "mpl.rcParams['lines.linewidth'] = 2\n", 37 | "mpl.rcParams['lines.markersize'] = 1\n", 38 | "# plt.style.use('.matplotlibrc')" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "d90f30a7-9ad6-4773-bf95-d6a07298757d", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "from datetime import datetime\n", 49 | "import yfinance as yf\n", 50 | " \n", 51 | "start_date = datetime(2005, 1, 1)\n", 52 | "end_date = datetime(2021, 1, 1)\n", 53 | "\n", 54 | "df = yf.download(\n", 55 | " 'SPY',\n", 56 | " start=start_date,\n", 57 | " end = end_date\n", 58 | ")" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "f11d3381-2b1c-400a-98d9-b41d53061e8b", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "plt.figure(figsize = (12, 6))\n", 69 | "plt.title('Opening Prices between {} and {}'.format(\n", 70 | " start_date.date().isoformat(),\n", 71 | " end_date.date().isoformat()\n", 72 | "))\n", 73 | "df['Open'].plot()\n", 74 | "plt.ylabel('Price')\n", 75 | "plt.xlabel('Date');" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "c6b1629a-92eb-4c52-a523-24427d120786", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "df.head()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "6289d877-f934-4543-9849-87bc3ef8a3c5", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "df1 = df.reset_index().resample('W', on=\"Date\")['Open'].mean()" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "6db93897-2993-4490-9558-dec619e30f85", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# some years have 53 weeks. We can't handle that, so we'll get rid of the 53rd week.\n", 106 | "df1 = df1[df1.index.week < 53]" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "567145da-61fc-4bc2-859c-3e36ad5c7421", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# final check: \n", 117 | "df1.index.week.value_counts().plot.bar()" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "3df5304c-7c73-45bb-8d4d-94af87ac8dc3", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "df1 = df1[~df1.isnull()]" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "171c9318-b8cd-46d0-8018-0cac48888d96", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "# let's fix the frequency:\n", 138 | "df1 = df1.asfreq('W').fillna(method='ffill')" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "38a4ba79-453a-4e18-bbb5-652838270b90", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "df1.index.freq" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "c70e30be-4481-4aa4-9142-b3614dfb39e9", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "import statsmodels.api as sm\n", 159 | "\n", 160 | "fig, axs = plt.subplots(2)\n", 161 | "sm.graphics.tsa.plot_pacf(df1, lags=20, ax=axs[0])\n", 162 | "axs[0].set_ylabel('R')\n", 163 | "axs[0].set_xlabel('Lag')\n", 164 | "sm.graphics.tsa.plot_acf(df1, lags=20, ax=axs[1]);\n", 165 | "axs[1].set_ylabel('R')\n", 166 | "axs[1].set_xlabel('Lag')\n", 167 | "fig.tight_layout()" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "e88633cd-5542-4366-84bd-ef59e46f253a", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "from statsmodels.tsa.seasonal import seasonal_decompose\n", 178 | "result = seasonal_decompose(df1, model='additive', period=52)\n", 179 | "result.plot();" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "id": "da2aa62a-fe40-49e2-b2d9-ebcb1a91cda8", 185 | "metadata": {}, 186 | "source": [ 187 | "# Finding a value for d" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "id": "fbeefe1e-67ce-403d-a990-a7cf19aa9318", 193 | "metadata": {}, 194 | "source": [ 195 | "We are using the ARCH package which has more convenient versions of both the ADF and the KPSS tests." 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "18eb1bc7-85a8-4380-8346-4c28aeb16116", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "from arch.unitroot import KPSS, ADF\n", 206 | "\n", 207 | "ADF(df1)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "id": "40bddaee-38d3-4d40-9773-e6a4721e4b62", 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "from pmdarima.arima.utils import ndiffs\n", 218 | "\n", 219 | "# ADF Test:\n", 220 | "ndiffs(df1, test='adf') # 1; same values for the KPSS and the PP test" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "id": "6d3410b6-a2ab-40fb-a942-b1fcfd423512", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "# what happens if we forget differencing?\n", 231 | "# We get a helful warning: 'Non-stationary starting autoregressive parameters'\n", 232 | "mod = sm.tsa.arima.ARIMA(endog=df1, order=(1, 0, 0))\n", 233 | "res = mod.fit()\n", 234 | "print(res.summary())" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "id": "12cf8ad9-bcd8-42c0-ad6f-164ff81d7dec", 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "# let's try again and this time, we'll take into account the seasonality:\n", 245 | "from statsmodels.tsa.forecasting.stl import STLForecast\n", 246 | "\n", 247 | "mod = STLForecast(df1, sm.tsa.arima.ARIMA, model_kwargs=dict(order=(1, 1, 0), trend=\"t\"))\n", 248 | "res = mod.fit().model_result\n", 249 | "print(res.summary())" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "id": "454a5679-45ac-45b5-860a-8a4685d7e9eb", 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "# doing a forecast:\n", 260 | "STEPS = 20\n", 261 | "forecasts_df = res.get_forecast(steps=STEPS).summary_frame() \n", 262 | "ax = df1.plot(figsize=(12, 6))\n", 263 | "plt.ylabel('SPY')\n", 264 | "forecasts_df['mean'].plot(style='k--')\n", 265 | "ax.fill_between(\n", 266 | " forecasts_df.index,\n", 267 | " forecasts_df['mean_ci_lower'],\n", 268 | " forecasts_df['mean_ci_upper'],\n", 269 | " color='k',\n", 270 | " alpha=0.1\n", 271 | ")" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "id": "90cc47f4-1444-40af-9d84-e64e4518e0e0", 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "forecasts = []\n", 282 | "qs = []\n", 283 | "for q in range(0, 30, 10):\n", 284 | " mod = STLForecast(df1, sm.tsa.arima.ARIMA, model_kwargs=dict(order=(0, 1, q), trend=\"t\"))\n", 285 | " res = mod.fit().model_result\n", 286 | " print(f\"aic ({q}): {res.aic}\")\n", 287 | " forecasts.append(\n", 288 | " res.get_forecast(steps=STEPS).summary_frame()['mean']\n", 289 | " )\n", 290 | " qs.append(q)\n", 291 | "\n", 292 | "forecasts_df = pd.concat(forecasts, axis=1)\n", 293 | "forecasts_df.columns = qs" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "id": "1afc0557-d41d-4540-95c5-9244b1ceca24", 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "# plotting the three forecasts:\n", 304 | "ax = df1.plot(figsize=(12, 6))\n", 305 | "plt.ylabel('SPY')\n", 306 | "forecasts_df.plot(ax=ax)" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "id": "f9b65061-b5ca-453f-92a9-b5c7341f8932", 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "mod = sm.tsa.ExponentialSmoothing(\n", 317 | " endog=df1, trend='add', seasonal_periods=52, use_boxcox=True, initialization_method=\"heuristic\"\n", 318 | " )\n", 319 | "res = mod.fit()\n", 320 | "print(res.summary())" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "id": "92209cc0-b046-4c5b-a9ee-4a3afb5d615a", 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "forecasts = pd.Series(res.forecast(steps=STEPS))\n", 331 | "ax = df1.plot(figsize=(12, 6))\n", 332 | "plt.ylabel('SPY')\n", 333 | "forecasts.plot(style='k--')" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "id": "c79a4000-663b-4264-a0f2-c00427ccfc11", 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "from statsmodels.tsa.forecasting.theta import ThetaModel\n", 344 | "\n", 345 | "train_length = int(len(df1) * 0.8)\n", 346 | "tm = ThetaModel(df1[:train_length], method=\"auto\", deseasonalize=True)\n", 347 | "res = tm.fit()\n", 348 | "forecasts = res.forecast(steps=len(df1)-train_length)\n", 349 | "ax = df1.plot(figsize=(12, 6))\n", 350 | "plt.ylabel('SPY')\n", 351 | "forecasts.plot(style='k--')" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "id": "949b7c6c-1326-44a6-8e5d-c58d0f4191e7", 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "from sklearn import metrics\n", 362 | "\n", 363 | "metrics.mean_squared_error(forecasts, df1[train_length:], squared=False)" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "id": "9733e78f-ee79-44c4-a26e-496670a68cad", 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "Python 3", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.8.8" 392 | } 393 | }, 394 | "nbformat": 4, 395 | "nbformat_minor": 5 396 | } 397 | -------------------------------------------------------------------------------- /chapter6/Change-Points_Anomalies.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# Anomaly detection" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "!pip install alibi_detect" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "from alibi_detect.datasets import fetch_kdd\n", 33 | "\n", 34 | "intrusions = fetch_kdd()" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "intrusions[\"target\"].sum() / len(intrusions[\"target\"])" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "intrusions[\"feature_names\"]" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "scores = od.score(intrusions[\"data\"][:, 0])" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "import pandas as pd\n", 71 | "\n", 72 | "pd.Series(intrusions[\"data\"][:, 0]).plot();" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "from alibi_detect.od import SpectralResidual\n", 82 | "\n", 83 | "od = SpectralResidual(\n", 84 | " threshold=1.0, window_amp=20, window_local=20, n_est_points=10, n_grad_points=5\n", 85 | ")\n", 86 | "intrusion_outliers = od.predict(intrusions[\"data\"][:,0])" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "import matplotlib\n", 96 | "\n", 97 | "ax = pd.Series(intrusions[\"data\"][:, 0], name=\"data\").plot(\n", 98 | " legend=False, figsize=(12, 6)\n", 99 | ")\n", 100 | "ax2 = ax.twinx()\n", 101 | "ax = pd.Series(scores, name=\"scores\").plot(\n", 102 | " ax=ax2, legend=False, color=\"r\", marker=matplotlib.markers.CARETDOWNBASE\n", 103 | ")\n", 104 | "ax.figure.legend(bbox_to_anchor=(1, 1), loc=\"upper left\");" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "# Change point detection" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "!pip install ruptures" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "import matplotlib.pyplot as plt\n", 130 | "import numpy as np\n", 131 | "import ruptures as rpt\n", 132 | "\n", 133 | "plt.style.use(\"seaborn-whitegrid\")\n", 134 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n", 135 | "plt.rcParams[\"font.size\"] = \"17\"" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "signal, bkps = rpt.pw_constant(\n", 145 | " n_samples=500, n_features=3, n_bkps=2, noise_std=5.0, delta=(1, 20)\n", 146 | ")" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "rpt.display(signal, bkps)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "signal.shape" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "bkps" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "# \"l1\", \"rbf\", \"linear\", \"normal\", \"ar\"\n", 183 | "algo = rpt.Binseg(model=\"l1\").fit(signal)\n", 184 | "my_bkps = algo.predict(n_bkps=3)\n", 185 | "\n", 186 | "# show results\n", 187 | "rpt.show.display(signal, bkps, my_bkps, figsize=(10, 6))" 188 | ] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 3 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython3", 207 | "version": "3.8.8" 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 4 212 | } 213 | -------------------------------------------------------------------------------- /chapter7/KNN_with_dynamic_DTW.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "colab": { 15 | "base_uri": "https://localhost:8080/" 16 | }, 17 | "id": "r1kjFHY6xlbo", 18 | "outputId": "4b3f00df-2b21-4b3f-809f-7df77ada4790" 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "!pip install tsfresh \"statsmodels<=0.12\"" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "colab": { 30 | "base_uri": "https://localhost:8080/" 31 | }, 32 | "id": "Pg0UkW57-VlF", 33 | "outputId": "8275dc45-501c-40cd-ec9e-19d7d42935ad" 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "!pip install tslearn" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "id": "5VtsgQxuzMUt" 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "import pandas as pd\n", 49 | "import matplotlib.pyplot as plt\n", 50 | "\n", 51 | "plt.style.use('seaborn-whitegrid')\n", 52 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n", 53 | "plt.rcParams[\"font.size\"] = \"17\"" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/" 62 | }, 63 | "id": "iPQPBG849jB2", 64 | "outputId": "13f96036-b472-4445-d0ca-75d02a48016a" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "from tsfresh.examples import load_robot_execution_failures\n", 69 | "from tsfresh.examples.robot_execution_failures import download_robot_execution_failures\n", 70 | "\n", 71 | "download_robot_execution_failures()\n", 72 | "df_ts, y = load_robot_execution_failures()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "colab": { 80 | "base_uri": "https://localhost:8080/" 81 | }, 82 | "id": "UOHxdjmAWKZc", 83 | "outputId": "67d575b7-d649-40ab-8edd-2d0c7950a6a2" 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "from tsfresh import extract_features\n", 88 | "from tsfresh import select_features\n", 89 | "from tsfresh.utilities.dataframe_functions import impute\n", 90 | "\n", 91 | "extracted_features = impute(extract_features(df_ts, column_id=\"id\", column_sort=\"time\"))\n", 92 | "features_filtered = select_features(extracted_features, y)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": { 99 | "colab": { 100 | "base_uri": "https://localhost:8080/" 101 | }, 102 | "id": "7D0JxrCpQL6x", 103 | "outputId": "b75b4054-afab-4f55-8c61-9a6c7a7f0d8d" 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "from tsfresh.transformers import RelevantFeatureAugmenter\n", 108 | "import pandas as pd\n", 109 | "\n", 110 | "X = pd.DataFrame(index=y.index)\n", 111 | "TRAINING_SIZE = (len(X) // 10) * 8\n", 112 | "augmenter = RelevantFeatureAugmenter(column_id='id', column_sort='time')\n", 113 | "augmenter.set_timeseries_container(df_ts[:TRAINING_SIZE])\n", 114 | "augmenter.fit(X[:TRAINING_SIZE], y[:TRAINING_SIZE])\n", 115 | "X_transformed = augmenter.transform(X)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "colab": { 123 | "base_uri": "https://localhost:8080/" 124 | }, 125 | "id": "jzaRzQc63Iiv", 126 | "outputId": "384b8608-39d3-4b52-8e23-4baefd4c05cc" 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "from sklearn.model_selection import TimeSeriesSplit, GridSearchCV\n", 131 | "from tslearn.neighbors import KNeighborsTimeSeriesClassifier\n", 132 | "\n", 133 | "knn = KNeighborsTimeSeriesClassifier()\n", 134 | "param_search = {\n", 135 | " 'metric' : ['dtw'], # ctw', 'dtw', \"softdtw\"], # TSLEARN_VALID_METRICS}\n", 136 | " 'n_neighbors': [1, 2, 3]\n", 137 | "}\n", 138 | "# 'param_grid': {'metric': ['ctw', 'dtw', 'gak', 'sax', 'softdtw', 'lcss']},\n", 139 | "tscv = TimeSeriesSplit(n_splits=2)\n", 140 | "\n", 141 | "gsearch = GridSearchCV(\n", 142 | " estimator=knn,\n", 143 | " cv=tscv,\n", 144 | " param_grid=param_search\n", 145 | ")\n", 146 | "gsearch.fit(\n", 147 | " features_filtered,\n", 148 | " y\n", 149 | ")" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": { 156 | "id": "j7XePgIYK1WQ" 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "# Adapted from comments on this discussion thread on stackoverflow: https://stackoverflow.com/questions/37161563/how-to-graph-grid-scores-from-gridsearchcv\n", 161 | "import seaborn as sns\n", 162 | "import pandas as pd\n", 163 | "\n", 164 | "def plot_cv_results(cv_results, param_x, param_z, metric='mean_test_score'):\n", 165 | " \"\"\"\n", 166 | " cv_results - cv_results_ attribute of a GridSearchCV instance (or similar)\n", 167 | " param_x - name of grid search parameter to plot on x axis\n", 168 | " param_z - name of grid search parameter to plot by line color\n", 169 | " \"\"\"\n", 170 | " cv_results = pd.DataFrame(cv_results)\n", 171 | " col_x = 'param_' + param_x\n", 172 | " col_z = 'param_' + param_z\n", 173 | " fig, ax = plt.subplots(1, 1, figsize=(11, 8))\n", 174 | " sns.pointplot(x=col_x, y=metric, hue=col_z, data=cv_results, ci=99, n_boot=64, ax=ax)\n", 175 | " ax.set_title(\"CV Grid Search Results\")\n", 176 | " ax.set_xlabel(param_x)\n", 177 | " ax.set_ylabel(metric)\n", 178 | " ax.legend(title=param_z)\n", 179 | " return fig\n" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": { 186 | "colab": { 187 | "base_uri": "https://localhost:8080/", 188 | "height": 564 189 | }, 190 | "id": "-mHYRnizXwsz", 191 | "outputId": "41bb85fe-a2e4-4f51-911d-54630be0645a" 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "fig = plot_cv_results(gsearch.cv_results_, 'metric', 'n_neighbors')" 196 | ] 197 | } 198 | ], 199 | "metadata": { 200 | "colab": { 201 | "collapsed_sections": [], 202 | "name": "KNN with dynamic DTW", 203 | "provenance": [] 204 | }, 205 | "kernelspec": { 206 | "display_name": "Python 3", 207 | "language": "python", 208 | "name": "python3" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3", 220 | "version": "3.8.8" 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 4 225 | } 226 | -------------------------------------------------------------------------------- /chapter7/Kats.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7ea3881f-ddfc-475b-bfa9-bb46456ffee4", 6 | "metadata": {}, 7 | "source": [ 8 | "\"Open" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "bde1b68e-20b6-42e1-a5a8-81824e365545", 15 | "metadata": { 16 | "colab": { 17 | "base_uri": "https://localhost:8080/" 18 | }, 19 | "id": "bde1b68e-20b6-42e1-a5a8-81824e365545", 20 | "outputId": "f290afb2-8026-4792-e11f-2b5572336c3a" 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "!MINIMAL=1 pip install kats" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "Jb7RFf_D-YzN", 31 | "metadata": { 32 | "colab": { 33 | "base_uri": "https://localhost:8080/" 34 | }, 35 | "id": "Jb7RFf_D-YzN", 36 | "outputId": "681437c3-35c7-461a-f512-170d929c1a38" 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "!pip install \"numpy==1.20\"" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "e9f62717-49c6-4a08-bd69-67a4c3ebe2d5", 47 | "metadata": { 48 | "id": "e9f62717-49c6-4a08-bd69-67a4c3ebe2d5" 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "import matplotlib.pyplot as plt\n", 53 | "import seaborn as sns\n", 54 | "%matplotlib inline\n", 55 | "\n", 56 | "plt.style.use('seaborn-whitegrid')\n", 57 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n", 58 | "plt.rcParams[\"font.size\"] = \"17\"" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "2045898b-51b2-446a-bf30-4102725d36d9", 65 | "metadata": { 66 | "id": "2045898b-51b2-446a-bf30-4102725d36d9" 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "import pandas as pd\n", 71 | "\n", 72 | "owid_covid = pd.read_csv(\"https://covid.ourworldindata.org/data/owid-covid-data.csv\")\n", 73 | "owid_covid[\"date\"] = pd.to_datetime(owid_covid[\"date\"])\n", 74 | "df = owid_covid[owid_covid.location == \"France\"].set_index(\"date\", drop=True).resample('D').interpolate(method='linear').reset_index()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "d196768c-ccd4-487b-b88a-825f3280f8b8", 81 | "metadata": { 82 | "id": "d196768c-ccd4-487b-b88a-825f3280f8b8" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "from kats.models.ensemble.ensemble import EnsembleParams, BaseModelParams\n", 87 | "from kats.models.ensemble.kats_ensemble import KatsEnsemble\n", 88 | "from kats.models import (\n", 89 | " linear_model,\n", 90 | " quadratic_model\n", 91 | ")\n", 92 | "\n", 93 | "\n", 94 | "model_params = EnsembleParams(\n", 95 | " [\n", 96 | " BaseModelParams(\"linear\", linear_model.LinearModelParams()),\n", 97 | " BaseModelParams(\"quadratic\", quadratic_model.QuadraticModelParams()),\n", 98 | " ]\n", 99 | " )\n", 100 | "\n", 101 | "# create `KatsEnsembleParam` with detailed configurations \n", 102 | "KatsEnsembleParam = {\n", 103 | " \"models\": model_params,\n", 104 | " \"aggregation\": \"weightedavg\",\n", 105 | " \"seasonality_length\": 30,\n", 106 | " \"decomposition_method\": \"additive\",\n", 107 | "}" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "8379225b-c620-491f-8357-b4ac822deb50", 114 | "metadata": { 115 | "id": "8379225b-c620-491f-8357-b4ac822deb50" 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "from kats.consts import TimeSeriesData\n", 120 | "TARGET_COL = \"new_cases\"\n", 121 | "\n", 122 | "df_ts = TimeSeriesData(\n", 123 | " value=df[TARGET_COL], time=df[\"date\"]\n", 124 | ")" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "b03eca14-47e3-4280-a178-3b78e9f4d58f", 131 | "metadata": { 132 | "colab": { 133 | "base_uri": "https://localhost:8080/" 134 | }, 135 | "id": "b03eca14-47e3-4280-a178-3b78e9f4d58f", 136 | "outputId": "be5d4373-935c-497c-f706-d2522e37c5c5" 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "m = KatsEnsemble(\n", 141 | " data=df_ts, \n", 142 | " params=KatsEnsembleParam\n", 143 | ").fit()" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "id": "I-QkTf2kBiRE", 150 | "metadata": { 151 | "colab": { 152 | "base_uri": "https://localhost:8080/", 153 | "height": 419 154 | }, 155 | "id": "I-QkTf2kBiRE", 156 | "outputId": "a42a5154-4f23-466b-cf00-9d9616ab84f4" 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "m.predict(steps=90).aggregate()" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "c6c05f39-8e78-40cd-a43c-0ae2eed3c200", 167 | "metadata": { 168 | "colab": { 169 | "base_uri": "https://localhost:8080/", 170 | "height": 426 171 | }, 172 | "id": "c6c05f39-8e78-40cd-a43c-0ae2eed3c200", 173 | "outputId": "f9454627-6e0e-4217-e789-80802cd87d22" 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "m.predict(steps=90)\n", 178 | "m.aggregate()\n", 179 | "m.plot()\n", 180 | "plt.ylabel(TARGET_COL)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "id": "22489efc-329a-43d7-8524-ebc46a7d0494", 187 | "metadata": { 188 | "id": "22489efc-329a-43d7-8524-ebc46a7d0494" 189 | }, 190 | "outputs": [], 191 | "source": [] 192 | } 193 | ], 194 | "metadata": { 195 | "colab": { 196 | "collapsed_sections": [], 197 | "name": "Kats", 198 | "provenance": [] 199 | }, 200 | "kernelspec": { 201 | "display_name": "Python 3", 202 | "language": "python", 203 | "name": "python3" 204 | }, 205 | "language_info": { 206 | "codemirror_mode": { 207 | "name": "ipython", 208 | "version": 3 209 | }, 210 | "file_extension": ".py", 211 | "mimetype": "text/x-python", 212 | "name": "python", 213 | "nbconvert_exporter": "python", 214 | "pygments_lexer": "ipython3", 215 | "version": "3.8.8" 216 | } 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 5 220 | } 221 | -------------------------------------------------------------------------------- /chapter7/Silverkite.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "colab": { 15 | "base_uri": "https://localhost:8080/" 16 | }, 17 | "id": "ZQ3Ym1YDe2FQ", 18 | "outputId": "0be39c08-a455-4902-b876-9228b6a0eb42" 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "!pip install greykite" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "id": "1gEv7oq8e2FY" 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "import matplotlib.pyplot as plt\n", 34 | "import seaborn as sns\n", 35 | "%matplotlib inline\n", 36 | "\n", 37 | "plt.style.use('seaborn-whitegrid')\n", 38 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n", 39 | "plt.rcParams[\"font.size\"] = \"17\"" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "id": "TECK_SIde2FX" 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "import pandas as pd\n", 51 | "\n", 52 | "owid_covid = pd.read_csv(\"https://covid.ourworldindata.org/data/owid-covid-data.csv\")\n", 53 | "owid_covid[\"date\"] = pd.to_datetime(owid_covid[\"date\"])" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/" 62 | }, 63 | "id": "ezjWoCAxfZfX", 64 | "outputId": "60fb7d72-bed1-481d-96f9-97e86bd79f84" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "owid_covid.location.unique()" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "colab": { 76 | "base_uri": "https://localhost:8080/", 77 | "height": 309 78 | }, 79 | "id": "mtiFkfn9hEgV", 80 | "outputId": "6f198322-c183-4936-d2af-688dfb9316a0" 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "owid_covid.head()" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": { 91 | "id": "JslH5iW-gthQ" 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "df = owid_covid[owid_covid.location == \"France\"].set_index(\"date\", drop=True).resample('D').interpolate(method='linear')" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": { 102 | "id": "lPBBQ8eLe2Fb" 103 | }, 104 | "outputs": [], 105 | "source": [ 106 | "from greykite.framework.templates.autogen.forecast_config import (\n", 107 | " ForecastConfig, MetadataParam\n", 108 | ")\n", 109 | "\n", 110 | "metadata = MetadataParam(\n", 111 | " time_col=\"date\",\n", 112 | " value_col=\"new_cases\",\n", 113 | " freq=\"D\"\n", 114 | ")" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": { 121 | "colab": { 122 | "base_uri": "https://localhost:8080/" 123 | }, 124 | "id": "Bv40xaJde2Fc", 125 | "outputId": "8305ef45-89f0-4067-e2c5-96f3d9b7dfc6" 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "import warnings\n", 130 | "from greykite.framework.templates.forecaster import Forecaster\n", 131 | "from greykite.framework.templates.model_templates import ModelTemplateEnum\n", 132 | "\n", 133 | "forecaster = Forecaster()\n", 134 | "\n", 135 | "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", 136 | "result = forecaster.run_forecast_config(\n", 137 | " df=df.reset_index(),\n", 138 | " config=ForecastConfig(\n", 139 | " model_template=ModelTemplateEnum.SILVERKITE_DAILY_90.name,\n", 140 | " forecast_horizon=90,\n", 141 | " coverage=0.95,\n", 142 | " metadata_param=metadata,\n", 143 | " )\n", 144 | ")" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": { 151 | "colab": { 152 | "base_uri": "https://localhost:8080/", 153 | "height": 542 154 | }, 155 | "id": "Djt_cYAPqdES", 156 | "outputId": "e22d6805-d3a5-4500-d923-3ea1e9778657" 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "forecast = result.forecast\n", 161 | "forecast.plot().show(renderer=\"colab\") # leave out the renderer argument if you are not on google colab!" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": { 168 | "colab": { 169 | "base_uri": "https://localhost:8080/", 170 | "height": 204 171 | }, 172 | "id": "6MGAevXyvEyA", 173 | "outputId": "132fb3e9-be00-485b-badb-05547f65a13b" 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "forecast.df.head().round(2)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "colab": { 185 | "base_uri": "https://localhost:8080/", 186 | "height": 204 187 | }, 188 | "id": "El0vJRace2Fe", 189 | "outputId": "816da56c-de9b-458b-bcd1-3956e3331169" 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "from collections import defaultdict\n", 194 | "\n", 195 | "backtest = result.backtest\n", 196 | "backtest_eval = defaultdict(list)\n", 197 | "for metric, value in backtest.train_evaluation.items():\n", 198 | " backtest_eval[metric].append(value)\n", 199 | " backtest_eval[metric].append(backtest.test_evaluation[metric])\n", 200 | "metrics = pd.DataFrame(backtest_eval, index=[\"train\", \"test\"]).T\n", 201 | "metrics.head()" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": { 208 | "colab": { 209 | "base_uri": "https://localhost:8080/", 210 | "height": 241 211 | }, 212 | "id": "oL9GMtJgtHGO", 213 | "outputId": "f7555d36-66b9-4f7d-a4f8-fdced9397806" 214 | }, 215 | "outputs": [], 216 | "source": [ 217 | "model = result.model\n", 218 | "future_df = result.timeseries.make_future_dataframe(\n", 219 | " periods=4,\n", 220 | " include_history=False\n", 221 | ")\n", 222 | "model.predict(future_df)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": { 229 | "colab": { 230 | "base_uri": "https://localhost:8080/", 231 | "height": 241 232 | }, 233 | "id": "uIclGJzatLl9", 234 | "outputId": "ae1832c4-10df-4cae-c5f3-3ac9066f8745" 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "model.predict(future_df)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [] 247 | } 248 | ], 249 | "metadata": { 250 | "colab": { 251 | "collapsed_sections": [ 252 | "JHuPMCuAe2Ff", 253 | "xJe34gyie2Fi", 254 | "B9u12G4ne2Fm" 255 | ], 256 | "name": "Silverkite", 257 | "provenance": [] 258 | }, 259 | "kernelspec": { 260 | "display_name": "Python 3", 261 | "language": "python", 262 | "name": "python3" 263 | }, 264 | "language_info": { 265 | "codemirror_mode": { 266 | "name": "ipython", 267 | "version": 3 268 | }, 269 | "file_extension": ".py", 270 | "mimetype": "text/x-python", 271 | "name": "python", 272 | "nbconvert_exporter": "python", 273 | "pygments_lexer": "ipython3", 274 | "version": "3.8.8" 275 | } 276 | }, 277 | "nbformat": 4, 278 | "nbformat_minor": 4 279 | } 280 | -------------------------------------------------------------------------------- /chapter7/XGBoost.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "id": "cjiXyuy5e2Fj", 15 | "outputId": "26e9d719-07e0-4a60-8005-c05a24e1e009" 16 | }, 17 | "outputs": [], 18 | "source": [ 19 | "!pip install xgboost" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "id": "3CaAZqSGe2Fi" 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "import matplotlib.pyplot as plt\n", 31 | "import seaborn as sns\n", 32 | "%matplotlib inline\n", 33 | "\n", 34 | "plt.style.use('seaborn-whitegrid')\n", 35 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n", 36 | "plt.rcParams[\"font.size\"] = \"17\"" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": { 43 | "id": "UpEJozDdyizN" 44 | }, 45 | "outputs": [], 46 | "source": [ 47 | "import pandas as pd\n", 48 | "\n", 49 | "owid_covid = pd.read_csv(\"https://covid.ourworldindata.org/data/owid-covid-data.csv\")\n", 50 | "owid_covid[\"date\"] = pd.to_datetime(owid_covid[\"date\"])\n", 51 | "df = owid_covid[owid_covid.location == \"France\"].set_index(\"date\", drop=True).resample('D').interpolate(method='linear').reset_index()" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "id": "WqB6Vwn3e2Fj" 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "from sklearn.base import TransformerMixin, BaseEstimator\n", 63 | "from typing import List\n", 64 | "\n", 65 | "class DateFeatures(TransformerMixin, BaseEstimator):\n", 66 | " \"\"\"DateFeatures transformer.\"\"\"\n", 67 | " features = [\n", 68 | " \"hour\",\n", 69 | " \"year\",\n", 70 | " \"day\",\n", 71 | " \"weekday\",\n", 72 | " \"month\",\n", 73 | " \"quarter\",\n", 74 | " ]\n", 75 | " \n", 76 | " def __init__(self):\n", 77 | " \"\"\"Nothing much to do.\"\"\"\n", 78 | " super().__init__()\n", 79 | " self.feature_names: List[str] = []\n", 80 | "\n", 81 | " def get_feature_names(self):\n", 82 | " \"\"\"Feature names.\"\"\"\n", 83 | " return self.feature_names\n", 84 | " \n", 85 | " def transform(self, df: pd.DataFrame):\n", 86 | " \"\"\"Annotate date features.\"\"\"\n", 87 | " Xt = []\n", 88 | " for col in df.columns:\n", 89 | " for feature in self.features:\n", 90 | " date_feature = getattr(\n", 91 | " getattr(\n", 92 | " df[col], \"dt\"\n", 93 | " ), feature\n", 94 | " )\n", 95 | " date_feature.name = f\"{col}_{feature}\"\n", 96 | " Xt.append(date_feature)\n", 97 | " \n", 98 | " df2 = pd.concat(Xt, axis=1)\n", 99 | " self.feature_names = list(df2.columns)\n", 100 | " return df2\n", 101 | "\n", 102 | " def fit(self, df: pd.DataFrame, y=None, **fit_params):\n", 103 | " \"\"\"No fitting needed.\"\"\"\n", 104 | " return self" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": { 111 | "id": "IhEkSi5azUrM" 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "import numpy as np\n", 116 | "from sklearn.base import TransformerMixin, BaseEstimator\n", 117 | "from typing import Dict\n", 118 | "\n", 119 | "class CyclicalFeatures(TransformerMixin, BaseEstimator):\n", 120 | " \"\"\"CyclicalFeatures transformer.\"\"\"\n", 121 | " \n", 122 | " def __init__(self, max_vals: Dict[str, float] = {}):\n", 123 | " \"\"\"Nothing much to do.\"\"\"\n", 124 | " super().__init__()\n", 125 | " self.feature_names: List[str] = []\n", 126 | " self.max_vals = max_vals\n", 127 | "\n", 128 | " def get_feature_names(self):\n", 129 | " \"\"\"Feature names.\"\"\"\n", 130 | " return self.feature_names\n", 131 | " \n", 132 | " def transform(self, df: pd.DataFrame):\n", 133 | " \"\"\"Annotate date features.\"\"\"\n", 134 | " Xt = []\n", 135 | " for col in df.columns:\n", 136 | " if col in self.max_vals:\n", 137 | " max_val = self.max_vals[col]\n", 138 | " else:\n", 139 | " max_val = df[col].max()\n", 140 | " for fun_name, fun in [(\"cos\", np.cos), (\"sin\", np.sin)]:\n", 141 | " date_feature = fun(2 * np.pi * df[col] / max_val)\n", 142 | " date_feature.name = f\"{col}_{fun_name}\"\n", 143 | " Xt.append(date_feature)\n", 144 | " \n", 145 | " df2 = pd.concat(Xt, axis=1)\n", 146 | " self.feature_names = list(df2.columns)\n", 147 | " return df2\n", 148 | "\n", 149 | " def fit(self, df: pd.DataFrame, y=None, **fit_params):\n", 150 | " \"\"\"No fitting needed.\"\"\"\n", 151 | " return self" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": { 158 | "id": "87bELgcAziFi" 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "from sklearn.compose import ColumnTransformer\n", 163 | "from sklearn.pipeline import Pipeline, make_pipeline\n", 164 | "#from sklearn import linear_model\n", 165 | "\n", 166 | "\n", 167 | "preprocessor = ColumnTransformer(\n", 168 | " transformers=[(\n", 169 | " \"date\",\n", 170 | " make_pipeline(\n", 171 | " DateFeatures(),\n", 172 | " ColumnTransformer(transformers=[\n", 173 | " (\"cyclical\", CyclicalFeatures(),\n", 174 | " [\"date_day\", \"date_weekday\", \"date_month\"]\n", 175 | " )\n", 176 | " ], remainder=\"passthrough\")\n", 177 | " ), [\"date\"],\n", 178 | " ),], remainder=\"passthrough\"\n", 179 | ")\n", 180 | "\n", 181 | "pipeline = Pipeline(\n", 182 | " [\n", 183 | " (\"preprocessing\", preprocessor),\n", 184 | " #(\"clf\", linear_model.LinearRegression(),),\n", 185 | " ]\n", 186 | ")" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": { 193 | "colab": { 194 | "base_uri": "https://localhost:8080/" 195 | }, 196 | "id": "UvhfdtVre2Fj", 197 | "outputId": "453afa7a-c3f5-4a51-f693-55d2b7698282" 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "FEATURE_COLS = [\"date\"]\n", 202 | "date_features = pipeline.fit_transform(df[FEATURE_COLS])\n", 203 | "date_features" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": { 210 | "id": "WmhAZE-Me2Fk" 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "TRAIN_SIZE = int(len(df) * 0.9)\n", 215 | "HORIZON = 1\n", 216 | "TARGET_COL = \"new_cases\"\n", 217 | "\n", 218 | "X_train, X_test = df.iloc[HORIZON:TRAIN_SIZE], df.iloc[TRAIN_SIZE+HORIZON:]\n", 219 | "y_train = df.shift(periods=HORIZON).iloc[HORIZON:TRAIN_SIZE][TARGET_COL]\n", 220 | "y_test = df.shift(periods=HORIZON).iloc[TRAIN_SIZE+HORIZON:][TARGET_COL]" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": { 227 | "colab": { 228 | "base_uri": "https://localhost:8080/" 229 | }, 230 | "id": "aA7GL52je2Fl", 231 | "outputId": "d19d55fd-a683-4eea-a087-5676a4333077" 232 | }, 233 | "outputs": [], 234 | "source": [ 235 | "from xgboost import XGBRegressor\n", 236 | "\n", 237 | "pipeline = Pipeline(\n", 238 | " [\n", 239 | " (\"preprocessing\", preprocessor),\n", 240 | " (\"xgb\", XGBRegressor(objective=\"reg:squarederror\", n_estimators=1000))\n", 241 | " ]\n", 242 | ")\n", 243 | "pipeline.fit(X_train[FEATURE_COLS], y_train)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": { 250 | "id": "6b3PUD4CH5Gx" 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "MAX_HORIZON = 90\n", 255 | "X_test_horizon = pd.Series(pd.date_range(\n", 256 | " start=df.date.min(), \n", 257 | " periods=len(df) + MAX_HORIZON,\n", 258 | " name=\"date\"\n", 259 | ")).reset_index()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": { 266 | "id": "t8VoXOzCe2Fl" 267 | }, 268 | "outputs": [], 269 | "source": [ 270 | "forecasted = pd.concat(\n", 271 | " [pd.Series(pipeline.predict(X_test_horizon[FEATURE_COLS])), pd.Series(X_test_horizon.date)],\n", 272 | " axis=1\n", 273 | ")\n", 274 | "forecasted.columns = [TARGET_COL, \"date\"]" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": { 281 | "id": "7FGnoSbre2Fl" 282 | }, 283 | "outputs": [], 284 | "source": [ 285 | "actual = pd.concat(\n", 286 | " [pd.Series(df[TARGET_COL]), pd.Series(df.date)],\n", 287 | " axis=1\n", 288 | ")\n", 289 | "actual.columns = [TARGET_COL, \"date\"]" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": { 296 | "colab": { 297 | "base_uri": "https://localhost:8080/", 298 | "height": 443 299 | }, 300 | "id": "XtTKGkoTe2Fm", 301 | "outputId": "4c3c830f-3df9-4c5b-ad06-41cd623e5e70" 302 | }, 303 | "outputs": [], 304 | "source": [ 305 | "fig, ax = plt.subplots(figsize=(12, 6))\n", 306 | "forecasted.set_index(\"date\").plot(linestyle='--', ax=ax)\n", 307 | "actual.set_index(\"date\").plot(linestyle='-.', ax=ax)\n", 308 | "plt.legend([\"forecast\", \"actual\"])" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "metadata": { 315 | "colab": { 316 | "base_uri": "https://localhost:8080/" 317 | }, 318 | "id": "B4oPSOTse2F1", 319 | "outputId": "ad6760a9-e9a8-4cd8-812f-39fce7c4e636" 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "from sklearn.metrics import mean_squared_error\n", 324 | "\n", 325 | "test_data = actual.merge(forecasted, on=\"date\", suffixes=(\"_actual\", \"_predicted\"))\n", 326 | "\n", 327 | "mse = mean_squared_error(test_data.new_cases_actual, test_data.new_cases_predicted, squared=False) # RMSE\n", 328 | "print(\"The root mean squared error (RMSE) on test set: {:.2f}\".format(mse))" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": { 335 | "id": "HU1DQT1o5Vpn" 336 | }, 337 | "outputs": [], 338 | "source": [] 339 | } 340 | ], 341 | "metadata": { 342 | "colab": { 343 | "collapsed_sections": [], 344 | "name": "XGBoost", 345 | "provenance": [] 346 | }, 347 | "kernelspec": { 348 | "display_name": "Python 3", 349 | "language": "python", 350 | "name": "python3" 351 | }, 352 | "language_info": { 353 | "codemirror_mode": { 354 | "name": "ipython", 355 | "version": 3 356 | }, 357 | "file_extension": ".py", 358 | "mimetype": "text/x-python", 359 | "name": "python", 360 | "nbconvert_exporter": "python", 361 | "pygments_lexer": "ipython3", 362 | "version": "3.8.8" 363 | } 364 | }, 365 | "nbformat": 4, 366 | "nbformat_minor": 4 367 | } 368 | -------------------------------------------------------------------------------- /chapter8/Drift_Detection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6d73980e-a0c4-4163-8819-1d70e60c8454", 6 | "metadata": {}, 7 | "source": [ 8 | "\"Open" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "4546ab51-c2cb-4261-ae78-fdbd6d6a07ac", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "!pip install river" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "3caf1ccd-709a-49d2-8f0a-36b54bbc9cf9", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import numpy as np\n", 29 | "\n", 30 | "np.random.seed(12345)\n", 31 | "data_stream = np.concatenate(\n", 32 | " (np.random.randint(2, size=1000), np.random.randint(8, size=1000))\n", 33 | ")" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "c0b2fb45-abc9-4ac4-9e10-78405620c3d3", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "import matplotlib.pyplot as plt\n", 44 | "\n", 45 | "plt.style.use('seaborn-whitegrid')\n", 46 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n", 47 | "plt.rcParams[\"font.size\"] = \"17\"" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "d1551a3f-cf55-4269-a1da-c697ccb8a271", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "plt.figure(figsize=(16, 6))\n", 58 | "plt.plot(data_stream)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "034a8574-c193-49f3-ab23-44cd7c2e2c2d", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "def perform_test(drift_detector, data_stream):\n", 69 | " detected_indices = []\n", 70 | " for i, val in enumerate(data_stream):\n", 71 | " in_drift, in_warning = drift_detector.update(val)\n", 72 | " if in_drift:\n", 73 | " detected_indices.append(i)\n", 74 | " return detected_indices" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "aa3f4530-72b5-4733-8db7-e6e71de8f292", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "import matplotlib.pyplot as plt\n", 85 | "from river.drift import ADWIN, PageHinkley\n", 86 | "\n", 87 | "def show_drift(data_stream, indices):\n", 88 | " fig, ax = plt.subplots(figsize=(16, 6))\n", 89 | " ax.plot(data_stream)\n", 90 | " ax.plot(\n", 91 | " indices,\n", 92 | " data_stream[indices],\n", 93 | " \"r\",\n", 94 | " alpha=0.6,\n", 95 | " marker=r'$\\circ$',\n", 96 | " markersize=22,\n", 97 | " linewidth=4\n", 98 | " )\n", 99 | " plt.tight_layout()\n", 100 | "\n", 101 | "\n", 102 | "detected_indices = perform_test(PageHinkley(), data_stream)\n", 103 | "show_drift(data_stream, detected_indices)\n", 104 | "plt.title(\"Page-Hinkley Drift Detection\");" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "id": "2e19844c-b56f-48cf-8cba-75f2b431804f", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "detected_indices = perform_test(ADWIN(), data_stream)\n", 115 | "show_drift(data_stream, detected_indices)\n", 116 | "plt.title(\"ADWIN Drift Detection\");" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "id": "47d3a7da-f866-4192-8c7d-46261312bc15", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [] 126 | } 127 | ], 128 | "metadata": { 129 | "kernelspec": { 130 | "display_name": "Python 3", 131 | "language": "python", 132 | "name": "python3" 133 | }, 134 | "language_info": { 135 | "codemirror_mode": { 136 | "name": "ipython", 137 | "version": 3 138 | }, 139 | "file_extension": ".py", 140 | "mimetype": "text/x-python", 141 | "name": "python", 142 | "nbconvert_exporter": "python", 143 | "pygments_lexer": "ipython3", 144 | "version": "3.8.8" 145 | } 146 | }, 147 | "nbformat": 4, 148 | "nbformat_minor": 5 149 | } 150 | -------------------------------------------------------------------------------- /chapter8/Online_Learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6d0ec63d-ef95-48fa-b83d-eef4e02733e0", 6 | "metadata": {}, 7 | "source": [ 8 | "\"Open" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "514a7acc-81bf-49cb-8fea-632387aa44bf", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "!pip install river" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "id": "7e34e1a6-2d8d-4347-b96b-6927c210cb85", 24 | "metadata": {}, 25 | "source": [ 26 | "# Regression" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "d410bd23-82aa-4291-b034-16e9dc0cf606", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "from river import stream\n", 37 | "from river.datasets import base\n", 38 | "\n", 39 | "\n", 40 | "class SolarFlare(base.FileDataset):\n", 41 | " def __init__(self):\n", 42 | " super().__init__(\n", 43 | " n_samples=1066,\n", 44 | " n_features=10,\n", 45 | " n_outputs=1,\n", 46 | " task=base.MO_REG,\n", 47 | " filename=\"solar-flare.csv.zip\",\n", 48 | " )\n", 49 | "\n", 50 | " def __iter__(self):\n", 51 | " return stream.iter_csv(\n", 52 | " self.path,\n", 53 | " target=\"m-class-flares\",\n", 54 | " converters={\n", 55 | " \"zurich-class\": str,\n", 56 | " \"largest-spot-size\": str,\n", 57 | " \"spot-distribution\": str,\n", 58 | " \"activity\": int,\n", 59 | " \"evolution\": int,\n", 60 | " \"previous-24h-flare-activity\": int,\n", 61 | " \"hist-complex\": int,\n", 62 | " \"hist-complex-this-pass\": int,\n", 63 | " \"area\": int,\n", 64 | " \"largest-spot-area\": int,\n", 65 | " \"c-class-flares\": int,\n", 66 | " \"m-class-flares\": int,\n", 67 | " \"x-class-flares\": int,\n", 68 | " },\n", 69 | " )\n", 70 | " " 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "6a8bdf8c-7a6f-4a47-a035-50ba4861dfd2", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "from pprint import pprint\n", 81 | "from river import datasets\n", 82 | "\n", 83 | "for x, y in SolarFlare():\n", 84 | " pprint(x)\n", 85 | " pprint(y)\n", 86 | " break" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "db7b13ad-4303-443d-840f-00aef7f99f51", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "import numbers\n", 97 | "from river import compose\n", 98 | "from river import datasets\n", 99 | "from river import evaluate\n", 100 | "from river import linear_model\n", 101 | "from river import metrics\n", 102 | "from river import preprocessing\n", 103 | "from river import tree\n", 104 | "\n", 105 | "dataset = SolarFlare()\n", 106 | "num = compose.SelectType(numbers.Number) | preprocessing.MinMaxScaler()\n", 107 | "cat = compose.SelectType(str) | preprocessing.OneHotEncoder(sparse=False)\n", 108 | "model = tree.HoeffdingTreeRegressor()\n", 109 | "pipeline = (num + cat) | model\n", 110 | "metric = metrics.MAE()\n", 111 | "#evaluate.progressive_val_score(dataset, pipeline, metric)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "3597bf77-9fdb-4358-bb6b-9543a65881c2", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "import matplotlib.pyplot as plt\n", 122 | "\n", 123 | "errors = []\n", 124 | "for x, y in SolarFlare():\n", 125 | " y_pred = pipeline.predict_one(x)\n", 126 | " metric = metric.update(y, y_pred)\n", 127 | " errors.append(metric.get())\n", 128 | " pipeline = pipeline.learn_one(x, y)\n", 129 | "\n", 130 | "fig, ax = plt.subplots(figsize=(16, 6))\n", 131 | "ax.plot(\n", 132 | " errors,\n", 133 | " \"ro\",\n", 134 | " alpha=0.6,\n", 135 | " markersize=2,\n", 136 | " linewidth=4\n", 137 | ")" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "30f435e0-56ed-448e-ba40-b53c5dcf3ef5", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "fig, ax = plt.subplots(figsize=(16, 6))\n", 148 | "ax.plot(\n", 149 | " errors,\n", 150 | " \"ro\",\n", 151 | " alpha=0.6,\n", 152 | " markersize=2,\n", 153 | " linewidth=4\n", 154 | ")\n", 155 | "ax.set_xlabel(\"number of points\")\n", 156 | "ax.set_ylabel(\"MAE\");" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "id": "7a709167-335a-4968-b8ce-c1faee11af30", 162 | "metadata": {}, 163 | "source": [ 164 | "# Adaptive Models on a Concept Drift Data Stream" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "id": "b0b6a6f8-48d0-478c-825a-d2ed313e7d6b", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "from river import (\n", 175 | " synth, ensemble, tree,\n", 176 | " evaluate, metrics\n", 177 | ")\n", 178 | "\n", 179 | "models = [\n", 180 | " tree.HoeffdingTreeRegressor(),\n", 181 | " tree.HoeffdingAdaptiveTreeRegressor(),\n", 182 | " ensemble.AdaptiveRandomForestRegressor(seed=42),\n", 183 | "]\n", 184 | "\n", 185 | "results = {}\n", 186 | "for model in models:\n", 187 | " metric = metrics.MSE()\n", 188 | " errors = []\n", 189 | " dataset = synth.ConceptDriftStream(\n", 190 | " seed=42, position=500, width=40\n", 191 | " ).take(1000) \n", 192 | " for i, (x, y) in enumerate(dataset):\n", 193 | " y_pred = model.predict_one(x)\n", 194 | " metric = metric.update(y, y_pred)\n", 195 | " model = model.learn_one(x, y)\n", 196 | " if (i % 100) == 0:\n", 197 | " errors.append(dict(step=i, error=metric.get()))\n", 198 | " results[str(model.__class__).split(\".\")[-1][:-2]] = errors" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "3e3a3135-fc1e-4f38-b2f0-97b945bc3ea6", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "results" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "id": "024745ef-b4ef-4ee8-ac08-9c839043ec6d", 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "import pandas as pd\n", 219 | "import seaborn as sns\n", 220 | "\n", 221 | "plt.figure(figsize=(16, 6))\n", 222 | "styles = [\"-\",\"--\",\"-.\",\":\"]\n", 223 | "markers = [\n", 224 | " '.', ',', 'o', 'v', '^', '<', '>',\n", 225 | " '1', '2', '3', '4', '8', 's', 'p',\n", 226 | " '*', 'h', 'H', '+', 'x', 'D', 'd',\n", 227 | " '|', '_', 'P', 'X', 0, 1, 2, 3, 4,\n", 228 | " 5, 6, 7, 8, 9, 10, 11\n", 229 | "]\n", 230 | "\n", 231 | "for i, (model, errors) in enumerate(results.items()):\n", 232 | " df = pd.DataFrame(errors)\n", 233 | " sns.lineplot(\n", 234 | " data=df,\n", 235 | " x=\"step\",\n", 236 | " y=\"error\",\n", 237 | " linestyle=styles[i%len(styles)],\n", 238 | " alpha=0.5,\n", 239 | " markersize=22,\n", 240 | " markers=markers[i%len(markers)], \n", 241 | " label=model,\n", 242 | " linewidth=4\n", 243 | " )\n", 244 | " \n", 245 | "plt.ylabel(\"MSE\")\n", 246 | "plt.xlabel(\"Step\")\n", 247 | "sns.set_style(\"ticks\")" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "id": "ad7c4902-96ae-4abc-b48b-801929473533", 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [] 257 | } 258 | ], 259 | "metadata": { 260 | "kernelspec": { 261 | "display_name": "Python 3", 262 | "language": "python", 263 | "name": "python3" 264 | }, 265 | "language_info": { 266 | "codemirror_mode": { 267 | "name": "ipython", 268 | "version": 3 269 | }, 270 | "file_extension": ".py", 271 | "mimetype": "text/x-python", 272 | "name": "python", 273 | "nbconvert_exporter": "python", 274 | "pygments_lexer": "ipython3", 275 | "version": "3.8.8" 276 | } 277 | }, 278 | "nbformat": 4, 279 | "nbformat_minor": 5 280 | } 281 | --------------------------------------------------------------------------------