├── .gitignore
├── .other
└── cover.png
├── LICENSE
├── README.md
├── chapter10
├── Causal_CNN.ipynb
├── RNN.ipynb
├── Time_Series_with_Deep_Learning.ipynb
└── passengers.csv
├── chapter11
├── Ranking_with_Bandits.ipynb
├── Trading_with_DQN.ipynb
└── jesterfinal151cols.csv
├── chapter12
├── Energy_Demand_Forecasting.ipynb
├── gaussian_process.ipynb
└── test_models.ipynb
├── chapter2
├── Air Pollution.ipynb
├── EEG Signals.ipynb
├── README.md
├── monthly_csv.csv
└── spm.csv
├── chapter3
└── Preprocessing.ipynb
├── chapter5
└── Forecasting.ipynb
├── chapter6
└── Change-Points_Anomalies.ipynb
├── chapter7
├── KNN_with_dynamic_DTW.ipynb
├── Kats.ipynb
├── Silverkite.ipynb
└── XGBoost.ipynb
├── chapter8
├── Drift_Detection.ipynb
└── Online_Learning.ipynb
└── chapter9
├── Causal_Impact_Volkswagen.ipynb
├── Fuzzy_time_series_forecasting.ipynb
├── Markov_switching_model.ipynb
└── Prophet_model.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints/
2 | .DS_Store
--------------------------------------------------------------------------------
/.other/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PacktPublishing/Machine-Learning-for-Time-Series-with-Python/f76b589b83c260563bee1515e393260f5a68c96a/.other/cover.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Packt
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

2 |
3 | ## Machine Learning Summit 2025
4 | **Bridging Theory and Practice: ML Solutions for Today’s Challenges**
5 |
6 | 3 days, 20+ experts, and 25+ tech sessions and talks covering critical aspects of:
7 | - **Agentic and Generative AI**
8 | - **Applied Machine Learning in the Real World**
9 | - **ML Engineering and Optimization**
10 |
11 | 👉 [Book your ticket now >>](https://packt.link/mlsumgh)
12 |
13 | ---
14 |
15 | ## Join Our Newsletters 📬
16 |
17 | ### DataPro
18 | *The future of AI is unfolding. Don’t fall behind.*
19 |
20 | 
21 |
22 | Stay ahead with [**DataPro**](https://landing.packtpub.com/subscribe-datapronewsletter/?link_from_packtlink=yes), the free weekly newsletter for data scientists, AI/ML researchers, and data engineers.
23 | From trending tools like **PyTorch**, **scikit-learn**, **XGBoost**, and **BentoML** to hands-on insights on **database optimization** and real-world **ML workflows**, you’ll get what matters, fast.
24 |
25 | > Stay sharp with [DataPro](https://landing.packtpub.com/subscribe-datapronewsletter/?link_from_packtlink=yes). Join **115K+ data professionals** who never miss a beat.
26 |
27 | ---
28 |
29 | ### BIPro
30 | *Business runs on data. Make sure yours tells the right story.*
31 |
32 | 
33 |
34 | [**BIPro**](https://landing.packtpub.com/subscribe-bipro-newsletter/?link_from_packtlink=yes) is your free weekly newsletter for BI professionals, analysts, and data leaders.
35 | Get practical tips on **dashboarding**, **data visualization**, and **analytics strategy** with tools like **Power BI**, **Tableau**, **Looker**, **SQL**, and **dbt**.
36 |
37 | > Get smarter with [BIPro](https://landing.packtpub.com/subscribe-bipro-newsletter/?link_from_packtlink=yes). Trusted by **35K+ BI professionals**, see what you’re missing.
38 |
39 |
40 |
41 |
42 | # Machine-Learning-for-Time-Series-with-Python
43 |
44 | [
](https://www.amazon.com/Machine-Learning-Time-Python-state/dp/1801819629/)
45 |
46 | Become proficient in deriving insights from time-series data and analyzing a model’s performance
47 |
48 | ## Links
49 |
50 | * [Amazon](https://www.amazon.com/Machine-Learning-Time-Python-state/dp/1801819629/)
51 | * [Packt Publishing](https://www.packtpub.com/product/machine-learning-for-time-series-with-python/9781801819626)
52 |
53 | ## Key Features
54 | * Explore popular and modern machine learning methods including the latest online and deep learning algorithms
55 | * Learn to increase the accuracy of your predictions by matching the right model with the right problem
56 | * Master time-series via real-world case studies on operations management, digital marketing, finance, and healthcare
57 |
58 | ## What you will learn
59 | - Understand the main classes of time-series and learn how to detect outliers and patterns
60 | - Choose the right method to solve time-series problems
61 | - Characterize seasonal and correlation patterns through autocorrelation and statistical techniques
62 | - Get to grips with time-series data visualization
63 | - Understand classical time-series models like ARMA and ARIMA
64 | - Implement deep learning models, like Gaussian processes, transformers, and state-of-the-art machine learning models
65 | - Become familiar with many libraries like Prophet, XGboost, and TensorFlow
66 |
67 | ## Who This Book Is For
68 | This book is ideal for data analysts, data scientists, and Python developers who are looking to perform time-series analysis to effectively predict outcomes. Basic knowledge of the Python language is essential. Familiarity with statistics is desirable.
69 |
70 | ## Table of Contents
71 | 1. Introduction to Time-Series with Python
72 | 2. Time-Series Analysis with Python
73 | 3. Preprocessing Time-Series
74 | 4. Introduction to Machine Learning for Time-Series
75 | 5. Forecasting with Moving Averages and Autoregressive Models
76 | 6. Unsupervised Methods for Time-Series
77 | 7. Machine Learning Models for Time-Series
78 | 8. Online Learning for Time-Series
79 | 9. Probabilistic Models for Time-Series
80 | 10. Deep Learning for Time-Series
81 | 11. Reinforcement Learning for Time-Series
82 | 12. Multivariate Forecasting
83 |
84 | ## Author Notes
85 |
86 | I've heard from a few people struggling with tsfresh and featuretools for chapter 3.
87 |
88 | [My PR](https://github.com/blue-yonder/tsfresh/pull/912) for tsfresh was merged mid-December fixing a version incompatibility - featuretools went through many breaking changes with the release of version 1.0.0 (congratulations to the team!). Please see how to fix any problems in the [discussion here](https://github.com/PacktPublishing/Machine-Learning-for-Time-Series-with-Python/issues/2).
89 | ### Download a free PDF
90 |
91 | If you have already purchased a print or Kindle version of this book, you can get a DRM-free PDF version at no cost.
Simply click on the link to claim your free PDF.
92 | https://packt.link/free-ebook/9781801819626
--------------------------------------------------------------------------------
/chapter10/Causal_CNN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "
"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 6,
13 | "metadata": {
14 | "id": "bU3tHC0PArEI",
15 | "tags": []
16 | },
17 | "outputs": [],
18 | "source": [
19 | "import numpy as np\n",
20 | "import pandas as pd\n",
21 | "from keras.layers import Conv1D, Input, Add, Activation, Dropout\n",
22 | "from keras.models import Sequential, Model\n",
23 | "from keras.layers.advanced_activations import LeakyReLU, ELU\n",
24 | "from keras import optimizers\n",
25 | "import tensorflow as tf\n",
26 | "\n",
27 | "callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=10)\n",
28 | "\n",
29 | "\n",
30 | "def DC_CNN_Block(nb_filter, filter_length, dilation):\n",
31 | " def f(input_):\n",
32 | " residual = input_\n",
33 | " layer_out = Conv1D(\n",
34 | " filters=nb_filter, kernel_size=filter_length, \n",
35 | " dilation_rate=dilation, \n",
36 | " activation='linear', padding='causal', use_bias=False\n",
37 | " )(input_) \n",
38 | " layer_out = Activation('selu')(layer_out) \n",
39 | " skip_out = Conv1D(1, 1, activation='linear', use_bias=False)(layer_out) \n",
40 | " network_in = Conv1D(1, 1, activation='linear', use_bias=False)(layer_out) \n",
41 | " network_out = Add()([residual, network_in]) \n",
42 | " return network_out, skip_out \n",
43 | " return f\n",
44 | "\n",
45 | "\n",
46 | "def DC_CNN_Model(length):\n",
47 | " input = Input(shape=(length,1))\n",
48 | " l1a, l1b = DC_CNN_Block(32, 2, 1)(input) \n",
49 | " l2a, l2b = DC_CNN_Block(32, 2, 2)(l1a) \n",
50 | " l3a, l3b = DC_CNN_Block(32, 2, 4)(l2a)\n",
51 | " l4a, l4b = DC_CNN_Block(32, 2, 8)(l3a)\n",
52 | " l5a, l5b = DC_CNN_Block(32, 2, 16)(l4a)\n",
53 | " l6a, l6b = DC_CNN_Block(32, 2, 32)(l5a)\n",
54 | " l6b = Dropout(0.8)(l6b)\n",
55 | " l7a, l7b = DC_CNN_Block(32, 2, 64)(l6a)\n",
56 | " l7b = Dropout(0.8)(l7b)\n",
57 | " l8 = Add()([l1b, l2b, l3b, l4b, l5b, l6b, l7b])\n",
58 | " l9 = Activation('relu')(l8) \n",
59 | " l21 = Conv1D(1, 1, activation='linear', use_bias=False)(l9)\n",
60 | " model = Model(inputs=input, outputs=l21)\n",
61 | " model.compile(loss='mae', optimizer=optimizers.adam_v2.Adam(), metrics=['mse'])\n",
62 | " return model"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 7,
68 | "metadata": {
69 | "tags": []
70 | },
71 | "outputs": [],
72 | "source": [
73 | "def fit_model(timeseries):\n",
74 | " length = len(timeseries)-1\n",
75 | " model = DC_CNN_Model(length)\n",
76 | " model.summary()\n",
77 | " X = timeseries[:-1].reshape(1,length, 1)\n",
78 | " y = timeseries[1:].reshape(1,length, 1)\n",
79 | " model.fit(X, y, epochs=3000, callbacks=[callback])\n",
80 | " return model\n",
81 | " \n",
82 | "def forecast(model, timeseries, horizon: int):\n",
83 | " length = len(timeseries)-1\n",
84 | " pred_array = np.zeros(horizon).reshape(1, horizon, 1)\n",
85 | " X_test_initial = timeseries[1:].reshape(1,length,1)\n",
86 | " \n",
87 | " pred_array[: ,0, :] = model.predict(X_test_initial)[:, -1:, :]\n",
88 | " for i in range(horizon-1):\n",
89 | " pred_array[:, i+1:, :] = model.predict(\n",
90 | " np.append(\n",
91 | " X_test_initial[:, i+1:, :], \n",
92 | " pred_array[:, :i+1, :]\n",
93 | " ).reshape(1, length, 1))[:, -1:, :]\n",
94 | " return pred_array.flatten()\n",
95 | " \n",
96 | "def evaluate_timeseries(series, horizon: int):\n",
97 | " model = fit_model(series)\n",
98 | " pred_array = forecast(model, series, horizon)\n",
99 | " return pred_array, model"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "metadata": {
106 | "id": "LCzxf36xxvri",
107 | "tags": []
108 | },
109 | "outputs": [],
110 | "source": [
111 | "import pandas as pd\n",
112 | "import matplotlib.pyplot as plt\n",
113 | "import seaborn as sns\n",
114 | "\n",
115 | "def show_result(y_test, predicted, ylabel=\"Passengers\"):\n",
116 | " plt.figure(figsize=(16, 6))\n",
117 | " plt.plot(y_test.index, predicted, 'o-', label=\"predicted\")\n",
118 | " plt.plot(y_test.index, y_test, '.-', label=\"actual\")\n",
119 | "\n",
120 | " plt.ylabel(ylabel)\n",
121 | " plt.legend()"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 9,
127 | "metadata": {
128 | "id": "Tu-831LAylwK",
129 | "tags": []
130 | },
131 | "outputs": [],
132 | "source": [
133 | "import pandas as pd\n",
134 | "\n",
135 | "values = [ \n",
136 | " 112., 118., 132., 129., 121., 135., 148., 148., 136., 119., 104., 118., 115., 126.,\n",
137 | " 141., 135., 125., 149., 170., 170., 158., 133., 114., 140., 145., 150., 178., 163.,\n",
138 | " 172., 178., 199., 199., 184., 162., 146., 166., 171., 180., 193., 181., 183., 218.,\n",
139 | " 230., 242., 209., 191., 172., 194., 196., 196., 236., 235., 229., 243., 264., 272.,\n",
140 | " 237., 211., 180., 201., 204., 188., 235., 227., 234., 264., 302., 293., 259., 229.,\n",
141 | " 203., 229., 242., 233., 267., 269., 270., 315., 364., 347., 312., 274., 237., 278.,\n",
142 | " 284., 277., 317., 313., 318., 374., 413., 405., 355., 306., 271., 306., 315., 301.,\n",
143 | " 356., 348., 355., 422., 465., 467., 404., 347., 305., 336., 340., 318., 362., 348.,\n",
144 | " 363., 435., 491., 505., 404., 359., 310., 337., 360., 342., 406., 396., 420., 472.,\n",
145 | " 548., 559., 463., 407., 362., 405., 417., 391., 419., 461., 472., 535., 622., 606.,\n",
146 | " 508., 461., 390., 432.,\n",
147 | " ]\n",
148 | "idx = pd.date_range(\"1949-01-01\", periods=len(values), freq=\"M\")\n",
149 | "passengers = pd.Series(values, index=idx, name=\"passengers\").to_frame()"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "metadata": {
156 | "colab": {
157 | "base_uri": "https://localhost:8080/"
158 | },
159 | "id": "boJVi8Jiya4D",
160 | "outputId": "e4605e96-9198-4ab0-a277-489933daaa0b",
161 | "tags": []
162 | },
163 | "outputs": [],
164 | "source": [
165 | "from sklearn.model_selection import train_test_split\n",
166 | "\n",
167 | "\n",
168 | "X_train, X_test, y_train, y_test = train_test_split(\n",
169 | " passengers.passengers, passengers.passengers.shift(-1), shuffle=False\n",
170 | ")\n",
171 | "HORIZON = len(y_test)\n",
172 | "predictions, model = evaluate_timeseries(X_train.values.reshape(-1, 1), horizon=HORIZON)"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 14,
178 | "metadata": {
179 | "colab": {
180 | "base_uri": "https://localhost:8080/",
181 | "height": 379
182 | },
183 | "id": "fxQ1vflLzcmt",
184 | "outputId": "1dfb65e4-fe34-46e0-879a-a8a095493e05",
185 | "tags": []
186 | },
187 | "outputs": [],
188 | "source": [
189 | "show_result(y_test[:HORIZON], predictions[:HORIZON], \"Passengers\")"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": 15,
195 | "metadata": {
196 | "id": "yH1hqiySztJ1",
197 | "tags": []
198 | },
199 | "outputs": [],
200 | "source": []
201 | }
202 | ],
203 | "metadata": {
204 | "colab": {
205 | "collapsed_sections": [],
206 | "name": "Causal CNN",
207 | "provenance": []
208 | },
209 | "kernelspec": {
210 | "display_name": "Python 3",
211 | "language": "python",
212 | "name": "python3"
213 | },
214 | "language_info": {
215 | "codemirror_mode": {
216 | "name": "ipython",
217 | "version": 3
218 | },
219 | "file_extension": ".py",
220 | "mimetype": "text/x-python",
221 | "name": "python",
222 | "nbconvert_exporter": "python",
223 | "pygments_lexer": "ipython3",
224 | "version": "3.8.8"
225 | }
226 | },
227 | "nbformat": 4,
228 | "nbformat_minor": 4
229 | }
230 |
--------------------------------------------------------------------------------
/chapter10/RNN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "
"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "id": "o47CT4PG6dsp"
15 | },
16 | "outputs": [],
17 | "source": [
18 | "import pandas as pd\n",
19 | "\n",
20 | "values = [ \n",
21 | " 112., 118., 132., 129., 121., 135., 148., 148., 136., 119., 104., 118., 115., 126.,\n",
22 | " 141., 135., 125., 149., 170., 170., 158., 133., 114., 140., 145., 150., 178., 163.,\n",
23 | " 172., 178., 199., 199., 184., 162., 146., 166., 171., 180., 193., 181., 183., 218.,\n",
24 | " 230., 242., 209., 191., 172., 194., 196., 196., 236., 235., 229., 243., 264., 272.,\n",
25 | " 237., 211., 180., 201., 204., 188., 235., 227., 234., 264., 302., 293., 259., 229.,\n",
26 | " 203., 229., 242., 233., 267., 269., 270., 315., 364., 347., 312., 274., 237., 278.,\n",
27 | " 284., 277., 317., 313., 318., 374., 413., 405., 355., 306., 271., 306., 315., 301.,\n",
28 | " 356., 348., 355., 422., 465., 467., 404., 347., 305., 336., 340., 318., 362., 348.,\n",
29 | " 363., 435., 491., 505., 404., 359., 310., 337., 360., 342., 406., 396., 420., 472.,\n",
30 | " 548., 559., 463., 407., 362., 405., 417., 391., 419., 461., 472., 535., 622., 606.,\n",
31 | " 508., 461., 390., 432.,\n",
32 | " ]\n",
33 | "idx = pd.date_range(\"1949-01-01\", periods=len(values), freq=\"M\")\n",
34 | "passengers = pd.Series(values, index=idx, name=\"passengers\").to_frame()"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {
41 | "colab": {
42 | "base_uri": "https://localhost:8080/",
43 | "height": 419
44 | },
45 | "id": "BNOhWQC46fJU",
46 | "outputId": "1687dde6-4e92-45f5-f40c-e4668cd7ba32"
47 | },
48 | "outputs": [],
49 | "source": [
50 | "passengers"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {
57 | "id": "1Fhde-Qg6pPj"
58 | },
59 | "outputs": [],
60 | "source": [
61 | "LOOKBACK = 10\n",
62 | "\n",
63 | "def wrap_data(df, lookback):\n",
64 | " dataset = []\n",
65 | " for index in range(lookback, len(df)+1):\n",
66 | " features = {\n",
67 | " f\"col_{i}\": float(val) for i, val in enumerate(\n",
68 | " df.iloc[index-lookback:index].values\n",
69 | " )\n",
70 | " }\n",
71 | " row = pd.DataFrame.from_dict([features])\n",
72 | " row.index = [df.index[index-1]]\n",
73 | " dataset.append(row)\n",
74 | " return pd.concat(dataset, axis=0)\n",
75 | "\n",
76 | "dataset = wrap_data(passengers, lookback=LOOKBACK)\n",
77 | "dataset = dataset.join(passengers.shift(-1))"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "metadata": {
84 | "id": "xkD2twYx7lxK"
85 | },
86 | "outputs": [],
87 | "source": [
88 | "import tensorflow.keras as keras\n",
89 | "from tensorflow.keras.layers import Input, Bidirectional, LSTM, Dense\n",
90 | "import tensorflow as tf\n",
91 | "\n",
92 | "callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n",
93 | "\n",
94 | "def create_model(passengers):\n",
95 | " input_layer = Input(shape=(LOOKBACK, 1))\n",
96 | " recurrent = Bidirectional(LSTM(20, activation=\"tanh\"))(input_layer)\n",
97 | " output_layer = Dense(1)(recurrent)\n",
98 | " model = keras.models.Model(inputs=input_layer, outputs=output_layer)\n",
99 | " model.compile(\n",
100 | " loss='mse', optimizer=keras.optimizers.Adagrad(),\n",
101 | " metrics=[\n",
102 | " keras.metrics.RootMeanSquaredError(),\n",
103 | " keras.metrics.MeanAbsoluteError()\n",
104 | " ])\n",
105 | " return model\n",
106 | "\n",
107 | "model = create_model(passengers)"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {
114 | "id": "uv0e7MJNBIrl"
115 | },
116 | "outputs": [],
117 | "source": [
118 | "from sklearn.model_selection import train_test_split\n",
119 | "\n",
120 | "X_train, X_test, y_train, y_test = train_test_split(\n",
121 | " dataset.drop(columns=\"passengers\"),\n",
122 | " dataset[\"passengers\"],\n",
123 | " shuffle=False\n",
124 | ")"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {
131 | "colab": {
132 | "base_uri": "https://localhost:8080/"
133 | },
134 | "id": "bN4gzfAhAK4q",
135 | "outputId": "4c812a58-5a93-421a-fd20-e47bea56d52b"
136 | },
137 | "outputs": [],
138 | "source": [
139 | "model.fit(X_train, y_train, epochs=1000, callbacks=[callback])"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": null,
145 | "metadata": {
146 | "id": "o6lhB_7VApOJ"
147 | },
148 | "outputs": [],
149 | "source": [
150 | "import pandas as pd\n",
151 | "import matplotlib.pyplot as plt\n",
152 | "import seaborn as sns\n",
153 | "\n",
154 | "def show_result(y_test, predicted):\n",
155 | " plt.figure(figsize=(16, 6))\n",
156 | " plt.plot(y_test.index, predicted, 'o-', label=\"predicted\")\n",
157 | " plt.plot(y_test.index, y_test, '.-', label=\"actual\")\n",
158 | "\n",
159 | " plt.ylabel(\"Passengers\")\n",
160 | " plt.legend()"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {
167 | "colab": {
168 | "base_uri": "https://localhost:8080/",
169 | "height": 379
170 | },
171 | "id": "9TN2UDG_Bu0c",
172 | "outputId": "4c84996c-8906-4d53-b84b-2204811b017c"
173 | },
174 | "outputs": [],
175 | "source": [
176 | "predicted = model.predict(X_test)\n",
177 | "show_result(y_test, predicted)"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {
184 | "id": "V668DSiQBzV1"
185 | },
186 | "outputs": [],
187 | "source": []
188 | }
189 | ],
190 | "metadata": {
191 | "colab": {
192 | "name": "RNN",
193 | "provenance": []
194 | },
195 | "kernelspec": {
196 | "display_name": "Python 3",
197 | "language": "python",
198 | "name": "python3"
199 | },
200 | "language_info": {
201 | "codemirror_mode": {
202 | "name": "ipython",
203 | "version": 3
204 | },
205 | "file_extension": ".py",
206 | "mimetype": "text/x-python",
207 | "name": "python",
208 | "nbconvert_exporter": "python",
209 | "pygments_lexer": "ipython3",
210 | "version": "3.8.8"
211 | }
212 | },
213 | "nbformat": 4,
214 | "nbformat_minor": 4
215 | }
216 |
--------------------------------------------------------------------------------
/chapter10/Time_Series_with_Deep_Learning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "
"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# from https://github.com/FinYang/tsdl/blob/56e091544cb81e573ee6db20c6f9cd39c70e6243/data-raw/boxjenk/seriesg.dat"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {
23 | "id": "jEiBBpYjLCj3"
24 | },
25 | "outputs": [],
26 | "source": [
27 | "values = [ \n",
28 | " 112., 118., 132., 129., 121., 135., 148., 148., 136., 119., 104., 118., 115., 126.,\n",
29 | " 141., 135., 125., 149., 170., 170., 158., 133., 114., 140., 145., 150., 178., 163.,\n",
30 | " 172., 178., 199., 199., 184., 162., 146., 166., 171., 180., 193., 181., 183., 218.,\n",
31 | " 230., 242., 209., 191., 172., 194., 196., 196., 236., 235., 229., 243., 264., 272.,\n",
32 | " 237., 211., 180., 201., 204., 188., 235., 227., 234., 264., 302., 293., 259., 229.,\n",
33 | " 203., 229., 242., 233., 267., 269., 270., 315., 364., 347., 312., 274., 237., 278.,\n",
34 | " 284., 277., 317., 313., 318., 374., 413., 405., 355., 306., 271., 306., 315., 301.,\n",
35 | " 356., 348., 355., 422., 465., 467., 404., 347., 305., 336., 340., 318., 362., 348.,\n",
36 | " 363., 435., 491., 505., 404., 359., 310., 337., 360., 342., 406., 396., 420., 472.,\n",
37 | " 548., 559., 463., 407., 362., 405., 417., 391., 419., 461., 472., 535., 622., 606.,\n",
38 | " 508., 461., 390., 432.,\n",
39 | " ]"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {
46 | "id": "9bBfTTdMK9AE"
47 | },
48 | "outputs": [],
49 | "source": [
50 | "import pandas as pd\n",
51 | "idx = pd.date_range(\"1949-01-01\", periods=len(values), freq=\"M\")"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {
58 | "id": "z3sG4R4VKnWT"
59 | },
60 | "outputs": [],
61 | "source": [
62 | "passengers = pd.Series(values, index=idx, name=\"passengers\").to_frame()"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "metadata": {
69 | "id": "PBlu9I2p8uib"
70 | },
71 | "outputs": [],
72 | "source": [
73 | "from sklearn.model_selection import train_test_split\n",
74 | "\n",
75 | "X_train, X_test, y_train, y_test = train_test_split(passengers, passengers.passengers.shift(-1), shuffle=False)"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {
82 | "id": "DkejrcL8aFt1"
83 | },
84 | "outputs": [],
85 | "source": [
86 | "import tensorflow.keras as keras\n",
87 | "import tensorflow as tf\n",
88 | "\n",
89 | "DROPOUT_RATIO = 0.2\n",
90 | "HIDDEN_NEURONS = 10\n",
91 | "\n",
92 | "callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n",
93 | "\n",
94 | "def create_model(passengers):\n",
95 | " input_layer = keras.layers.Input(len(passengers.columns))\n",
96 | "\n",
97 | " hiden_layer = keras.layers.Dropout(DROPOUT_RATIO)(input_layer)\n",
98 | " hiden_layer = keras.layers.Dense(HIDDEN_NEURONS, activation='relu')(hiden_layer)\n",
99 | "\n",
100 | " output_layer = keras.layers.Dropout(DROPOUT_RATIO)(hiden_layer)\n",
101 | " output_layer = keras.layers.Dense(1)(output_layer)\n",
102 | "\n",
103 | " model = keras.models.Model(inputs=input_layer, outputs=output_layer)\n",
104 | "\n",
105 | " model.compile(loss='mse', optimizer=keras.optimizers.Adagrad(),\n",
106 | " metrics=[keras.metrics.RootMeanSquaredError(), keras.metrics.MeanAbsoluteError()])\n",
107 | " return model\n",
108 | "\n",
109 | "model = create_model(passengers)"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {
116 | "colab": {
117 | "base_uri": "https://localhost:8080/"
118 | },
119 | "id": "qJCCbRTib8Up",
120 | "outputId": "63136aae-e096-4ff3-f4fb-2d728054d099"
121 | },
122 | "outputs": [],
123 | "source": [
124 | "model.fit(X_train, y_train, epochs=1000, callbacks=[callback])"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {
131 | "id": "BgHreXsUdc6a"
132 | },
133 | "outputs": [],
134 | "source": [
135 | "predicted = model.predict(X_test)"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {
142 | "id": "mJC6mB74dkh9"
143 | },
144 | "outputs": [],
145 | "source": [
146 | "import matplotlib.pyplot as plt\n",
147 | "\n",
148 | "def show_result(y_test, predicted):\n",
149 | " plt.figure(figsize=(16, 6))\n",
150 | " plt.plot(y_test.index, predicted, 'o-', label=\"predicted\")\n",
151 | " plt.plot(y_test.index, y_test, '.-', label=\"actual\")\n",
152 | "\n",
153 | " plt.ylabel(\"Passengers\")\n",
154 | " plt.legend()"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {
161 | "colab": {
162 | "base_uri": "https://localhost:8080/",
163 | "height": 396
164 | },
165 | "id": "4spy29-UdzRS",
166 | "outputId": "86904074-d6ec-48a2-ab70-d5085ff21535"
167 | },
168 | "outputs": [],
169 | "source": [
170 | "show_result(y_test, predicted)"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": null,
176 | "metadata": {
177 | "colab": {
178 | "base_uri": "https://localhost:8080/"
179 | },
180 | "id": "Mh4E54N_egps",
181 | "outputId": "3930edda-49da-4ba5-ba5b-2d9c09241a94"
182 | },
183 | "outputs": [],
184 | "source": [
185 | "passengers[\"month\"] = passengers.index.month.values\n",
186 | "passengers[\"year\"] = passengers.index.year.values\n",
187 | "\n",
188 | "model = create_model(passengers)\n",
189 | "X_train, X_test, y_train, y_test = train_test_split(passengers, passengers.passengers.shift(-1), shuffle=False)\n",
190 | "model.fit(X_train, y_train, epochs=100, callbacks=[callback])\n",
191 | "predicted = model.predict(X_test)"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "metadata": {
198 | "colab": {
199 | "base_uri": "https://localhost:8080/",
200 | "height": 379
201 | },
202 | "id": "Y9_1dwAofTM7",
203 | "outputId": "6f8515de-b3d6-4aee-a65b-cde8dc442631"
204 | },
205 | "outputs": [],
206 | "source": [
207 | "show_result(y_test, predicted)"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": null,
213 | "metadata": {
214 | "id": "3eSnjPMYrcbN"
215 | },
216 | "outputs": [],
217 | "source": [
218 | "from tensorflow.keras.layers.experimental import preprocessing\n",
219 | "import tensorflow as tf\n",
220 | "\n",
221 | "\n",
222 | "DROPOUT_RATIO = 0.1\n",
223 | "HIDDEN_NEURONS = 5\n",
224 | "\n",
225 | "\n",
226 | "def create_model(passengers):\n",
227 | " scale = tf.constant(passengers.passengers.std())\n",
228 | "\n",
229 | " continuous_input_layer = keras.layers.Input(shape=1)\n",
230 | "\n",
231 | " categorical_input_layer = keras.layers.Input(shape=1)\n",
232 | " embedded = keras.layers.Embedding(12, 5)(categorical_input_layer)\n",
233 | " embedded_flattened = keras.layers.Flatten()(embedded)\n",
234 | "\n",
235 | " year_input = keras.layers.Input(shape=1)\n",
236 | " year_layer = keras.layers.Dense(1)(year_input)\n",
237 | "\n",
238 | " hidden_output = keras.layers.Concatenate(-1)([embedded_flattened, year_layer, continuous_input_layer])\n",
239 | " output_layer = keras.layers.Dense(1)(hidden_output)\n",
240 | " output = output_layer * scale + continuous_input_layer\n",
241 | "\n",
242 | " model = keras.models.Model(inputs=[\n",
243 | " continuous_input_layer, categorical_input_layer, year_input\n",
244 | " ], outputs=output)\n",
245 | "\n",
246 | " model.compile(loss='mse', optimizer=keras.optimizers.Adam(),\n",
247 | " metrics=[keras.metrics.RootMeanSquaredError(), keras.metrics.MeanAbsoluteError()])\n",
248 | " return model"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "metadata": {
255 | "colab": {
256 | "base_uri": "https://localhost:8080/"
257 | },
258 | "id": "zSyGPARMt9zx",
259 | "outputId": "df8b02fb-fc3a-49da-ac2f-d6e951eca334"
260 | },
261 | "outputs": [],
262 | "source": [
263 | "passengers = pd.Series(values, index=idx, name=\"passengers\").to_frame()\n",
264 | "passengers[\"year\"] = passengers.index.year.values - passengers.index.year.values.min()\n",
265 | "passengers[\"month\"] = passengers.index.month.values - 1\n",
266 | "\n",
267 | "X_train, X_test, y_train, y_test = train_test_split(passengers, passengers.passengers.shift(-1), shuffle=False)\n",
268 | "model = create_model(X_train)\n",
269 | "model.fit(\n",
270 | " (X_train[\"passengers\"], X_train[\"year\"], X_train[\"month\"]),\n",
271 | " y_train, epochs=1000,\n",
272 | " callbacks=[callback]\n",
273 | ")\n",
274 | "predicted = model.predict((X_test[\"passengers\"], X_test[\"year\"], X_test[\"month\"]))"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": null,
280 | "metadata": {
281 | "colab": {
282 | "base_uri": "https://localhost:8080/",
283 | "height": 379
284 | },
285 | "id": "D2fH073LAYYc",
286 | "outputId": "c3bc19a2-7bf6-47dc-f753-75fdfb215aa2"
287 | },
288 | "outputs": [],
289 | "source": [
290 | "show_result(y_test, predicted)"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": null,
296 | "metadata": {
297 | "id": "SMfFWDb8Vs6x"
298 | },
299 | "outputs": [],
300 | "source": []
301 | }
302 | ],
303 | "metadata": {
304 | "colab": {
305 | "collapsed_sections": [],
306 | "name": "Time-Series with Deep Learning",
307 | "provenance": []
308 | },
309 | "kernelspec": {
310 | "display_name": "Python 3",
311 | "language": "python",
312 | "name": "python3"
313 | },
314 | "language_info": {
315 | "codemirror_mode": {
316 | "name": "ipython",
317 | "version": 3
318 | },
319 | "file_extension": ".py",
320 | "mimetype": "text/x-python",
321 | "name": "python",
322 | "nbconvert_exporter": "python",
323 | "pygments_lexer": "ipython3",
324 | "version": "3.8.8"
325 | }
326 | },
327 | "nbformat": 4,
328 | "nbformat_minor": 4
329 | }
330 |
--------------------------------------------------------------------------------
/chapter10/passengers.csv:
--------------------------------------------------------------------------------
1 | date,passengers
2 | 1949-01-31,112.0
3 | 1949-02-28,118.0
4 | 1949-03-31,132.0
5 | 1949-04-30,129.0
6 | 1949-05-31,121.0
7 | 1949-06-30,135.0
8 | 1949-07-31,148.0
9 | 1949-08-31,148.0
10 | 1949-09-30,136.0
11 | 1949-10-31,119.0
12 | 1949-11-30,104.0
13 | 1949-12-31,118.0
14 | 1950-01-31,115.0
15 | 1950-02-28,126.0
16 | 1950-03-31,141.0
17 | 1950-04-30,135.0
18 | 1950-05-31,125.0
19 | 1950-06-30,149.0
20 | 1950-07-31,170.0
21 | 1950-08-31,170.0
22 | 1950-09-30,158.0
23 | 1950-10-31,133.0
24 | 1950-11-30,114.0
25 | 1950-12-31,140.0
26 | 1951-01-31,145.0
27 | 1951-02-28,150.0
28 | 1951-03-31,178.0
29 | 1951-04-30,163.0
30 | 1951-05-31,172.0
31 | 1951-06-30,178.0
32 | 1951-07-31,199.0
33 | 1951-08-31,199.0
34 | 1951-09-30,184.0
35 | 1951-10-31,162.0
36 | 1951-11-30,146.0
37 | 1951-12-31,166.0
38 | 1952-01-31,171.0
39 | 1952-02-29,180.0
40 | 1952-03-31,193.0
41 | 1952-04-30,181.0
42 | 1952-05-31,183.0
43 | 1952-06-30,218.0
44 | 1952-07-31,230.0
45 | 1952-08-31,242.0
46 | 1952-09-30,209.0
47 | 1952-10-31,191.0
48 | 1952-11-30,172.0
49 | 1952-12-31,194.0
50 | 1953-01-31,196.0
51 | 1953-02-28,196.0
52 | 1953-03-31,236.0
53 | 1953-04-30,235.0
54 | 1953-05-31,229.0
55 | 1953-06-30,243.0
56 | 1953-07-31,264.0
57 | 1953-08-31,272.0
58 | 1953-09-30,237.0
59 | 1953-10-31,211.0
60 | 1953-11-30,180.0
61 | 1953-12-31,201.0
62 | 1954-01-31,204.0
63 | 1954-02-28,188.0
64 | 1954-03-31,235.0
65 | 1954-04-30,227.0
66 | 1954-05-31,234.0
67 | 1954-06-30,264.0
68 | 1954-07-31,302.0
69 | 1954-08-31,293.0
70 | 1954-09-30,259.0
71 | 1954-10-31,229.0
72 | 1954-11-30,203.0
73 | 1954-12-31,229.0
74 | 1955-01-31,242.0
75 | 1955-02-28,233.0
76 | 1955-03-31,267.0
77 | 1955-04-30,269.0
78 | 1955-05-31,270.0
79 | 1955-06-30,315.0
80 | 1955-07-31,364.0
81 | 1955-08-31,347.0
82 | 1955-09-30,312.0
83 | 1955-10-31,274.0
84 | 1955-11-30,237.0
85 | 1955-12-31,278.0
86 | 1956-01-31,284.0
87 | 1956-02-29,277.0
88 | 1956-03-31,317.0
89 | 1956-04-30,313.0
90 | 1956-05-31,318.0
91 | 1956-06-30,374.0
92 | 1956-07-31,413.0
93 | 1956-08-31,405.0
94 | 1956-09-30,355.0
95 | 1956-10-31,306.0
96 | 1956-11-30,271.0
97 | 1956-12-31,306.0
98 | 1957-01-31,315.0
99 | 1957-02-28,301.0
100 | 1957-03-31,356.0
101 | 1957-04-30,348.0
102 | 1957-05-31,355.0
103 | 1957-06-30,422.0
104 | 1957-07-31,465.0
105 | 1957-08-31,467.0
106 | 1957-09-30,404.0
107 | 1957-10-31,347.0
108 | 1957-11-30,305.0
109 | 1957-12-31,336.0
110 | 1958-01-31,340.0
111 | 1958-02-28,318.0
112 | 1958-03-31,362.0
113 | 1958-04-30,348.0
114 | 1958-05-31,363.0
115 | 1958-06-30,435.0
116 | 1958-07-31,491.0
117 | 1958-08-31,505.0
118 | 1958-09-30,404.0
119 | 1958-10-31,359.0
120 | 1958-11-30,310.0
121 | 1958-12-31,337.0
122 | 1959-01-31,360.0
123 | 1959-02-28,342.0
124 | 1959-03-31,406.0
125 | 1959-04-30,396.0
126 | 1959-05-31,420.0
127 | 1959-06-30,472.0
128 | 1959-07-31,548.0
129 | 1959-08-31,559.0
130 | 1959-09-30,463.0
131 | 1959-10-31,407.0
132 | 1959-11-30,362.0
133 | 1959-12-31,405.0
134 | 1960-01-31,417.0
135 | 1960-02-29,391.0
136 | 1960-03-31,419.0
137 | 1960-04-30,461.0
138 | 1960-05-31,472.0
139 | 1960-06-30,535.0
140 | 1960-07-31,622.0
141 | 1960-08-31,606.0
142 | 1960-09-30,508.0
143 | 1960-10-31,461.0
144 | 1960-11-30,390.0
145 | 1960-12-31,432.0
146 |
--------------------------------------------------------------------------------
/chapter11/Ranking_with_Bandits.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "708ef2ed-176f-42e3-a47a-bcc324221daa",
6 | "metadata": {},
7 | "source": [
8 | "
"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "55093d0c-238e-4ac3-ad08-2fe7ff52ca13",
14 | "metadata": {},
15 | "source": [
16 | "After https://github.com/Kenza-AI/mab-ranking/blob/master/examples/jester/example.ipynb"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 3,
22 | "id": "c5fcc1c0-2e0c-48b3-9f10-5263e1f7ae33",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "URL = 'https://raw.githubusercontent.com/PacktPublishing/Machine-Learning-for-Time-Series-with-Python/main/chapter11/jesterfinal151cols.csv'"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 4,
32 | "id": "533e06c6-62dd-48f7-be50-3356666d364b",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "import pandas as pd\n",
37 | "jester_data = pd.read_csv(URL, header=None)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "id": "a9a7e5a7-bdfa-4a11-ad83-f319735b44cc",
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "# jester_data.columns = [f\"joke_{col}\" for col in jester_data.columns]\n",
48 | "jester_data.index.name = \"users\""
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "id": "7490c980-6f0f-453f-8b02-8ca5f101aecb",
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "jester_data.head()"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "5990c5cc-66ce-482c-b7f4-1f4911470d6e",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "for col in jester_data.columns:\n",
69 | " jester_data[col] = jester_data[col].apply(lambda x: 0.0 if x>=99 or x<7.0 else 1.0)"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "id": "f4b220de-0657-4e4f-a01f-33c879a20aa5",
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "jester_data"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "id": "578d30fa-7e4f-41a9-bbdd-86c1f3898fbe",
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "# keep users with at least one rating for a joke\n",
90 | "jester_data = jester_data[jester_data.sum(axis=1) > 0]"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "id": "9432df2f-4500-46c7-a47c-1a1f0e592b38",
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "jester_data"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "id": "85cf3f6c-c823-4883-b250-8bf7a7c2cf3e",
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "pip install git+https://github.com/benman1/mab-ranking"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "id": "f5302c81-3478-4ac3-b1f9-96f721f75515",
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "# setting up the bandits:\n",
121 | "from mab_ranking.bandits.rank_bandits import IndependentBandits\n",
122 | "from mab_ranking.bandits.bandits import BetaThompsonSampling, DirichletThompsonSampling\n",
123 | "\n",
124 | "independent_bandits = IndependentBandits(\n",
125 | " num_arms=jester_data.shape[1],\n",
126 | " num_ranks=10,\n",
127 | " bandit_class=DirichletThompsonSampling\n",
128 | ")"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "id": "329ba955-7490-402a-a007-428121d130f4",
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "from tqdm import trange\n",
139 | "\n",
140 | "num_steps = 7000\n",
141 | "hit_rates = []\n",
142 | "for _ in trange(1, num_steps + 1):\n",
143 | " selected_items = set(independent_bandits.choose())\n",
144 | " # Pick a users choices at random\n",
145 | " random_user = jester_data.sample().iloc[0, :]\n",
146 | " ground_truth = set(random_user[random_user == 1].index)\n",
147 | " hit_rate = len(ground_truth.intersection(selected_items)) / len(ground_truth)\n",
148 | " feedback_list = [1.0 if item in ground_truth else 0.0 for item in selected_items]\n",
149 | " independent_bandits.update(selected_items, feedback_list)\n",
150 | " hit_rates.append(hit_rate)"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "id": "6c1acf3d-b2ca-4f00-9b22-3d703c22713a",
157 | "metadata": {},
158 | "outputs": [],
159 | "source": [
160 | "import matplotlib.pyplot as plt\n",
161 | "\n",
162 | "stats = pd.Series(hit_rates)\n",
163 | "plt.figure(figsize=(12, 6))\n",
164 | "plt.plot(stats.index, stats.rolling(200).mean(), \"--\")\n",
165 | "plt.xlabel('Iteration')\n",
166 | "plt.ylabel('Hit rate')"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "id": "d1535111-fa30-45b1-b723-5bfaa24eba8e",
173 | "metadata": {},
174 | "outputs": [],
175 | "source": [
176 | "from sklearn.cluster import KMeans\n",
177 | "from sklearn.preprocessing import StandardScaler\n",
178 | "\n",
179 | "scaler = StandardScaler().fit(jester_data)\n",
180 | "kmeans = KMeans(n_clusters=5, random_state=0).fit(scaler.transform(jester_data))\n",
181 | "contexts = pd.Series(kmeans.labels_, index=jester_data.index)"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "id": "ac755bbe-306e-4086-a24f-a8ef55792287",
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "contexts.value_counts()"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "id": "128d0fe1-5edb-47e1-aeab-62531a620a56",
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "independent_bandits = IndependentBandits(\n",
202 | " num_arms=jester_data.shape[1],\n",
203 | " num_ranks=10,\n",
204 | " bandit_class=DirichletThompsonSampling\n",
205 | ")\n",
206 | "\n",
207 | "num_steps = 7000\n",
208 | "hit_rates = []\n",
209 | "for _ in trange(1, num_steps + 1):\n",
210 | " # Pick a users choices at random\n",
211 | " random_user = jester_data.sample().iloc[0, :]\n",
212 | " context = {\"previous_action\": contexts.loc[random_user.name]}\n",
213 | " selected_items = set(independent_bandits.choose(\n",
214 | " context=context\n",
215 | " ))\n",
216 | " ground_truth = set(random_user[random_user == 1].index)\n",
217 | " hit_rate = len(ground_truth.intersection(selected_items)) / len(ground_truth)\n",
218 | " feedback_list = [1.0 if item in ground_truth else 0.0 for item in selected_items]\n",
219 | " independent_bandits.update(selected_items, feedback_list, context=context)\n",
220 | " hit_rates.append(hit_rate)"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "id": "daefac60-eba6-4429-8a33-067c26483402",
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "import matplotlib.pyplot as plt\n",
231 | "\n",
232 | "stats = pd.Series(hit_rates)\n",
233 | "plt.figure(figsize=(12, 6))\n",
234 | "plt.plot(stats.index, stats.rolling(200).mean(), \"--\")\n",
235 | "plt.xlabel('Iteration')\n",
236 | "plt.ylabel('Hit rate')"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": null,
242 | "id": "016e11fb-b2e7-4705-8aae-251468762bc4",
243 | "metadata": {},
244 | "outputs": [],
245 | "source": []
246 | }
247 | ],
248 | "metadata": {
249 | "kernelspec": {
250 | "display_name": "Python 3",
251 | "language": "python",
252 | "name": "python3"
253 | },
254 | "language_info": {
255 | "codemirror_mode": {
256 | "name": "ipython",
257 | "version": 3
258 | },
259 | "file_extension": ".py",
260 | "mimetype": "text/x-python",
261 | "name": "python",
262 | "nbconvert_exporter": "python",
263 | "pygments_lexer": "ipython3",
264 | "version": "3.8.8"
265 | }
266 | },
267 | "nbformat": 4,
268 | "nbformat_minor": 5
269 | }
270 |
--------------------------------------------------------------------------------
/chapter11/Trading_with_DQN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "
"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "id": "P6xXgYonWE8n"
15 | },
16 | "outputs": [],
17 | "source": [
18 | "# based on https://github.com/tensortrade-org/tensortrade/blob/master/examples/train_and_evaluate.ipynb"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {
25 | "colab": {
26 | "base_uri": "https://localhost:8080/"
27 | },
28 | "id": "hp1Fk0dGJXpT",
29 | "outputId": "c7215ab3-26d0-45a9-8c47-6973f527ace6"
30 | },
31 | "outputs": [],
32 | "source": [
33 | "pip install git+https://github.com/tensortrade-org/tensortrade.git"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "metadata": {
40 | "id": "UkrQ9KU3JXpU"
41 | },
42 | "outputs": [],
43 | "source": [
44 | "# all imports:\n",
45 | "import pandas as pd\n",
46 | "import tensortrade.env.default as default\n",
47 | "\n",
48 | "from tensortrade.data.cdd import CryptoDataDownload\n",
49 | "from tensortrade.feed.core import Stream, DataFeed\n",
50 | "from tensortrade.oms.exchanges import Exchange\n",
51 | "from tensortrade.oms.services.execution.simulated import execute_order\n",
52 | "from tensortrade.oms.instruments import USD, BTC, ETH\n",
53 | "from tensortrade.oms.wallets import Wallet, Portfolio\n",
54 | "from tensortrade.agents import DQNAgent\n",
55 | "\n",
56 | "\n",
57 | "%matplotlib inline"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {
64 | "id": "QpPq3bKwJXpU"
65 | },
66 | "outputs": [],
67 | "source": [
68 | "cdd = CryptoDataDownload()\n",
69 | "\n",
70 | "data = cdd.fetch(\"Bitstamp\", \"USD\", \"BTC\", \"1h\")"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {
77 | "colab": {
78 | "base_uri": "https://localhost:8080/",
79 | "height": 204
80 | },
81 | "id": "yepqF5x7JXpV",
82 | "outputId": "166d22d5-e050-4f44-9b27-96d5c4f7b6e1"
83 | },
84 | "outputs": [],
85 | "source": [
86 | "data.head()"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {
93 | "id": "8PVhnDBvJXpW"
94 | },
95 | "outputs": [],
96 | "source": [
97 | "# we'll create a couple of indicators:\n",
98 | "def rsi(price: Stream[float], period: float) -> Stream[float]:\n",
99 | " r = price.diff()\n",
100 | " upside = r.clamp_min(0).abs()\n",
101 | " downside = r.clamp_max(0).abs()\n",
102 | " rs = upside.ewm(alpha=1 / period).mean() / downside.ewm(alpha=1 / period).mean()\n",
103 | " return 100*(1 - (1 + rs) ** -1)\n",
104 | "\n",
105 | "\n",
106 | "def macd(price: Stream[float], fast: float, slow: float, signal: float) -> Stream[float]:\n",
107 | " fm = price.ewm(span=fast, adjust=False).mean()\n",
108 | " sm = price.ewm(span=slow, adjust=False).mean()\n",
109 | " md = fm - sm\n",
110 | " signal = md - md.ewm(span=signal, adjust=False).mean()\n",
111 | " return signal## Create features with the feed module"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "# choosing the closing price:\n",
121 | "features = []\n",
122 | "for c in data.columns[1:]:\n",
123 | " s = Stream.source(list(data[c]), dtype=\"float\").rename(data[c].name)\n",
124 | " features += [s]\n",
125 | "\n",
126 | "cp = Stream.select(features, lambda s: s.name == \"close\")"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {
133 | "id": "hYjM8JoJJXpW"
134 | },
135 | "outputs": [],
136 | "source": [
137 | "# adding the three features (trend indicator, RSI, MACD):\n",
138 | "features = [\n",
139 | " cp.log().diff().rename(\"lr\"),\n",
140 | " rsi(cp, period=20).rename(\"rsi\"),\n",
141 | " macd(cp, fast=10, slow=50, signal=5).rename(\"macd\")\n",
142 | "]\n",
143 | "\n",
144 | "feed = DataFeed(features)\n",
145 | "feed.compile()"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {
152 | "colab": {
153 | "base_uri": "https://localhost:8080/"
154 | },
155 | "id": "uqXBeRwrJXpX",
156 | "outputId": "69b29f71-03e6-4c70-a6e3-3d8aa3b23717"
157 | },
158 | "outputs": [],
159 | "source": [
160 | "for i in range(5):\n",
161 | " print(feed.next())"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "# setting up broker and the portfolio:\n",
171 | "bitstamp = Exchange(\"bitstamp\", service=execute_order)(\n",
172 | " Stream.source(list(data[\"close\"]), dtype=\"float\").rename(\"USD-BTC\")\n",
173 | ")\n",
174 | "\n",
175 | "portfolio = Portfolio(USD, [\n",
176 | " Wallet(bitstamp, 10000 * USD),\n",
177 | " Wallet(bitstamp, 10 * BTC)\n",
178 | "])"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": null,
184 | "metadata": {},
185 | "outputs": [],
186 | "source": [
187 | "# renderer:\n",
188 | "renderer_feed = DataFeed([\n",
189 | " Stream.source(list(data[\"date\"])).rename(\"date\"),\n",
190 | " Stream.source(list(data[\"open\"]), dtype=\"float\").rename(\"open\"),\n",
191 | " Stream.source(list(data[\"high\"]), dtype=\"float\").rename(\"high\"),\n",
192 | " Stream.source(list(data[\"low\"]), dtype=\"float\").rename(\"low\"),\n",
193 | " Stream.source(list(data[\"close\"]), dtype=\"float\").rename(\"close\"), \n",
194 | " Stream.source(list(data[\"volume\"]), dtype=\"float\").rename(\"volume\") \n",
195 | "])"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "metadata": {
202 | "id": "IzWYGEyZJXpY"
203 | },
204 | "outputs": [],
205 | "source": [
206 | "# the trading environment:\n",
207 | "env = default.create(\n",
208 | " portfolio=portfolio,\n",
209 | " action_scheme=\"managed-risk\",\n",
210 | " reward_scheme=\"risk-adjusted\",\n",
211 | " feed=feed,\n",
212 | " renderer_feed=renderer_feed,\n",
213 | " renderer=default.renderers.PlotlyTradingChart(),\n",
214 | " window_size=20\n",
215 | ")"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": null,
221 | "metadata": {
222 | "colab": {
223 | "base_uri": "https://localhost:8080/"
224 | },
225 | "id": "hYUd95vjJXpY",
226 | "outputId": "45d3f699-abe3-49d4-fd87-d7bba3e953f9"
227 | },
228 | "outputs": [],
229 | "source": [
230 | "env.observer.feed.next()"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": null,
236 | "metadata": {
237 | "colab": {
238 | "base_uri": "https://localhost:8080/",
239 | "height": 1000,
240 | "referenced_widgets": [
241 | "9bcf0bbf1c1c49859a2dfa40acce1038"
242 | ]
243 | },
244 | "id": "YoxUET9ZJXpZ",
245 | "outputId": "a7672ac6-9c9c-42ff-b9d3-8a128fd882ae"
246 | },
247 | "outputs": [],
248 | "source": [
249 | "# training a DQN trading agent\n",
250 | "agent = DQNAgent(env)\n",
251 | "\n",
252 | "agent.train(n_steps=200, n_episodes=2, save_path=\"agents/\")"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "metadata": {
259 | "colab": {
260 | "base_uri": "https://localhost:8080/",
261 | "height": 282
262 | },
263 | "id": "dtoYz5bOKGdo",
264 | "outputId": "fc4372f2-1d3f-4072-e452-2d91d41a2abe"
265 | },
266 | "outputs": [],
267 | "source": [
268 | "%matplotlib inline\n",
269 | "\n",
270 | "performance = pd.DataFrame.from_dict(env.action_scheme.portfolio.performance, orient='index')\n",
271 | "performance.plot()"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": null,
277 | "metadata": {
278 | "colab": {
279 | "base_uri": "https://localhost:8080/",
280 | "height": 282
281 | },
282 | "id": "URT6-E9wVtIt",
283 | "outputId": "0b293bc2-46a6-4031-ea0c-9c8415f60d77"
284 | },
285 | "outputs": [],
286 | "source": [
287 | "performance[\"net_worth\"].plot()"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": null,
293 | "metadata": {
294 | "id": "GymCxiqfV0VM"
295 | },
296 | "outputs": [],
297 | "source": []
298 | }
299 | ],
300 | "metadata": {
301 | "colab": {
302 | "collapsed_sections": [],
303 | "name": "train_and_evaluate.ipynb",
304 | "provenance": []
305 | },
306 | "kernelspec": {
307 | "display_name": "Python 3",
308 | "language": "python",
309 | "name": "python3"
310 | },
311 | "language_info": {
312 | "codemirror_mode": {
313 | "name": "ipython",
314 | "version": 3
315 | },
316 | "file_extension": ".py",
317 | "mimetype": "text/x-python",
318 | "name": "python",
319 | "nbconvert_exporter": "python",
320 | "pygments_lexer": "ipython3",
321 | "version": "3.8.8"
322 | }
323 | },
324 | "nbformat": 4,
325 | "nbformat_minor": 4
326 | }
327 |
--------------------------------------------------------------------------------
/chapter2/README.md:
--------------------------------------------------------------------------------
1 | # Chapter 2 - Time-Series Analysis with Python
2 |
3 | The CSV file `spm.csv` was downloaded from the [Our World in Data github repository](https://github.com/owid/owid-datasets/blob/master/datasets) (Air pollution by city - Fouquet and DPCC).
4 |
5 | The global temperatures were downloaded from [Datahub](https://datahub.io/core/global-temp).
6 |
--------------------------------------------------------------------------------
/chapter2/spm.csv:
--------------------------------------------------------------------------------
1 | Entity,Year,Smoke (Fouquet and DPCC (2011)),Suspended Particulate Matter (SPM) (Fouquet and DPCC (2011))
2 | Delhi,1997,,363
3 | Delhi,1998,,378
4 | Delhi,1999,,375
5 | Delhi,2000,,431
6 | Delhi,2001,,382
7 | Delhi,2002,,456
8 | Delhi,2003,,391
9 | Delhi,2004,,390
10 | Delhi,2005,,373
11 | Delhi,2006,,433
12 | Delhi,2007,,365
13 | Delhi,2008,,416
14 | Delhi,2009,,492
15 | Delhi,2010,,481
16 | London,1700,142.8571429,259.7402597
17 | London,1701,144.2857143,262.3376623
18 | London,1702,145.7142857,264.9350649
19 | London,1703,147.1428571,267.5324675
20 | London,1704,148.5714286,270.1298701
21 | London,1705,150,272.7272727
22 | London,1706,151.4285714,275.3246753
23 | London,1707,152.8571429,277.9220779
24 | London,1708,154.2857143,280.5194805
25 | London,1709,155.7142857,283.1168831
26 | London,1710,157.1428571,285.7142857
27 | London,1711,158.5714286,288.3116883
28 | London,1712,160,290.9090909
29 | London,1713,161.4285714,293.5064935
30 | London,1714,162.8571429,296.1038961
31 | London,1715,164.2857143,298.7012987
32 | London,1716,165.7142857,301.2987013
33 | London,1717,167.1428571,303.8961039
34 | London,1718,168.5714286,306.4935065
35 | London,1719,170,309.0909091
36 | London,1720,171.4285714,311.6883117
37 | London,1721,172.8571429,314.2857143
38 | London,1722,174.2857143,316.8831169
39 | London,1723,175.7142857,319.4805195
40 | London,1724,177.1428571,322.0779221
41 | London,1725,178.5714286,324.6753247
42 | London,1726,179.2857143,325.974026
43 | London,1727,180,327.2727273
44 | London,1728,180.7142857,328.5714286
45 | London,1729,181.4285714,329.8701299
46 | London,1730,182.1428571,331.1688312
47 | London,1731,182.8571429,332.4675325
48 | London,1732,183.5714286,333.7662338
49 | London,1733,184.2857143,335.0649351
50 | London,1734,185,336.3636364
51 | London,1735,185.7142857,337.6623377
52 | London,1736,186.4285714,338.961039
53 | London,1737,187.1428571,340.2597403
54 | London,1738,187.8571429,341.5584416
55 | London,1739,188.5714286,342.8571429
56 | London,1740,189.2857143,344.1558442
57 | London,1741,190,345.4545455
58 | London,1742,190.7142857,346.7532468
59 | London,1743,191.4285714,348.0519481
60 | London,1744,192.1428571,349.3506494
61 | London,1745,192.8571429,350.6493506
62 | London,1746,193.5714286,351.9480519
63 | London,1747,194.2857143,353.2467532
64 | London,1748,195,354.5454545
65 | London,1749,195.7142857,355.8441558
66 | London,1750,196.4285714,357.1428571
67 | London,1751,196.7857143,357.7922078
68 | London,1752,197.1428571,358.4415584
69 | London,1753,197.5,359.0909091
70 | London,1754,197.8571429,359.7402597
71 | London,1755,198.2142857,360.3896104
72 | London,1756,198.5714286,361.038961
73 | London,1757,198.9285714,361.6883117
74 | London,1758,199.2857143,362.3376623
75 | London,1759,199.6428571,362.987013
76 | London,1760,200,363.6363636
77 | London,1761,200.3571429,364.2857143
78 | London,1762,200.7142857,364.9350649
79 | London,1763,201.0714286,365.5844156
80 | London,1764,201.4285714,366.2337662
81 | London,1765,201.7857143,366.8831169
82 | London,1766,202.1428571,367.5324675
83 | London,1767,202.5,368.1818182
84 | London,1768,202.8571429,368.8311688
85 | London,1769,203.2142857,369.4805195
86 | London,1770,203.5714286,370.1298701
87 | London,1771,203.9285714,370.7792208
88 | London,1772,204.2857143,371.4285714
89 | London,1773,204.6428571,372.0779221
90 | London,1774,205,372.7272727
91 | London,1775,205.3571429,373.3766234
92 | London,1776,205.7142857,374.025974
93 | London,1777,206.0714286,374.6753247
94 | London,1778,206.4285714,375.3246753
95 | London,1779,206.7857143,375.974026
96 | London,1780,207.1428571,376.6233766
97 | London,1781,207.5,377.2727273
98 | London,1782,207.8571429,377.9220779
99 | London,1783,208.2142857,378.5714286
100 | London,1784,208.5714286,379.2207792
101 | London,1785,208.9285714,379.8701299
102 | London,1786,209.2857143,380.5194805
103 | London,1787,209.6428571,381.1688312
104 | London,1788,210,381.8181818
105 | London,1789,210.3571429,382.4675325
106 | London,1790,210.7142857,383.1168831
107 | London,1791,211.0714286,383.7662338
108 | London,1792,211.4285714,384.4155844
109 | London,1793,211.7857143,385.0649351
110 | London,1794,212.1428571,385.7142857
111 | London,1795,212.5,386.3636364
112 | London,1796,212.8571429,387.012987
113 | London,1797,213.2142857,387.6623377
114 | London,1798,213.5714286,388.3116883
115 | London,1799,213.9285714,388.961039
116 | London,1800,214.2857143,389.6103896
117 | London,1801,216.1428571,392.987013
118 | London,1802,218,396.3636364
119 | London,1803,219.8571429,399.7402597
120 | London,1804,221.7142857,403.1168831
121 | London,1805,223.5714286,406.4935065
122 | London,1806,225.4285714,409.8701299
123 | London,1807,227.2857143,413.2467532
124 | London,1808,229.1428571,416.6233766
125 | London,1809,231,420
126 | London,1810,232.8571429,423.3766234
127 | London,1811,234.7142857,426.7532468
128 | London,1812,236.5714286,430.1298701
129 | London,1813,238.4285714,433.5064935
130 | London,1814,240.2857143,436.8831169
131 | London,1815,242.1428571,440.2597403
132 | London,1816,244,443.6363636
133 | London,1817,245.8571429,447.012987
134 | London,1818,247.7142857,450.3896104
135 | London,1819,249.5714286,453.7662338
136 | London,1820,251.4285714,457.1428571
137 | London,1821,253.2857143,460.5194805
138 | London,1822,255.1428571,463.8961039
139 | London,1823,257,467.2727273
140 | London,1824,258.8571429,470.6493506
141 | London,1825,260.7142857,474.025974
142 | London,1826,262.5714286,477.4025974
143 | London,1827,264.4285714,480.7792208
144 | London,1828,266.2857143,484.1558442
145 | London,1829,268.1428571,487.5324675
146 | London,1830,270,490.9090909
147 | London,1831,271.8571429,494.2857143
148 | London,1832,273.7142857,497.6623377
149 | London,1833,275.5714286,501.038961
150 | London,1834,277.4285714,504.4155844
151 | London,1835,279.2857143,507.7922078
152 | London,1836,281.1428571,511.1688312
153 | London,1837,283,514.5454545
154 | London,1838,284.8571429,517.9220779
155 | London,1839,286.7142857,521.2987013
156 | London,1840,288.5714286,524.6753247
157 | London,1841,290.4285714,528.0519481
158 | London,1842,292.2857143,531.4285714
159 | London,1843,294.1428571,534.8051948
160 | London,1844,296,538.1818182
161 | London,1845,297.8571429,541.5584416
162 | London,1846,299.7142857,544.9350649
163 | London,1847,301.5714286,548.3116883
164 | London,1848,303.4285714,551.6883117
165 | London,1849,305.2857143,555.0649351
166 | London,1850,307.1428571,558.4415584
167 | London,1851,307.7142857,559.4805195
168 | London,1852,308.2857143,560.5194805
169 | London,1853,308.8571429,561.5584416
170 | London,1854,309.4285714,562.5974026
171 | London,1855,310,563.6363636
172 | London,1856,310.5714286,564.6753247
173 | London,1857,311.1428571,565.7142857
174 | London,1858,311.7142857,566.7532468
175 | London,1859,312.2857143,567.7922078
176 | London,1860,312.8571429,568.8311688
177 | London,1861,313.4285714,569.8701299
178 | London,1862,314,570.9090909
179 | London,1863,314.5714286,571.9480519
180 | London,1864,315.1428571,572.987013
181 | London,1865,315.7142857,574.025974
182 | London,1866,316.2857143,575.0649351
183 | London,1867,316.8571429,576.1038961
184 | London,1868,317.4285714,577.1428571
185 | London,1869,318,578.1818182
186 | London,1870,318.5714286,579.2207792
187 | London,1871,319.1428571,580.2597403
188 | London,1872,319.7142857,581.2987013
189 | London,1873,320.2857143,582.3376623
190 | London,1874,320.8571429,583.3766234
191 | London,1875,321.4285714,584.4155844
192 | London,1876,322.7678571,586.8506494
193 | London,1877,324.1071429,589.2857143
194 | London,1878,325.4464286,591.7207792
195 | London,1879,326.7857143,594.1558442
196 | London,1880,328.125,596.5909091
197 | London,1881,329.4642857,599.025974
198 | London,1882,330.8035714,601.461039
199 | London,1883,332.1428571,603.8961039
200 | London,1884,333.4821429,606.3311688
201 | London,1885,334.8214286,608.7662338
202 | London,1886,336.1607143,611.2012987
203 | London,1887,337.5,613.6363636
204 | London,1888,338.8392857,616.0714286
205 | London,1889,340.1785714,618.5064935
206 | London,1890,341.5178571,620.9415584
207 | London,1891,342.8571429,623.3766234
208 | London,1892,339.6825397,617.6046176
209 | London,1893,336.5079365,611.8326118
210 | London,1894,333.3333333,606.0606061
211 | London,1895,330.1587302,600.2886003
212 | London,1896,326.984127,594.5165945
213 | London,1897,323.8095238,588.7445887
214 | London,1898,320.6349206,582.972583
215 | London,1899,317.4603175,577.2005772
216 | London,1900,314.2857143,571.4285714
217 | London,1901,314.1142857,571.1168831
218 | London,1902,313.9428571,570.8051948
219 | London,1903,313.7714286,570.4935065
220 | London,1904,313.6,570.1818182
221 | London,1905,313.4285714,569.8701299
222 | London,1906,313.2571429,569.5584416
223 | London,1907,313.0857143,569.2467532
224 | London,1908,312.9142857,568.9350649
225 | London,1909,312.7428571,568.6233766
226 | London,1910,312.5714286,568.3116883
227 | London,1911,312.4,568
228 | London,1912,312.2285714,567.6883117
229 | London,1913,310,563.6363636
230 | London,1914,305.4545455,555.3719008
231 | London,1915,300.9090909,547.107438
232 | London,1916,296.3636364,538.8429752
233 | London,1917,291.8181818,530.5785124
234 | London,1918,287.2727273,522.3140496
235 | London,1919,282.7272727,514.0495868
236 | London,1920,278.1818182,505.785124
237 | London,1921,273.6363636,497.5206612
238 | London,1922,269.0909091,489.2561983
239 | London,1923,264.5454545,480.9917355
240 | London,1924,260,472.7272727
241 | London,1925,225,409.0909091
242 | London,1926,200,363.6363636
243 | London,1927,198,360
244 | London,1928,197,358.1818182
245 | London,1929,198,360
246 | London,1930,199,361.8181818
247 | London,1931,200,363.6363636
248 | London,1932,210,381.8181818
249 | London,1933,220,400
250 | London,1934,222,403.6363636
251 | London,1935,225,409.0909091
252 | London,1936,223,405.4545455
253 | London,1937,221,401.8181818
254 | London,1938,220,400
255 | London,1939,211.25,384.0909091
256 | London,1940,202.5,368.1818182
257 | London,1941,193.75,352.2727273
258 | London,1942,185,336.3636364
259 | London,1943,176.25,320.4545455
260 | London,1944,167.5,304.5454545
261 | London,1945,158.75,288.6363636
262 | London,1946,150,272.7272727
263 | London,1947,141.25,256.8181818
264 | London,1948,132.5,240.9090909
265 | London,1949,123.75,225
266 | London,1950,115,209.0909091
267 | London,1951,112.6709302,204.8562368
268 | London,1952,110.3418605,200.6215645
269 | London,1953,108.0127907,196.3868922
270 | London,1954,105.6837209,192.1522199
271 | London,1955,103.3546512,187.9175476
272 | London,1956,101.0255814,183.6828753
273 | London,1957,98.69651163,179.448203
274 | London,1958,96.36744186,175.2135307
275 | London,1959,94.03837209,170.9788584
276 | London,1960,91.70930233,166.744186
277 | London,1961,89.38023256,162.5095137
278 | London,1962,87.05116279,158.2748414
279 | London,1963,84.72209302,154.0401691
280 | London,1964,82.39302326,149.8054968
281 | London,1965,80.06395349,145.5708245
282 | London,1966,77.73488372,141.3361522
283 | London,1967,75.40581395,137.1014799
284 | London,1968,73.07674419,132.8668076
285 | London,1969,70.74767442,128.6321353
286 | London,1970,68.41860465,124.397463
287 | London,1971,66.08953488,120.1627907
288 | London,1972,63.76046512,115.9281184
289 | London,1973,61.43139535,111.6934461
290 | London,1974,59.10232558,107.4587738
291 | London,1975,56.77325581,103.2241015
292 | London,1976,54.44418605,98.98942918
293 | London,1977,52.11511628,94.75475687
294 | London,1978,49.78604651,90.52008457
295 | London,1979,47.45697674,86.28541226
296 | London,1980,45.12790698,82.05073996
297 | London,1981,42.79883721,77.81606765
298 | London,1982,40.46976744,73.58139535
299 | London,1983,38.14069767,69.34672304
300 | London,1984,35.81162791,65.11205074
301 | London,1985,33.48255814,60.87737844
302 | London,1986,31.15348837,56.64270613
303 | London,1987,28.8244186,52.40803383
304 | London,1988,26.49534884,48.17336152
305 | London,1989,24.16627907,43.93868922
306 | London,1990,21.8372093,39.70401691
307 | London,1991,19.50813953,35.46934461
308 | London,1992,17.17906977,31.2346723
309 | London,1993,14.85,27
310 | London,1994,14.3,26
311 | London,1995,14.3,26
312 | London,1996,14.3,26
313 | London,1997,13.75,25
314 | London,1998,13.75,25
315 | London,1999,13.75,25
316 | London,2000,13.75,25
317 | London,2001,,23.22
318 | London,2002,,23
319 | London,2003,,24.79
320 | London,2004,,22.03
321 | London,2005,,22.35
322 | London,2006,,23.82
323 | London,2007,,21.56
324 | London,2008,,18.81
325 | London,2009,,19
326 | London,2010,,20
327 | London,2011,,20
328 | London,2012,,17
329 | London,2013,,17
330 | London,2014,,17
331 | London,2015,,15
332 | London,2016,,16
333 |
--------------------------------------------------------------------------------
/chapter3/Preprocessing.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "6a06a10a-ee82-40ce-8821-3da680af4b95",
6 | "metadata": {},
7 | "source": [
8 | "
"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "614a3348-214c-4472-969d-e7db10664075",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "pip install -U tsfresh workalendar astral \"featuretools[tsfresh]\" sktime"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "9a82dfb1-f90b-4f38-91fb-34756dece1e5",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import datetime\n",
29 | "import pandas as pd\n",
30 | "import matplotlib.pyplot as plt\n",
31 | "import seaborn as sns"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "id": "96db671d-bfdb-48f2-b572-f1d35e2b14a1",
37 | "metadata": {},
38 | "source": [
39 | "# Transformations"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "id": "0bd9698d-3905-42a6-92b4-1d3d6c80f79f",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "import numpy as np\n",
50 | "\n",
51 | "np.random.seed(0)\n",
52 | "pts = 10000\n",
53 | "vals = np.random.lognormal(0, 1.0, pts)"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "id": "55097e4a-c0d8-4cd9-a7e5-5e59738c0475",
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "plt.hist(vals, bins=20, density=True)\n",
64 | "plt.yscale(\"log\")\n",
65 | "plt.ylabel(\"frequency\")\n",
66 | "plt.xlabel(\"value range\");"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "id": "1a0306b4-c03d-46e4-9ed1-1c8e30c3f7bb",
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "from sklearn.preprocessing import minmax_scale\n",
77 | "from sklearn.preprocessing import StandardScaler\n",
78 | "from scipy.stats import normaltest\n",
79 | "\n",
80 | "vals_mm = minmax_scale(vals)\n",
81 | "scaler = StandardScaler()\n",
82 | "vals_ss = scaler.fit_transform(vals.reshape(-1, 1))\n",
83 | "_, p = normaltest(vals_ss.squeeze())\n",
84 | "print(f\"significance: {p:.2f}\")"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "id": "f32df59b-5443-48b2-b2ef-07a4f79bb869",
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "_, p = normaltest(vals_mm.squeeze())\n",
95 | "print(f\"significance: {p:.2f}\")"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "id": "9fc25e7e-8bca-44ad-8d14-100cd62ee1ec",
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "plt.scatter(vals, vals_ss, alpha=0.3)\n",
106 | "plt.ylabel(\"standard scaled\")\n",
107 | "plt.xlabel(\"original\");"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "id": "e88633cd-5542-4366-84bd-ef59e46f253a",
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "from statsmodels.stats.diagnostic import kstest_normal\n",
118 | "\n",
119 | "log_transformed = np.log(vals)\n",
120 | "_, p = kstest_normal(log_transformed) # stats.normaltest\n",
121 | "print(f\"significance: {p:.2f}\")"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "id": "8ed96d42-2023-463d-82cf-250b5bb69733",
128 | "metadata": {},
129 | "outputs": [],
130 | "source": [
131 | "np.std(log_transformed)"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "id": "0f9faf67-1496-4383-a93e-abfb67e4dc9c",
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "np.mean(log_transformed)"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "id": "ae3a9a4b-f856-468e-8f7a-07178544f52d",
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "plt.hist(log_transformed, bins=20, density=True)\n",
152 | "#plt.yscale(\"log\")\n",
153 | "plt.ylabel(\"frequency\")\n",
154 | "plt.xlabel(\"value range\");"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "id": "89f9f296-456d-4789-88c0-2e2eb35bc8a8",
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "from scipy.stats import boxcox\n",
165 | "vals_bc = boxcox(vals, 0.0)\n",
166 | "_, p = normaltest(vals_bc)\n",
167 | "print(f\"significance: {p:.2f}\")"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "id": "40697ad8-fbc7-4c62-ab4a-64f813c13ec6",
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "plt.hist(vals_bc, bins=20, density=True)\n",
178 | "plt.ylabel(\"frequency\")\n",
179 | "plt.xlabel(\"value range\");"
180 | ]
181 | },
182 | {
183 | "cell_type": "markdown",
184 | "id": "da2aa62a-fe40-49e2-b2d9-ebcb1a91cda8",
185 | "metadata": {},
186 | "source": [
187 | "# Imputation"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "id": "40bddaee-38d3-4d40-9773-e6a4721e4b62",
194 | "metadata": {},
195 | "outputs": [],
196 | "source": [
197 | "import numpy as np\n",
198 | "from sklearn.impute import SimpleImputer\n",
199 | "imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')\n",
200 | "imp_mean.fit([[7, 2, 3], [4, np.nan, 6], [10, 5, 9]])\n",
201 | "SimpleImputer()\n",
202 | "df = [[np.nan, 2, 3], [4, np.nan, 6], [10, np.nan, 9]]\n",
203 | "print(imp_mean.transform(df))\n"
204 | ]
205 | },
206 | {
207 | "cell_type": "markdown",
208 | "id": "503b0b9a-71a9-475c-9906-e59fb9853932",
209 | "metadata": {},
210 | "source": [
211 | "# Derived Date Features"
212 | ]
213 | },
214 | {
215 | "cell_type": "markdown",
216 | "id": "5aa0e489-638d-46e7-8850-518c616e22e5",
217 | "metadata": {},
218 | "source": [
219 | "## Holidays"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "id": "9efee10c-2d3d-4265-886a-50e87e380a97",
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "from workalendar.europe.united_kingdom import UnitedKingdom\n",
230 | "UnitedKingdom().holidays()"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": null,
236 | "id": "1539b652-a1f8-41a6-8410-cccad3d5edb6",
237 | "metadata": {},
238 | "outputs": [],
239 | "source": [
240 | "from typing import List\n",
241 | "from dateutil.relativedelta import relativedelta, TH\n",
242 | "import datetime\n",
243 | "from workalendar.usa import California\n",
244 | "\n",
245 | "def create_custom_holidays(year) -> List:\n",
246 | " custom_holidays = California().holidays()\n",
247 | " custom_holidays.append((\n",
248 | " (datetime.datetime(year, 11, 1) + relativedelta(weekday=TH(+4)) + datetime.timedelta(days=1)).date(),\n",
249 | " \"Black Friday\"\n",
250 | " ))\n",
251 | " return {k: v for (k, v) in custom_holidays}\n",
252 | "\n",
253 | "custom_holidays = create_custom_holidays(2021)"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "id": "bec50434-e7b4-409b-8f37-9b400ce0a2bf",
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "custom_holidays"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": null,
269 | "id": "3e8c9a76-0ec7-40c8-8eda-e7e480281501",
270 | "metadata": {},
271 | "outputs": [],
272 | "source": [
273 | "def is_holiday(current_date: datetime.date):\n",
274 | " \"\"\"Determine if we have a holiday.\"\"\"\n",
275 | " return custom_holidays.get(current_date, False)\n",
276 | "\n",
277 | "today = datetime.date(2021, 4, 11)\n",
278 | "is_holiday(today)"
279 | ]
280 | },
281 | {
282 | "cell_type": "markdown",
283 | "id": "ddb63805-b782-4944-a88b-4996783177f7",
284 | "metadata": {},
285 | "source": [
286 | "## Date Annotations"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": null,
292 | "id": "eab731b5-fa14-4c8a-bd39-0a805c36aa80",
293 | "metadata": {},
294 | "outputs": [],
295 | "source": [
296 | "import calendar\n",
297 | "\n",
298 | "calendar.monthrange(2021, 1)"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": null,
304 | "id": "f9d4a8b3-fea7-4e75-86e4-fff5cc3ffa9e",
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "from datetime import date\n",
309 | "def year_anchor(current_date: datetime.date):\n",
310 | " return (\n",
311 | " (current_date - date(current_date.year, 1, 1)).days,\n",
312 | " (date(current_date.year, 12, 31) - current_date).days,\n",
313 | " )\n",
314 | "\n",
315 | "year_anchor(today)\n"
316 | ]
317 | },
318 | {
319 | "cell_type": "code",
320 | "execution_count": null,
321 | "id": "d6a4e6d4-3a68-48a5-b564-29e9f7b0b72d",
322 | "metadata": {},
323 | "outputs": [],
324 | "source": [
325 | "def month_anchor(current_date: datetime.date):\n",
326 | " last_day = calendar.monthrange(current_date.year, current_date.month)[0]\n",
327 | " \n",
328 | " return (\n",
329 | " (current_date - datetime.date(current_date.year, current_date.month, 1)).days,\n",
330 | " (current_date - datetime.date(current_date.year, current_date.month, last_day)).days,\n",
331 | " )\n",
332 | "\n",
333 | "month_anchor(today)\n"
334 | ]
335 | },
336 | {
337 | "cell_type": "markdown",
338 | "id": "5453766c-e8d5-4fd8-88a8-4b87602c506f",
339 | "metadata": {},
340 | "source": [
341 | "## Paydays"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": null,
347 | "id": "5fa16210-e016-4cb6-85b9-84f964255f81",
348 | "metadata": {},
349 | "outputs": [],
350 | "source": [
351 | "def get_last_friday(current_date: datetime.date, weekday=calendar.FRIDAY):\n",
352 | " return max(week[weekday]\n",
353 | " for week in calendar.monthcalendar(\n",
354 | " current_date.year, current_date.month\n",
355 | " ))\n",
356 | "\n",
357 | "get_last_friday(today)\n"
358 | ]
359 | },
360 | {
361 | "cell_type": "markdown",
362 | "id": "2a2cca85-1e9f-4e3c-ad34-a6b1122e217a",
363 | "metadata": {},
364 | "source": [
365 | "## Seasons"
366 | ]
367 | },
368 | {
369 | "cell_type": "code",
370 | "execution_count": null,
371 | "id": "6abd71ed-2478-4c8a-a578-1b961e8a1de4",
372 | "metadata": {},
373 | "outputs": [],
374 | "source": [
375 | "YEAR = 2021\n",
376 | "seasons = [\n",
377 | " ('winter', (date(YEAR, 1, 1), date(YEAR, 3, 20))),\n",
378 | " ('spring', (date(YEAR, 3, 21), date(YEAR, 6, 20))),\n",
379 | " ('summer', (date(YEAR, 6, 21), date(YEAR, 9, 22))),\n",
380 | " ('autumn', (date(YEAR, 9, 23), date(YEAR, 12, 20))),\n",
381 | " ('winter', (date(YEAR, 12, 21), date(YEAR, 12, 31)))\n",
382 | "]\n",
383 | "\n",
384 | "def is_in_interval(current_date: datetime.date, seasons):\n",
385 | " return next(season for season, (start, end) in seasons\n",
386 | " if start <= current_date.replace(year=YEAR) <= end)\n",
387 | " \n",
388 | "is_in_interval(today, seasons)\n"
389 | ]
390 | },
391 | {
392 | "cell_type": "markdown",
393 | "id": "bff1e4f1-03a6-4710-8fcc-93ca10e529c1",
394 | "metadata": {},
395 | "source": [
396 | "## Sun and Moon"
397 | ]
398 | },
399 | {
400 | "cell_type": "code",
401 | "execution_count": null,
402 | "id": "2e79e8c5-a2c9-4ce0-bc55-b15c8564a877",
403 | "metadata": {},
404 | "outputs": [],
405 | "source": [
406 | "!pip install astral"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": null,
412 | "id": "92c55bb2-caf3-4a70-b51a-84bfc073b2ce",
413 | "metadata": {},
414 | "outputs": [],
415 | "source": [
416 | "from astral.sun import sun\n",
417 | "from astral import LocationInfo\n",
418 | "CITY = LocationInfo(\"London\", \"England\", \"Europe/London\", 51.5, -0.116)\n",
419 | "def get_sunrise_dusk(current_date: datetime.date, city_name='London'):\n",
420 | " s = sun(CITY.observer, date=current_date)\n",
421 | " sunrise = s['sunrise']\n",
422 | " dusk = s['dusk']\n",
423 | " return (sunrise - dusk).seconds / 3600\n",
424 | "\n",
425 | "get_sunrise_dusk(today)\n"
426 | ]
427 | },
428 | {
429 | "cell_type": "markdown",
430 | "id": "98f041ac-5c80-408e-909a-3364975f6be0",
431 | "metadata": {},
432 | "source": [
433 | "## Business Days"
434 | ]
435 | },
436 | {
437 | "cell_type": "code",
438 | "execution_count": null,
439 | "id": "aac6eeef-2d1b-4bf2-a2cb-75a9c4648a93",
440 | "metadata": {},
441 | "outputs": [],
442 | "source": [
443 | "def get_business_days(current_date: datetime.date):\n",
444 | " last_day = calendar.monthrange(current_date.year, current_date.month)[1]\n",
445 | " rng = pd.date_range(current_date.replace(day=1), periods=last_day, freq='D')\n",
446 | " business_days = pd.bdate_range(rng[0], rng[-1])\n",
447 | " return len(business_days), last_day - len(business_days)\n",
448 | "\n",
449 | "get_business_days(date.today())\n"
450 | ]
451 | },
452 | {
453 | "cell_type": "markdown",
454 | "id": "91f21e9b-c326-472b-a59a-85a908502fe4",
455 | "metadata": {},
456 | "source": [
457 | "# Automated Feature Extraction"
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "execution_count": null,
463 | "id": "7df98e51-5f84-49fd-b13c-a7a488bcc4ea",
464 | "metadata": {},
465 | "outputs": [],
466 | "source": [
467 | "import featuretools as ft\n",
468 | "from featuretools.primitives import Minute, Hour, Day, Month, Year, Weekday\n",
469 | "\n",
470 | "data = pd.DataFrame(\n",
471 | " {'Time': ['2014-01-01 01:41:50',\n",
472 | " '2014-01-01 02:06:50',\n",
473 | " '2014-01-01 02:31:50',\n",
474 | " '2014-01-01 02:56:50',\n",
475 | " '2014-01-01 03:21:50'],\n",
476 | " 'Target': [0, 0, 0, 0, 1]}\n",
477 | ") \n",
478 | "data['index'] = data.index\n",
479 | "es = ft.EntitySet('My EntitySet')\n",
480 | "es.entity_from_dataframe(\n",
481 | " entity_id='main_data_table',\n",
482 | " index='index',\n",
483 | " dataframe=data,\n",
484 | " time_index='Time'\n",
485 | ")\n",
486 | "fm, features = ft.dfs(\n",
487 | " entityset=es, \n",
488 | " target_entity='main_data_table', \n",
489 | " trans_primitives=[Minute, Hour, Day, Month, Year, Weekday]\n",
490 | ")\n"
491 | ]
492 | },
493 | {
494 | "cell_type": "code",
495 | "execution_count": null,
496 | "id": "39b85175-58d9-44bc-8374-09af58f245d3",
497 | "metadata": {},
498 | "outputs": [],
499 | "source": [
500 | "fm"
501 | ]
502 | },
503 | {
504 | "cell_type": "code",
505 | "execution_count": null,
506 | "id": "11f5581e-0ce2-4a90-9f0c-5aeefbaa2096",
507 | "metadata": {},
508 | "outputs": [],
509 | "source": [
510 | "from tsfresh.feature_extraction import extract_features\n",
511 | "from tsfresh.feature_extraction import ComprehensiveFCParameters\n",
512 | "\n",
513 | "settings = ComprehensiveFCParameters()\n",
514 | "extract_features(data, column_id='Time', default_fc_parameters=settings)\n"
515 | ]
516 | },
517 | {
518 | "cell_type": "markdown",
519 | "id": "d78b955f-04a1-4751-8142-e3b7c8002278",
520 | "metadata": {},
521 | "source": [
522 | "## ROCKET"
523 | ]
524 | },
525 | {
526 | "cell_type": "code",
527 | "execution_count": null,
528 | "id": "cc173b6d-be20-424a-b8b9-3ae617cf5f39",
529 | "metadata": {},
530 | "outputs": [],
531 | "source": [
532 | "from sktime.datasets import load_arrow_head\n",
533 | "from sktime.utils.data_processing import from_nested_to_2d_array\n",
534 | "# please note that this import changes in version 0.8:\n",
535 | "# from sktime.datatypes._panel._convert import from_nested_to_2d_array\n",
536 | "\n",
537 | "X_train, y_train = load_arrow_head(split=\"train\", return_X_y=True)\n",
538 | "from_nested_to_2d_array(X_train).head()"
539 | ]
540 | },
541 | {
542 | "cell_type": "code",
543 | "execution_count": null,
544 | "id": "b865bd92-8ba9-46f7-a15d-4745ad8bf5ea",
545 | "metadata": {},
546 | "outputs": [],
547 | "source": [
548 | "from sktime.transformations.panel.rocket import Rocket\n",
549 | "rocket = Rocket(num_kernels=1000)\n",
550 | "rocket.fit(X_train)\n",
551 | "X_train_transform = rocket.transform(X_train)\n"
552 | ]
553 | },
554 | {
555 | "cell_type": "markdown",
556 | "id": "474f4e40-d2b1-43da-98d9-767d922c6f91",
557 | "metadata": {},
558 | "source": [
559 | "## Shapelets"
560 | ]
561 | },
562 | {
563 | "cell_type": "code",
564 | "execution_count": null,
565 | "id": "0d99e98f-17ec-4b01-b8e0-01b062e1dfbb",
566 | "metadata": {},
567 | "outputs": [],
568 | "source": [
569 | "from sktime.transformations.panel.shapelets import ContractedShapeletTransform\n",
570 | "shapelets_transform = ContractedShapeletTransform(\n",
571 | " time_contract_in_mins=1,\n",
572 | " num_candidates_to_sample_per_case=10,\n",
573 | " verbose=0,\n",
574 | ")\n",
575 | "shapelets_transform.fit(X_train, y_train)\n"
576 | ]
577 | },
578 | {
579 | "cell_type": "code",
580 | "execution_count": null,
581 | "id": "a5468a48-7669-49e9-87e3-827f25cc8b2b",
582 | "metadata": {},
583 | "outputs": [],
584 | "source": [
585 | "X_train_transform = shapelets_transform.transform(X_train)"
586 | ]
587 | },
588 | {
589 | "cell_type": "code",
590 | "execution_count": null,
591 | "id": "6812d1d2-7de2-48ee-91c5-70741ff8a320",
592 | "metadata": {},
593 | "outputs": [],
594 | "source": [
595 | "X_train_transform"
596 | ]
597 | },
598 | {
599 | "cell_type": "code",
600 | "execution_count": null,
601 | "id": "6d3410b6-a2ab-40fb-a942-b1fcfd423512",
602 | "metadata": {},
603 | "outputs": [],
604 | "source": []
605 | }
606 | ],
607 | "metadata": {
608 | "kernelspec": {
609 | "display_name": "Python 3",
610 | "language": "python",
611 | "name": "python3"
612 | },
613 | "language_info": {
614 | "codemirror_mode": {
615 | "name": "ipython",
616 | "version": 3
617 | },
618 | "file_extension": ".py",
619 | "mimetype": "text/x-python",
620 | "name": "python",
621 | "nbconvert_exporter": "python",
622 | "pygments_lexer": "ipython3",
623 | "version": "3.8.8"
624 | }
625 | },
626 | "nbformat": 4,
627 | "nbformat_minor": 5
628 | }
629 |
--------------------------------------------------------------------------------
/chapter5/Forecasting.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "6a06a10a-ee82-40ce-8821-3da680af4b95",
6 | "metadata": {},
7 | "source": [
8 | "
"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "614a3348-214c-4472-969d-e7db10664075",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "!pip install statsmodels yfinance pmdarima arch"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "cb11cb6d-0e34-418b-9085-2b5d88ed343d",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import matplotlib.pyplot as plt\n",
29 | "import seaborn as sns\n",
30 | "import matplotlib as mpl\n",
31 | "import pandas as pd\n",
32 | "\n",
33 | "plt.style.use('seaborn-whitegrid')\n",
34 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
35 | "plt.rcParams[\"font.size\"] = \"17\"\n",
36 | "mpl.rcParams['lines.linewidth'] = 2\n",
37 | "mpl.rcParams['lines.markersize'] = 1\n",
38 | "# plt.style.use('.matplotlibrc')"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "id": "d90f30a7-9ad6-4773-bf95-d6a07298757d",
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "from datetime import datetime\n",
49 | "import yfinance as yf\n",
50 | " \n",
51 | "start_date = datetime(2005, 1, 1)\n",
52 | "end_date = datetime(2021, 1, 1)\n",
53 | "\n",
54 | "df = yf.download(\n",
55 | " 'SPY',\n",
56 | " start=start_date,\n",
57 | " end = end_date\n",
58 | ")"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "f11d3381-2b1c-400a-98d9-b41d53061e8b",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "plt.figure(figsize = (12, 6))\n",
69 | "plt.title('Opening Prices between {} and {}'.format(\n",
70 | " start_date.date().isoformat(),\n",
71 | " end_date.date().isoformat()\n",
72 | "))\n",
73 | "df['Open'].plot()\n",
74 | "plt.ylabel('Price')\n",
75 | "plt.xlabel('Date');"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "id": "c6b1629a-92eb-4c52-a523-24427d120786",
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "df.head()"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "id": "6289d877-f934-4543-9849-87bc3ef8a3c5",
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "df1 = df.reset_index().resample('W', on=\"Date\")['Open'].mean()"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "id": "6db93897-2993-4490-9558-dec619e30f85",
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "# some years have 53 weeks. We can't handle that, so we'll get rid of the 53rd week.\n",
106 | "df1 = df1[df1.index.week < 53]"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "id": "567145da-61fc-4bc2-859c-3e36ad5c7421",
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "# final check: \n",
117 | "df1.index.week.value_counts().plot.bar()"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "id": "3df5304c-7c73-45bb-8d4d-94af87ac8dc3",
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "df1 = df1[~df1.isnull()]"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "id": "171c9318-b8cd-46d0-8018-0cac48888d96",
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "# let's fix the frequency:\n",
138 | "df1 = df1.asfreq('W').fillna(method='ffill')"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "id": "38a4ba79-453a-4e18-bbb5-652838270b90",
145 | "metadata": {},
146 | "outputs": [],
147 | "source": [
148 | "df1.index.freq"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "id": "c70e30be-4481-4aa4-9142-b3614dfb39e9",
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "import statsmodels.api as sm\n",
159 | "\n",
160 | "fig, axs = plt.subplots(2)\n",
161 | "sm.graphics.tsa.plot_pacf(df1, lags=20, ax=axs[0])\n",
162 | "axs[0].set_ylabel('R')\n",
163 | "axs[0].set_xlabel('Lag')\n",
164 | "sm.graphics.tsa.plot_acf(df1, lags=20, ax=axs[1]);\n",
165 | "axs[1].set_ylabel('R')\n",
166 | "axs[1].set_xlabel('Lag')\n",
167 | "fig.tight_layout()"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "id": "e88633cd-5542-4366-84bd-ef59e46f253a",
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "from statsmodels.tsa.seasonal import seasonal_decompose\n",
178 | "result = seasonal_decompose(df1, model='additive', period=52)\n",
179 | "result.plot();"
180 | ]
181 | },
182 | {
183 | "cell_type": "markdown",
184 | "id": "da2aa62a-fe40-49e2-b2d9-ebcb1a91cda8",
185 | "metadata": {},
186 | "source": [
187 | "# Finding a value for d"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "id": "fbeefe1e-67ce-403d-a990-a7cf19aa9318",
193 | "metadata": {},
194 | "source": [
195 | "We are using the ARCH package which has more convenient versions of both the ADF and the KPSS tests."
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "id": "18eb1bc7-85a8-4380-8346-4c28aeb16116",
202 | "metadata": {},
203 | "outputs": [],
204 | "source": [
205 | "from arch.unitroot import KPSS, ADF\n",
206 | "\n",
207 | "ADF(df1)"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": null,
213 | "id": "40bddaee-38d3-4d40-9773-e6a4721e4b62",
214 | "metadata": {},
215 | "outputs": [],
216 | "source": [
217 | "from pmdarima.arima.utils import ndiffs\n",
218 | "\n",
219 | "# ADF Test:\n",
220 | "ndiffs(df1, test='adf') # 1; same values for the KPSS and the PP test"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "id": "6d3410b6-a2ab-40fb-a942-b1fcfd423512",
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "# what happens if we forget differencing?\n",
231 | "# We get a helful warning: 'Non-stationary starting autoregressive parameters'\n",
232 | "mod = sm.tsa.arima.ARIMA(endog=df1, order=(1, 0, 0))\n",
233 | "res = mod.fit()\n",
234 | "print(res.summary())"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": null,
240 | "id": "12cf8ad9-bcd8-42c0-ad6f-164ff81d7dec",
241 | "metadata": {},
242 | "outputs": [],
243 | "source": [
244 | "# let's try again and this time, we'll take into account the seasonality:\n",
245 | "from statsmodels.tsa.forecasting.stl import STLForecast\n",
246 | "\n",
247 | "mod = STLForecast(df1, sm.tsa.arima.ARIMA, model_kwargs=dict(order=(1, 1, 0), trend=\"t\"))\n",
248 | "res = mod.fit().model_result\n",
249 | "print(res.summary())"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": null,
255 | "id": "454a5679-45ac-45b5-860a-8a4685d7e9eb",
256 | "metadata": {},
257 | "outputs": [],
258 | "source": [
259 | "# doing a forecast:\n",
260 | "STEPS = 20\n",
261 | "forecasts_df = res.get_forecast(steps=STEPS).summary_frame() \n",
262 | "ax = df1.plot(figsize=(12, 6))\n",
263 | "plt.ylabel('SPY')\n",
264 | "forecasts_df['mean'].plot(style='k--')\n",
265 | "ax.fill_between(\n",
266 | " forecasts_df.index,\n",
267 | " forecasts_df['mean_ci_lower'],\n",
268 | " forecasts_df['mean_ci_upper'],\n",
269 | " color='k',\n",
270 | " alpha=0.1\n",
271 | ")"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": null,
277 | "id": "90cc47f4-1444-40af-9d84-e64e4518e0e0",
278 | "metadata": {},
279 | "outputs": [],
280 | "source": [
281 | "forecasts = []\n",
282 | "qs = []\n",
283 | "for q in range(0, 30, 10):\n",
284 | " mod = STLForecast(df1, sm.tsa.arima.ARIMA, model_kwargs=dict(order=(0, 1, q), trend=\"t\"))\n",
285 | " res = mod.fit().model_result\n",
286 | " print(f\"aic ({q}): {res.aic}\")\n",
287 | " forecasts.append(\n",
288 | " res.get_forecast(steps=STEPS).summary_frame()['mean']\n",
289 | " )\n",
290 | " qs.append(q)\n",
291 | "\n",
292 | "forecasts_df = pd.concat(forecasts, axis=1)\n",
293 | "forecasts_df.columns = qs"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": null,
299 | "id": "1afc0557-d41d-4540-95c5-9244b1ceca24",
300 | "metadata": {},
301 | "outputs": [],
302 | "source": [
303 | "# plotting the three forecasts:\n",
304 | "ax = df1.plot(figsize=(12, 6))\n",
305 | "plt.ylabel('SPY')\n",
306 | "forecasts_df.plot(ax=ax)"
307 | ]
308 | },
309 | {
310 | "cell_type": "code",
311 | "execution_count": null,
312 | "id": "f9b65061-b5ca-453f-92a9-b5c7341f8932",
313 | "metadata": {},
314 | "outputs": [],
315 | "source": [
316 | "mod = sm.tsa.ExponentialSmoothing(\n",
317 | " endog=df1, trend='add', seasonal_periods=52, use_boxcox=True, initialization_method=\"heuristic\"\n",
318 | " )\n",
319 | "res = mod.fit()\n",
320 | "print(res.summary())"
321 | ]
322 | },
323 | {
324 | "cell_type": "code",
325 | "execution_count": null,
326 | "id": "92209cc0-b046-4c5b-a9ee-4a3afb5d615a",
327 | "metadata": {},
328 | "outputs": [],
329 | "source": [
330 | "forecasts = pd.Series(res.forecast(steps=STEPS))\n",
331 | "ax = df1.plot(figsize=(12, 6))\n",
332 | "plt.ylabel('SPY')\n",
333 | "forecasts.plot(style='k--')"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": null,
339 | "id": "c79a4000-663b-4264-a0f2-c00427ccfc11",
340 | "metadata": {},
341 | "outputs": [],
342 | "source": [
343 | "from statsmodels.tsa.forecasting.theta import ThetaModel\n",
344 | "\n",
345 | "train_length = int(len(df1) * 0.8)\n",
346 | "tm = ThetaModel(df1[:train_length], method=\"auto\", deseasonalize=True)\n",
347 | "res = tm.fit()\n",
348 | "forecasts = res.forecast(steps=len(df1)-train_length)\n",
349 | "ax = df1.plot(figsize=(12, 6))\n",
350 | "plt.ylabel('SPY')\n",
351 | "forecasts.plot(style='k--')"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": null,
357 | "id": "949b7c6c-1326-44a6-8e5d-c58d0f4191e7",
358 | "metadata": {},
359 | "outputs": [],
360 | "source": [
361 | "from sklearn import metrics\n",
362 | "\n",
363 | "metrics.mean_squared_error(forecasts, df1[train_length:], squared=False)"
364 | ]
365 | },
366 | {
367 | "cell_type": "code",
368 | "execution_count": null,
369 | "id": "9733e78f-ee79-44c4-a26e-496670a68cad",
370 | "metadata": {},
371 | "outputs": [],
372 | "source": []
373 | }
374 | ],
375 | "metadata": {
376 | "kernelspec": {
377 | "display_name": "Python 3",
378 | "language": "python",
379 | "name": "python3"
380 | },
381 | "language_info": {
382 | "codemirror_mode": {
383 | "name": "ipython",
384 | "version": 3
385 | },
386 | "file_extension": ".py",
387 | "mimetype": "text/x-python",
388 | "name": "python",
389 | "nbconvert_exporter": "python",
390 | "pygments_lexer": "ipython3",
391 | "version": "3.8.8"
392 | }
393 | },
394 | "nbformat": 4,
395 | "nbformat_minor": 5
396 | }
397 |
--------------------------------------------------------------------------------
/chapter6/Change-Points_Anomalies.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "
"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "# Anomaly detection"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "!pip install alibi_detect"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "from alibi_detect.datasets import fetch_kdd\n",
33 | "\n",
34 | "intrusions = fetch_kdd()"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "intrusions[\"target\"].sum() / len(intrusions[\"target\"])"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "intrusions[\"feature_names\"]"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "scores = od.score(intrusions[\"data\"][:, 0])"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "import pandas as pd\n",
71 | "\n",
72 | "pd.Series(intrusions[\"data\"][:, 0]).plot();"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "from alibi_detect.od import SpectralResidual\n",
82 | "\n",
83 | "od = SpectralResidual(\n",
84 | " threshold=1.0, window_amp=20, window_local=20, n_est_points=10, n_grad_points=5\n",
85 | ")\n",
86 | "intrusion_outliers = od.predict(intrusions[\"data\"][:,0])"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "import matplotlib\n",
96 | "\n",
97 | "ax = pd.Series(intrusions[\"data\"][:, 0], name=\"data\").plot(\n",
98 | " legend=False, figsize=(12, 6)\n",
99 | ")\n",
100 | "ax2 = ax.twinx()\n",
101 | "ax = pd.Series(scores, name=\"scores\").plot(\n",
102 | " ax=ax2, legend=False, color=\"r\", marker=matplotlib.markers.CARETDOWNBASE\n",
103 | ")\n",
104 | "ax.figure.legend(bbox_to_anchor=(1, 1), loc=\"upper left\");"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "metadata": {},
110 | "source": [
111 | "# Change point detection"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "!pip install ruptures"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "import matplotlib.pyplot as plt\n",
130 | "import numpy as np\n",
131 | "import ruptures as rpt\n",
132 | "\n",
133 | "plt.style.use(\"seaborn-whitegrid\")\n",
134 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
135 | "plt.rcParams[\"font.size\"] = \"17\""
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "signal, bkps = rpt.pw_constant(\n",
145 | " n_samples=500, n_features=3, n_bkps=2, noise_std=5.0, delta=(1, 20)\n",
146 | ")"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "metadata": {},
153 | "outputs": [],
154 | "source": [
155 | "rpt.display(signal, bkps)"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "signal.shape"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "metadata": {},
171 | "outputs": [],
172 | "source": [
173 | "bkps"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": null,
179 | "metadata": {},
180 | "outputs": [],
181 | "source": [
182 | "# \"l1\", \"rbf\", \"linear\", \"normal\", \"ar\"\n",
183 | "algo = rpt.Binseg(model=\"l1\").fit(signal)\n",
184 | "my_bkps = algo.predict(n_bkps=3)\n",
185 | "\n",
186 | "# show results\n",
187 | "rpt.show.display(signal, bkps, my_bkps, figsize=(10, 6))"
188 | ]
189 | }
190 | ],
191 | "metadata": {
192 | "kernelspec": {
193 | "display_name": "Python 3",
194 | "language": "python",
195 | "name": "python3"
196 | },
197 | "language_info": {
198 | "codemirror_mode": {
199 | "name": "ipython",
200 | "version": 3
201 | },
202 | "file_extension": ".py",
203 | "mimetype": "text/x-python",
204 | "name": "python",
205 | "nbconvert_exporter": "python",
206 | "pygments_lexer": "ipython3",
207 | "version": "3.8.8"
208 | }
209 | },
210 | "nbformat": 4,
211 | "nbformat_minor": 4
212 | }
213 |
--------------------------------------------------------------------------------
/chapter7/KNN_with_dynamic_DTW.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "
"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "colab": {
15 | "base_uri": "https://localhost:8080/"
16 | },
17 | "id": "r1kjFHY6xlbo",
18 | "outputId": "4b3f00df-2b21-4b3f-809f-7df77ada4790"
19 | },
20 | "outputs": [],
21 | "source": [
22 | "!pip install tsfresh \"statsmodels<=0.12\""
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {
29 | "colab": {
30 | "base_uri": "https://localhost:8080/"
31 | },
32 | "id": "Pg0UkW57-VlF",
33 | "outputId": "8275dc45-501c-40cd-ec9e-19d7d42935ad"
34 | },
35 | "outputs": [],
36 | "source": [
37 | "!pip install tslearn"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {
44 | "id": "5VtsgQxuzMUt"
45 | },
46 | "outputs": [],
47 | "source": [
48 | "import pandas as pd\n",
49 | "import matplotlib.pyplot as plt\n",
50 | "\n",
51 | "plt.style.use('seaborn-whitegrid')\n",
52 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
53 | "plt.rcParams[\"font.size\"] = \"17\""
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {
60 | "colab": {
61 | "base_uri": "https://localhost:8080/"
62 | },
63 | "id": "iPQPBG849jB2",
64 | "outputId": "13f96036-b472-4445-d0ca-75d02a48016a"
65 | },
66 | "outputs": [],
67 | "source": [
68 | "from tsfresh.examples import load_robot_execution_failures\n",
69 | "from tsfresh.examples.robot_execution_failures import download_robot_execution_failures\n",
70 | "\n",
71 | "download_robot_execution_failures()\n",
72 | "df_ts, y = load_robot_execution_failures()"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {
79 | "colab": {
80 | "base_uri": "https://localhost:8080/"
81 | },
82 | "id": "UOHxdjmAWKZc",
83 | "outputId": "67d575b7-d649-40ab-8edd-2d0c7950a6a2"
84 | },
85 | "outputs": [],
86 | "source": [
87 | "from tsfresh import extract_features\n",
88 | "from tsfresh import select_features\n",
89 | "from tsfresh.utilities.dataframe_functions import impute\n",
90 | "\n",
91 | "extracted_features = impute(extract_features(df_ts, column_id=\"id\", column_sort=\"time\"))\n",
92 | "features_filtered = select_features(extracted_features, y)"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {
99 | "colab": {
100 | "base_uri": "https://localhost:8080/"
101 | },
102 | "id": "7D0JxrCpQL6x",
103 | "outputId": "b75b4054-afab-4f55-8c61-9a6c7a7f0d8d"
104 | },
105 | "outputs": [],
106 | "source": [
107 | "from tsfresh.transformers import RelevantFeatureAugmenter\n",
108 | "import pandas as pd\n",
109 | "\n",
110 | "X = pd.DataFrame(index=y.index)\n",
111 | "TRAINING_SIZE = (len(X) // 10) * 8\n",
112 | "augmenter = RelevantFeatureAugmenter(column_id='id', column_sort='time')\n",
113 | "augmenter.set_timeseries_container(df_ts[:TRAINING_SIZE])\n",
114 | "augmenter.fit(X[:TRAINING_SIZE], y[:TRAINING_SIZE])\n",
115 | "X_transformed = augmenter.transform(X)"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {
122 | "colab": {
123 | "base_uri": "https://localhost:8080/"
124 | },
125 | "id": "jzaRzQc63Iiv",
126 | "outputId": "384b8608-39d3-4b52-8e23-4baefd4c05cc"
127 | },
128 | "outputs": [],
129 | "source": [
130 | "from sklearn.model_selection import TimeSeriesSplit, GridSearchCV\n",
131 | "from tslearn.neighbors import KNeighborsTimeSeriesClassifier\n",
132 | "\n",
133 | "knn = KNeighborsTimeSeriesClassifier()\n",
134 | "param_search = {\n",
135 | " 'metric' : ['dtw'], # ctw', 'dtw', \"softdtw\"], # TSLEARN_VALID_METRICS}\n",
136 | " 'n_neighbors': [1, 2, 3]\n",
137 | "}\n",
138 | "# 'param_grid': {'metric': ['ctw', 'dtw', 'gak', 'sax', 'softdtw', 'lcss']},\n",
139 | "tscv = TimeSeriesSplit(n_splits=2)\n",
140 | "\n",
141 | "gsearch = GridSearchCV(\n",
142 | " estimator=knn,\n",
143 | " cv=tscv,\n",
144 | " param_grid=param_search\n",
145 | ")\n",
146 | "gsearch.fit(\n",
147 | " features_filtered,\n",
148 | " y\n",
149 | ")"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "metadata": {
156 | "id": "j7XePgIYK1WQ"
157 | },
158 | "outputs": [],
159 | "source": [
160 | "# Adapted from comments on this discussion thread on stackoverflow: https://stackoverflow.com/questions/37161563/how-to-graph-grid-scores-from-gridsearchcv\n",
161 | "import seaborn as sns\n",
162 | "import pandas as pd\n",
163 | "\n",
164 | "def plot_cv_results(cv_results, param_x, param_z, metric='mean_test_score'):\n",
165 | " \"\"\"\n",
166 | " cv_results - cv_results_ attribute of a GridSearchCV instance (or similar)\n",
167 | " param_x - name of grid search parameter to plot on x axis\n",
168 | " param_z - name of grid search parameter to plot by line color\n",
169 | " \"\"\"\n",
170 | " cv_results = pd.DataFrame(cv_results)\n",
171 | " col_x = 'param_' + param_x\n",
172 | " col_z = 'param_' + param_z\n",
173 | " fig, ax = plt.subplots(1, 1, figsize=(11, 8))\n",
174 | " sns.pointplot(x=col_x, y=metric, hue=col_z, data=cv_results, ci=99, n_boot=64, ax=ax)\n",
175 | " ax.set_title(\"CV Grid Search Results\")\n",
176 | " ax.set_xlabel(param_x)\n",
177 | " ax.set_ylabel(metric)\n",
178 | " ax.legend(title=param_z)\n",
179 | " return fig\n"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "metadata": {
186 | "colab": {
187 | "base_uri": "https://localhost:8080/",
188 | "height": 564
189 | },
190 | "id": "-mHYRnizXwsz",
191 | "outputId": "41bb85fe-a2e4-4f51-911d-54630be0645a"
192 | },
193 | "outputs": [],
194 | "source": [
195 | "fig = plot_cv_results(gsearch.cv_results_, 'metric', 'n_neighbors')"
196 | ]
197 | }
198 | ],
199 | "metadata": {
200 | "colab": {
201 | "collapsed_sections": [],
202 | "name": "KNN with dynamic DTW",
203 | "provenance": []
204 | },
205 | "kernelspec": {
206 | "display_name": "Python 3",
207 | "language": "python",
208 | "name": "python3"
209 | },
210 | "language_info": {
211 | "codemirror_mode": {
212 | "name": "ipython",
213 | "version": 3
214 | },
215 | "file_extension": ".py",
216 | "mimetype": "text/x-python",
217 | "name": "python",
218 | "nbconvert_exporter": "python",
219 | "pygments_lexer": "ipython3",
220 | "version": "3.8.8"
221 | }
222 | },
223 | "nbformat": 4,
224 | "nbformat_minor": 4
225 | }
226 |
--------------------------------------------------------------------------------
/chapter7/Kats.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "7ea3881f-ddfc-475b-bfa9-bb46456ffee4",
6 | "metadata": {},
7 | "source": [
8 | "
"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "bde1b68e-20b6-42e1-a5a8-81824e365545",
15 | "metadata": {
16 | "colab": {
17 | "base_uri": "https://localhost:8080/"
18 | },
19 | "id": "bde1b68e-20b6-42e1-a5a8-81824e365545",
20 | "outputId": "f290afb2-8026-4792-e11f-2b5572336c3a"
21 | },
22 | "outputs": [],
23 | "source": [
24 | "!MINIMAL=1 pip install kats"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "Jb7RFf_D-YzN",
31 | "metadata": {
32 | "colab": {
33 | "base_uri": "https://localhost:8080/"
34 | },
35 | "id": "Jb7RFf_D-YzN",
36 | "outputId": "681437c3-35c7-461a-f512-170d929c1a38"
37 | },
38 | "outputs": [],
39 | "source": [
40 | "!pip install \"numpy==1.20\""
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "id": "e9f62717-49c6-4a08-bd69-67a4c3ebe2d5",
47 | "metadata": {
48 | "id": "e9f62717-49c6-4a08-bd69-67a4c3ebe2d5"
49 | },
50 | "outputs": [],
51 | "source": [
52 | "import matplotlib.pyplot as plt\n",
53 | "import seaborn as sns\n",
54 | "%matplotlib inline\n",
55 | "\n",
56 | "plt.style.use('seaborn-whitegrid')\n",
57 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
58 | "plt.rcParams[\"font.size\"] = \"17\""
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "2045898b-51b2-446a-bf30-4102725d36d9",
65 | "metadata": {
66 | "id": "2045898b-51b2-446a-bf30-4102725d36d9"
67 | },
68 | "outputs": [],
69 | "source": [
70 | "import pandas as pd\n",
71 | "\n",
72 | "owid_covid = pd.read_csv(\"https://covid.ourworldindata.org/data/owid-covid-data.csv\")\n",
73 | "owid_covid[\"date\"] = pd.to_datetime(owid_covid[\"date\"])\n",
74 | "df = owid_covid[owid_covid.location == \"France\"].set_index(\"date\", drop=True).resample('D').interpolate(method='linear').reset_index()"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "d196768c-ccd4-487b-b88a-825f3280f8b8",
81 | "metadata": {
82 | "id": "d196768c-ccd4-487b-b88a-825f3280f8b8"
83 | },
84 | "outputs": [],
85 | "source": [
86 | "from kats.models.ensemble.ensemble import EnsembleParams, BaseModelParams\n",
87 | "from kats.models.ensemble.kats_ensemble import KatsEnsemble\n",
88 | "from kats.models import (\n",
89 | " linear_model,\n",
90 | " quadratic_model\n",
91 | ")\n",
92 | "\n",
93 | "\n",
94 | "model_params = EnsembleParams(\n",
95 | " [\n",
96 | " BaseModelParams(\"linear\", linear_model.LinearModelParams()),\n",
97 | " BaseModelParams(\"quadratic\", quadratic_model.QuadraticModelParams()),\n",
98 | " ]\n",
99 | " )\n",
100 | "\n",
101 | "# create `KatsEnsembleParam` with detailed configurations \n",
102 | "KatsEnsembleParam = {\n",
103 | " \"models\": model_params,\n",
104 | " \"aggregation\": \"weightedavg\",\n",
105 | " \"seasonality_length\": 30,\n",
106 | " \"decomposition_method\": \"additive\",\n",
107 | "}"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "id": "8379225b-c620-491f-8357-b4ac822deb50",
114 | "metadata": {
115 | "id": "8379225b-c620-491f-8357-b4ac822deb50"
116 | },
117 | "outputs": [],
118 | "source": [
119 | "from kats.consts import TimeSeriesData\n",
120 | "TARGET_COL = \"new_cases\"\n",
121 | "\n",
122 | "df_ts = TimeSeriesData(\n",
123 | " value=df[TARGET_COL], time=df[\"date\"]\n",
124 | ")"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "id": "b03eca14-47e3-4280-a178-3b78e9f4d58f",
131 | "metadata": {
132 | "colab": {
133 | "base_uri": "https://localhost:8080/"
134 | },
135 | "id": "b03eca14-47e3-4280-a178-3b78e9f4d58f",
136 | "outputId": "be5d4373-935c-497c-f706-d2522e37c5c5"
137 | },
138 | "outputs": [],
139 | "source": [
140 | "m = KatsEnsemble(\n",
141 | " data=df_ts, \n",
142 | " params=KatsEnsembleParam\n",
143 | ").fit()"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": null,
149 | "id": "I-QkTf2kBiRE",
150 | "metadata": {
151 | "colab": {
152 | "base_uri": "https://localhost:8080/",
153 | "height": 419
154 | },
155 | "id": "I-QkTf2kBiRE",
156 | "outputId": "a42a5154-4f23-466b-cf00-9d9616ab84f4"
157 | },
158 | "outputs": [],
159 | "source": [
160 | "m.predict(steps=90).aggregate()"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "id": "c6c05f39-8e78-40cd-a43c-0ae2eed3c200",
167 | "metadata": {
168 | "colab": {
169 | "base_uri": "https://localhost:8080/",
170 | "height": 426
171 | },
172 | "id": "c6c05f39-8e78-40cd-a43c-0ae2eed3c200",
173 | "outputId": "f9454627-6e0e-4217-e789-80802cd87d22"
174 | },
175 | "outputs": [],
176 | "source": [
177 | "m.predict(steps=90)\n",
178 | "m.aggregate()\n",
179 | "m.plot()\n",
180 | "plt.ylabel(TARGET_COL)"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "id": "22489efc-329a-43d7-8524-ebc46a7d0494",
187 | "metadata": {
188 | "id": "22489efc-329a-43d7-8524-ebc46a7d0494"
189 | },
190 | "outputs": [],
191 | "source": []
192 | }
193 | ],
194 | "metadata": {
195 | "colab": {
196 | "collapsed_sections": [],
197 | "name": "Kats",
198 | "provenance": []
199 | },
200 | "kernelspec": {
201 | "display_name": "Python 3",
202 | "language": "python",
203 | "name": "python3"
204 | },
205 | "language_info": {
206 | "codemirror_mode": {
207 | "name": "ipython",
208 | "version": 3
209 | },
210 | "file_extension": ".py",
211 | "mimetype": "text/x-python",
212 | "name": "python",
213 | "nbconvert_exporter": "python",
214 | "pygments_lexer": "ipython3",
215 | "version": "3.8.8"
216 | }
217 | },
218 | "nbformat": 4,
219 | "nbformat_minor": 5
220 | }
221 |
--------------------------------------------------------------------------------
/chapter7/Silverkite.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "
"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "colab": {
15 | "base_uri": "https://localhost:8080/"
16 | },
17 | "id": "ZQ3Ym1YDe2FQ",
18 | "outputId": "0be39c08-a455-4902-b876-9228b6a0eb42"
19 | },
20 | "outputs": [],
21 | "source": [
22 | "!pip install greykite"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {
29 | "id": "1gEv7oq8e2FY"
30 | },
31 | "outputs": [],
32 | "source": [
33 | "import matplotlib.pyplot as plt\n",
34 | "import seaborn as sns\n",
35 | "%matplotlib inline\n",
36 | "\n",
37 | "plt.style.use('seaborn-whitegrid')\n",
38 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
39 | "plt.rcParams[\"font.size\"] = \"17\""
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {
46 | "id": "TECK_SIde2FX"
47 | },
48 | "outputs": [],
49 | "source": [
50 | "import pandas as pd\n",
51 | "\n",
52 | "owid_covid = pd.read_csv(\"https://covid.ourworldindata.org/data/owid-covid-data.csv\")\n",
53 | "owid_covid[\"date\"] = pd.to_datetime(owid_covid[\"date\"])"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {
60 | "colab": {
61 | "base_uri": "https://localhost:8080/"
62 | },
63 | "id": "ezjWoCAxfZfX",
64 | "outputId": "60fb7d72-bed1-481d-96f9-97e86bd79f84"
65 | },
66 | "outputs": [],
67 | "source": [
68 | "owid_covid.location.unique()"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {
75 | "colab": {
76 | "base_uri": "https://localhost:8080/",
77 | "height": 309
78 | },
79 | "id": "mtiFkfn9hEgV",
80 | "outputId": "6f198322-c183-4936-d2af-688dfb9316a0"
81 | },
82 | "outputs": [],
83 | "source": [
84 | "owid_covid.head()"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {
91 | "id": "JslH5iW-gthQ"
92 | },
93 | "outputs": [],
94 | "source": [
95 | "df = owid_covid[owid_covid.location == \"France\"].set_index(\"date\", drop=True).resample('D').interpolate(method='linear')"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {
102 | "id": "lPBBQ8eLe2Fb"
103 | },
104 | "outputs": [],
105 | "source": [
106 | "from greykite.framework.templates.autogen.forecast_config import (\n",
107 | " ForecastConfig, MetadataParam\n",
108 | ")\n",
109 | "\n",
110 | "metadata = MetadataParam(\n",
111 | " time_col=\"date\",\n",
112 | " value_col=\"new_cases\",\n",
113 | " freq=\"D\"\n",
114 | ")"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {
121 | "colab": {
122 | "base_uri": "https://localhost:8080/"
123 | },
124 | "id": "Bv40xaJde2Fc",
125 | "outputId": "8305ef45-89f0-4067-e2c5-96f3d9b7dfc6"
126 | },
127 | "outputs": [],
128 | "source": [
129 | "import warnings\n",
130 | "from greykite.framework.templates.forecaster import Forecaster\n",
131 | "from greykite.framework.templates.model_templates import ModelTemplateEnum\n",
132 | "\n",
133 | "forecaster = Forecaster()\n",
134 | "\n",
135 | "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
136 | "result = forecaster.run_forecast_config(\n",
137 | " df=df.reset_index(),\n",
138 | " config=ForecastConfig(\n",
139 | " model_template=ModelTemplateEnum.SILVERKITE_DAILY_90.name,\n",
140 | " forecast_horizon=90,\n",
141 | " coverage=0.95,\n",
142 | " metadata_param=metadata,\n",
143 | " )\n",
144 | ")"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "metadata": {
151 | "colab": {
152 | "base_uri": "https://localhost:8080/",
153 | "height": 542
154 | },
155 | "id": "Djt_cYAPqdES",
156 | "outputId": "e22d6805-d3a5-4500-d923-3ea1e9778657"
157 | },
158 | "outputs": [],
159 | "source": [
160 | "forecast = result.forecast\n",
161 | "forecast.plot().show(renderer=\"colab\") # leave out the renderer argument if you are not on google colab!"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "metadata": {
168 | "colab": {
169 | "base_uri": "https://localhost:8080/",
170 | "height": 204
171 | },
172 | "id": "6MGAevXyvEyA",
173 | "outputId": "132fb3e9-be00-485b-badb-05547f65a13b"
174 | },
175 | "outputs": [],
176 | "source": [
177 | "forecast.df.head().round(2)"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {
184 | "colab": {
185 | "base_uri": "https://localhost:8080/",
186 | "height": 204
187 | },
188 | "id": "El0vJRace2Fe",
189 | "outputId": "816da56c-de9b-458b-bcd1-3956e3331169"
190 | },
191 | "outputs": [],
192 | "source": [
193 | "from collections import defaultdict\n",
194 | "\n",
195 | "backtest = result.backtest\n",
196 | "backtest_eval = defaultdict(list)\n",
197 | "for metric, value in backtest.train_evaluation.items():\n",
198 | " backtest_eval[metric].append(value)\n",
199 | " backtest_eval[metric].append(backtest.test_evaluation[metric])\n",
200 | "metrics = pd.DataFrame(backtest_eval, index=[\"train\", \"test\"]).T\n",
201 | "metrics.head()"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "metadata": {
208 | "colab": {
209 | "base_uri": "https://localhost:8080/",
210 | "height": 241
211 | },
212 | "id": "oL9GMtJgtHGO",
213 | "outputId": "f7555d36-66b9-4f7d-a4f8-fdced9397806"
214 | },
215 | "outputs": [],
216 | "source": [
217 | "model = result.model\n",
218 | "future_df = result.timeseries.make_future_dataframe(\n",
219 | " periods=4,\n",
220 | " include_history=False\n",
221 | ")\n",
222 | "model.predict(future_df)"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "metadata": {
229 | "colab": {
230 | "base_uri": "https://localhost:8080/",
231 | "height": 241
232 | },
233 | "id": "uIclGJzatLl9",
234 | "outputId": "ae1832c4-10df-4cae-c5f3-3ac9066f8745"
235 | },
236 | "outputs": [],
237 | "source": [
238 | "model.predict(future_df)"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": null,
244 | "metadata": {},
245 | "outputs": [],
246 | "source": []
247 | }
248 | ],
249 | "metadata": {
250 | "colab": {
251 | "collapsed_sections": [
252 | "JHuPMCuAe2Ff",
253 | "xJe34gyie2Fi",
254 | "B9u12G4ne2Fm"
255 | ],
256 | "name": "Silverkite",
257 | "provenance": []
258 | },
259 | "kernelspec": {
260 | "display_name": "Python 3",
261 | "language": "python",
262 | "name": "python3"
263 | },
264 | "language_info": {
265 | "codemirror_mode": {
266 | "name": "ipython",
267 | "version": 3
268 | },
269 | "file_extension": ".py",
270 | "mimetype": "text/x-python",
271 | "name": "python",
272 | "nbconvert_exporter": "python",
273 | "pygments_lexer": "ipython3",
274 | "version": "3.8.8"
275 | }
276 | },
277 | "nbformat": 4,
278 | "nbformat_minor": 4
279 | }
280 |
--------------------------------------------------------------------------------
/chapter7/XGBoost.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "
"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "id": "cjiXyuy5e2Fj",
15 | "outputId": "26e9d719-07e0-4a60-8005-c05a24e1e009"
16 | },
17 | "outputs": [],
18 | "source": [
19 | "!pip install xgboost"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {
26 | "id": "3CaAZqSGe2Fi"
27 | },
28 | "outputs": [],
29 | "source": [
30 | "import matplotlib.pyplot as plt\n",
31 | "import seaborn as sns\n",
32 | "%matplotlib inline\n",
33 | "\n",
34 | "plt.style.use('seaborn-whitegrid')\n",
35 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
36 | "plt.rcParams[\"font.size\"] = \"17\""
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {
43 | "id": "UpEJozDdyizN"
44 | },
45 | "outputs": [],
46 | "source": [
47 | "import pandas as pd\n",
48 | "\n",
49 | "owid_covid = pd.read_csv(\"https://covid.ourworldindata.org/data/owid-covid-data.csv\")\n",
50 | "owid_covid[\"date\"] = pd.to_datetime(owid_covid[\"date\"])\n",
51 | "df = owid_covid[owid_covid.location == \"France\"].set_index(\"date\", drop=True).resample('D').interpolate(method='linear').reset_index()"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {
58 | "id": "WqB6Vwn3e2Fj"
59 | },
60 | "outputs": [],
61 | "source": [
62 | "from sklearn.base import TransformerMixin, BaseEstimator\n",
63 | "from typing import List\n",
64 | "\n",
65 | "class DateFeatures(TransformerMixin, BaseEstimator):\n",
66 | " \"\"\"DateFeatures transformer.\"\"\"\n",
67 | " features = [\n",
68 | " \"hour\",\n",
69 | " \"year\",\n",
70 | " \"day\",\n",
71 | " \"weekday\",\n",
72 | " \"month\",\n",
73 | " \"quarter\",\n",
74 | " ]\n",
75 | " \n",
76 | " def __init__(self):\n",
77 | " \"\"\"Nothing much to do.\"\"\"\n",
78 | " super().__init__()\n",
79 | " self.feature_names: List[str] = []\n",
80 | "\n",
81 | " def get_feature_names(self):\n",
82 | " \"\"\"Feature names.\"\"\"\n",
83 | " return self.feature_names\n",
84 | " \n",
85 | " def transform(self, df: pd.DataFrame):\n",
86 | " \"\"\"Annotate date features.\"\"\"\n",
87 | " Xt = []\n",
88 | " for col in df.columns:\n",
89 | " for feature in self.features:\n",
90 | " date_feature = getattr(\n",
91 | " getattr(\n",
92 | " df[col], \"dt\"\n",
93 | " ), feature\n",
94 | " )\n",
95 | " date_feature.name = f\"{col}_{feature}\"\n",
96 | " Xt.append(date_feature)\n",
97 | " \n",
98 | " df2 = pd.concat(Xt, axis=1)\n",
99 | " self.feature_names = list(df2.columns)\n",
100 | " return df2\n",
101 | "\n",
102 | " def fit(self, df: pd.DataFrame, y=None, **fit_params):\n",
103 | " \"\"\"No fitting needed.\"\"\"\n",
104 | " return self"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": null,
110 | "metadata": {
111 | "id": "IhEkSi5azUrM"
112 | },
113 | "outputs": [],
114 | "source": [
115 | "import numpy as np\n",
116 | "from sklearn.base import TransformerMixin, BaseEstimator\n",
117 | "from typing import Dict\n",
118 | "\n",
119 | "class CyclicalFeatures(TransformerMixin, BaseEstimator):\n",
120 | " \"\"\"CyclicalFeatures transformer.\"\"\"\n",
121 | " \n",
122 | " def __init__(self, max_vals: Dict[str, float] = {}):\n",
123 | " \"\"\"Nothing much to do.\"\"\"\n",
124 | " super().__init__()\n",
125 | " self.feature_names: List[str] = []\n",
126 | " self.max_vals = max_vals\n",
127 | "\n",
128 | " def get_feature_names(self):\n",
129 | " \"\"\"Feature names.\"\"\"\n",
130 | " return self.feature_names\n",
131 | " \n",
132 | " def transform(self, df: pd.DataFrame):\n",
133 | " \"\"\"Annotate date features.\"\"\"\n",
134 | " Xt = []\n",
135 | " for col in df.columns:\n",
136 | " if col in self.max_vals:\n",
137 | " max_val = self.max_vals[col]\n",
138 | " else:\n",
139 | " max_val = df[col].max()\n",
140 | " for fun_name, fun in [(\"cos\", np.cos), (\"sin\", np.sin)]:\n",
141 | " date_feature = fun(2 * np.pi * df[col] / max_val)\n",
142 | " date_feature.name = f\"{col}_{fun_name}\"\n",
143 | " Xt.append(date_feature)\n",
144 | " \n",
145 | " df2 = pd.concat(Xt, axis=1)\n",
146 | " self.feature_names = list(df2.columns)\n",
147 | " return df2\n",
148 | "\n",
149 | " def fit(self, df: pd.DataFrame, y=None, **fit_params):\n",
150 | " \"\"\"No fitting needed.\"\"\"\n",
151 | " return self"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "metadata": {
158 | "id": "87bELgcAziFi"
159 | },
160 | "outputs": [],
161 | "source": [
162 | "from sklearn.compose import ColumnTransformer\n",
163 | "from sklearn.pipeline import Pipeline, make_pipeline\n",
164 | "#from sklearn import linear_model\n",
165 | "\n",
166 | "\n",
167 | "preprocessor = ColumnTransformer(\n",
168 | " transformers=[(\n",
169 | " \"date\",\n",
170 | " make_pipeline(\n",
171 | " DateFeatures(),\n",
172 | " ColumnTransformer(transformers=[\n",
173 | " (\"cyclical\", CyclicalFeatures(),\n",
174 | " [\"date_day\", \"date_weekday\", \"date_month\"]\n",
175 | " )\n",
176 | " ], remainder=\"passthrough\")\n",
177 | " ), [\"date\"],\n",
178 | " ),], remainder=\"passthrough\"\n",
179 | ")\n",
180 | "\n",
181 | "pipeline = Pipeline(\n",
182 | " [\n",
183 | " (\"preprocessing\", preprocessor),\n",
184 | " #(\"clf\", linear_model.LinearRegression(),),\n",
185 | " ]\n",
186 | ")"
187 | ]
188 | },
189 | {
190 | "cell_type": "code",
191 | "execution_count": null,
192 | "metadata": {
193 | "colab": {
194 | "base_uri": "https://localhost:8080/"
195 | },
196 | "id": "UvhfdtVre2Fj",
197 | "outputId": "453afa7a-c3f5-4a51-f693-55d2b7698282"
198 | },
199 | "outputs": [],
200 | "source": [
201 | "FEATURE_COLS = [\"date\"]\n",
202 | "date_features = pipeline.fit_transform(df[FEATURE_COLS])\n",
203 | "date_features"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "metadata": {
210 | "id": "WmhAZE-Me2Fk"
211 | },
212 | "outputs": [],
213 | "source": [
214 | "TRAIN_SIZE = int(len(df) * 0.9)\n",
215 | "HORIZON = 1\n",
216 | "TARGET_COL = \"new_cases\"\n",
217 | "\n",
218 | "X_train, X_test = df.iloc[HORIZON:TRAIN_SIZE], df.iloc[TRAIN_SIZE+HORIZON:]\n",
219 | "y_train = df.shift(periods=HORIZON).iloc[HORIZON:TRAIN_SIZE][TARGET_COL]\n",
220 | "y_test = df.shift(periods=HORIZON).iloc[TRAIN_SIZE+HORIZON:][TARGET_COL]"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "metadata": {
227 | "colab": {
228 | "base_uri": "https://localhost:8080/"
229 | },
230 | "id": "aA7GL52je2Fl",
231 | "outputId": "d19d55fd-a683-4eea-a087-5676a4333077"
232 | },
233 | "outputs": [],
234 | "source": [
235 | "from xgboost import XGBRegressor\n",
236 | "\n",
237 | "pipeline = Pipeline(\n",
238 | " [\n",
239 | " (\"preprocessing\", preprocessor),\n",
240 | " (\"xgb\", XGBRegressor(objective=\"reg:squarederror\", n_estimators=1000))\n",
241 | " ]\n",
242 | ")\n",
243 | "pipeline.fit(X_train[FEATURE_COLS], y_train)"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "metadata": {
250 | "id": "6b3PUD4CH5Gx"
251 | },
252 | "outputs": [],
253 | "source": [
254 | "MAX_HORIZON = 90\n",
255 | "X_test_horizon = pd.Series(pd.date_range(\n",
256 | " start=df.date.min(), \n",
257 | " periods=len(df) + MAX_HORIZON,\n",
258 | " name=\"date\"\n",
259 | ")).reset_index()"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": null,
265 | "metadata": {
266 | "id": "t8VoXOzCe2Fl"
267 | },
268 | "outputs": [],
269 | "source": [
270 | "forecasted = pd.concat(\n",
271 | " [pd.Series(pipeline.predict(X_test_horizon[FEATURE_COLS])), pd.Series(X_test_horizon.date)],\n",
272 | " axis=1\n",
273 | ")\n",
274 | "forecasted.columns = [TARGET_COL, \"date\"]"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": null,
280 | "metadata": {
281 | "id": "7FGnoSbre2Fl"
282 | },
283 | "outputs": [],
284 | "source": [
285 | "actual = pd.concat(\n",
286 | " [pd.Series(df[TARGET_COL]), pd.Series(df.date)],\n",
287 | " axis=1\n",
288 | ")\n",
289 | "actual.columns = [TARGET_COL, \"date\"]"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": null,
295 | "metadata": {
296 | "colab": {
297 | "base_uri": "https://localhost:8080/",
298 | "height": 443
299 | },
300 | "id": "XtTKGkoTe2Fm",
301 | "outputId": "4c3c830f-3df9-4c5b-ad06-41cd623e5e70"
302 | },
303 | "outputs": [],
304 | "source": [
305 | "fig, ax = plt.subplots(figsize=(12, 6))\n",
306 | "forecasted.set_index(\"date\").plot(linestyle='--', ax=ax)\n",
307 | "actual.set_index(\"date\").plot(linestyle='-.', ax=ax)\n",
308 | "plt.legend([\"forecast\", \"actual\"])"
309 | ]
310 | },
311 | {
312 | "cell_type": "code",
313 | "execution_count": null,
314 | "metadata": {
315 | "colab": {
316 | "base_uri": "https://localhost:8080/"
317 | },
318 | "id": "B4oPSOTse2F1",
319 | "outputId": "ad6760a9-e9a8-4cd8-812f-39fce7c4e636"
320 | },
321 | "outputs": [],
322 | "source": [
323 | "from sklearn.metrics import mean_squared_error\n",
324 | "\n",
325 | "test_data = actual.merge(forecasted, on=\"date\", suffixes=(\"_actual\", \"_predicted\"))\n",
326 | "\n",
327 | "mse = mean_squared_error(test_data.new_cases_actual, test_data.new_cases_predicted, squared=False) # RMSE\n",
328 | "print(\"The root mean squared error (RMSE) on test set: {:.2f}\".format(mse))"
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "execution_count": null,
334 | "metadata": {
335 | "id": "HU1DQT1o5Vpn"
336 | },
337 | "outputs": [],
338 | "source": []
339 | }
340 | ],
341 | "metadata": {
342 | "colab": {
343 | "collapsed_sections": [],
344 | "name": "XGBoost",
345 | "provenance": []
346 | },
347 | "kernelspec": {
348 | "display_name": "Python 3",
349 | "language": "python",
350 | "name": "python3"
351 | },
352 | "language_info": {
353 | "codemirror_mode": {
354 | "name": "ipython",
355 | "version": 3
356 | },
357 | "file_extension": ".py",
358 | "mimetype": "text/x-python",
359 | "name": "python",
360 | "nbconvert_exporter": "python",
361 | "pygments_lexer": "ipython3",
362 | "version": "3.8.8"
363 | }
364 | },
365 | "nbformat": 4,
366 | "nbformat_minor": 4
367 | }
368 |
--------------------------------------------------------------------------------
/chapter8/Drift_Detection.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "6d73980e-a0c4-4163-8819-1d70e60c8454",
6 | "metadata": {},
7 | "source": [
8 | "
"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "4546ab51-c2cb-4261-ae78-fdbd6d6a07ac",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "!pip install river"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "id": "3caf1ccd-709a-49d2-8f0a-36b54bbc9cf9",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import numpy as np\n",
29 | "\n",
30 | "np.random.seed(12345)\n",
31 | "data_stream = np.concatenate(\n",
32 | " (np.random.randint(2, size=1000), np.random.randint(8, size=1000))\n",
33 | ")"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "id": "c0b2fb45-abc9-4ac4-9e10-78405620c3d3",
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "import matplotlib.pyplot as plt\n",
44 | "\n",
45 | "plt.style.use('seaborn-whitegrid')\n",
46 | "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
47 | "plt.rcParams[\"font.size\"] = \"17\""
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "id": "d1551a3f-cf55-4269-a1da-c697ccb8a271",
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "plt.figure(figsize=(16, 6))\n",
58 | "plt.plot(data_stream)"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "034a8574-c193-49f3-ab23-44cd7c2e2c2d",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "def perform_test(drift_detector, data_stream):\n",
69 | " detected_indices = []\n",
70 | " for i, val in enumerate(data_stream):\n",
71 | " in_drift, in_warning = drift_detector.update(val)\n",
72 | " if in_drift:\n",
73 | " detected_indices.append(i)\n",
74 | " return detected_indices"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "id": "aa3f4530-72b5-4733-8db7-e6e71de8f292",
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "import matplotlib.pyplot as plt\n",
85 | "from river.drift import ADWIN, PageHinkley\n",
86 | "\n",
87 | "def show_drift(data_stream, indices):\n",
88 | " fig, ax = plt.subplots(figsize=(16, 6))\n",
89 | " ax.plot(data_stream)\n",
90 | " ax.plot(\n",
91 | " indices,\n",
92 | " data_stream[indices],\n",
93 | " \"r\",\n",
94 | " alpha=0.6,\n",
95 | " marker=r'$\\circ$',\n",
96 | " markersize=22,\n",
97 | " linewidth=4\n",
98 | " )\n",
99 | " plt.tight_layout()\n",
100 | "\n",
101 | "\n",
102 | "detected_indices = perform_test(PageHinkley(), data_stream)\n",
103 | "show_drift(data_stream, detected_indices)\n",
104 | "plt.title(\"Page-Hinkley Drift Detection\");"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": null,
110 | "id": "2e19844c-b56f-48cf-8cba-75f2b431804f",
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "detected_indices = perform_test(ADWIN(), data_stream)\n",
115 | "show_drift(data_stream, detected_indices)\n",
116 | "plt.title(\"ADWIN Drift Detection\");"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "id": "47d3a7da-f866-4192-8c7d-46261312bc15",
123 | "metadata": {},
124 | "outputs": [],
125 | "source": []
126 | }
127 | ],
128 | "metadata": {
129 | "kernelspec": {
130 | "display_name": "Python 3",
131 | "language": "python",
132 | "name": "python3"
133 | },
134 | "language_info": {
135 | "codemirror_mode": {
136 | "name": "ipython",
137 | "version": 3
138 | },
139 | "file_extension": ".py",
140 | "mimetype": "text/x-python",
141 | "name": "python",
142 | "nbconvert_exporter": "python",
143 | "pygments_lexer": "ipython3",
144 | "version": "3.8.8"
145 | }
146 | },
147 | "nbformat": 4,
148 | "nbformat_minor": 5
149 | }
150 |
--------------------------------------------------------------------------------
/chapter8/Online_Learning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "6d0ec63d-ef95-48fa-b83d-eef4e02733e0",
6 | "metadata": {},
7 | "source": [
8 | "
"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "514a7acc-81bf-49cb-8fea-632387aa44bf",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "!pip install river"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "id": "7e34e1a6-2d8d-4347-b96b-6927c210cb85",
24 | "metadata": {},
25 | "source": [
26 | "# Regression"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "id": "d410bd23-82aa-4291-b034-16e9dc0cf606",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "from river import stream\n",
37 | "from river.datasets import base\n",
38 | "\n",
39 | "\n",
40 | "class SolarFlare(base.FileDataset):\n",
41 | " def __init__(self):\n",
42 | " super().__init__(\n",
43 | " n_samples=1066,\n",
44 | " n_features=10,\n",
45 | " n_outputs=1,\n",
46 | " task=base.MO_REG,\n",
47 | " filename=\"solar-flare.csv.zip\",\n",
48 | " )\n",
49 | "\n",
50 | " def __iter__(self):\n",
51 | " return stream.iter_csv(\n",
52 | " self.path,\n",
53 | " target=\"m-class-flares\",\n",
54 | " converters={\n",
55 | " \"zurich-class\": str,\n",
56 | " \"largest-spot-size\": str,\n",
57 | " \"spot-distribution\": str,\n",
58 | " \"activity\": int,\n",
59 | " \"evolution\": int,\n",
60 | " \"previous-24h-flare-activity\": int,\n",
61 | " \"hist-complex\": int,\n",
62 | " \"hist-complex-this-pass\": int,\n",
63 | " \"area\": int,\n",
64 | " \"largest-spot-area\": int,\n",
65 | " \"c-class-flares\": int,\n",
66 | " \"m-class-flares\": int,\n",
67 | " \"x-class-flares\": int,\n",
68 | " },\n",
69 | " )\n",
70 | " "
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "id": "6a8bdf8c-7a6f-4a47-a035-50ba4861dfd2",
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "from pprint import pprint\n",
81 | "from river import datasets\n",
82 | "\n",
83 | "for x, y in SolarFlare():\n",
84 | " pprint(x)\n",
85 | " pprint(y)\n",
86 | " break"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "id": "db7b13ad-4303-443d-840f-00aef7f99f51",
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "import numbers\n",
97 | "from river import compose\n",
98 | "from river import datasets\n",
99 | "from river import evaluate\n",
100 | "from river import linear_model\n",
101 | "from river import metrics\n",
102 | "from river import preprocessing\n",
103 | "from river import tree\n",
104 | "\n",
105 | "dataset = SolarFlare()\n",
106 | "num = compose.SelectType(numbers.Number) | preprocessing.MinMaxScaler()\n",
107 | "cat = compose.SelectType(str) | preprocessing.OneHotEncoder(sparse=False)\n",
108 | "model = tree.HoeffdingTreeRegressor()\n",
109 | "pipeline = (num + cat) | model\n",
110 | "metric = metrics.MAE()\n",
111 | "#evaluate.progressive_val_score(dataset, pipeline, metric)"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "id": "3597bf77-9fdb-4358-bb6b-9543a65881c2",
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "import matplotlib.pyplot as plt\n",
122 | "\n",
123 | "errors = []\n",
124 | "for x, y in SolarFlare():\n",
125 | " y_pred = pipeline.predict_one(x)\n",
126 | " metric = metric.update(y, y_pred)\n",
127 | " errors.append(metric.get())\n",
128 | " pipeline = pipeline.learn_one(x, y)\n",
129 | "\n",
130 | "fig, ax = plt.subplots(figsize=(16, 6))\n",
131 | "ax.plot(\n",
132 | " errors,\n",
133 | " \"ro\",\n",
134 | " alpha=0.6,\n",
135 | " markersize=2,\n",
136 | " linewidth=4\n",
137 | ")"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "id": "30f435e0-56ed-448e-ba40-b53c5dcf3ef5",
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "fig, ax = plt.subplots(figsize=(16, 6))\n",
148 | "ax.plot(\n",
149 | " errors,\n",
150 | " \"ro\",\n",
151 | " alpha=0.6,\n",
152 | " markersize=2,\n",
153 | " linewidth=4\n",
154 | ")\n",
155 | "ax.set_xlabel(\"number of points\")\n",
156 | "ax.set_ylabel(\"MAE\");"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "id": "7a709167-335a-4968-b8ce-c1faee11af30",
162 | "metadata": {},
163 | "source": [
164 | "# Adaptive Models on a Concept Drift Data Stream"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "id": "b0b6a6f8-48d0-478c-825a-d2ed313e7d6b",
171 | "metadata": {},
172 | "outputs": [],
173 | "source": [
174 | "from river import (\n",
175 | " synth, ensemble, tree,\n",
176 | " evaluate, metrics\n",
177 | ")\n",
178 | "\n",
179 | "models = [\n",
180 | " tree.HoeffdingTreeRegressor(),\n",
181 | " tree.HoeffdingAdaptiveTreeRegressor(),\n",
182 | " ensemble.AdaptiveRandomForestRegressor(seed=42),\n",
183 | "]\n",
184 | "\n",
185 | "results = {}\n",
186 | "for model in models:\n",
187 | " metric = metrics.MSE()\n",
188 | " errors = []\n",
189 | " dataset = synth.ConceptDriftStream(\n",
190 | " seed=42, position=500, width=40\n",
191 | " ).take(1000) \n",
192 | " for i, (x, y) in enumerate(dataset):\n",
193 | " y_pred = model.predict_one(x)\n",
194 | " metric = metric.update(y, y_pred)\n",
195 | " model = model.learn_one(x, y)\n",
196 | " if (i % 100) == 0:\n",
197 | " errors.append(dict(step=i, error=metric.get()))\n",
198 | " results[str(model.__class__).split(\".\")[-1][:-2]] = errors"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": null,
204 | "id": "3e3a3135-fc1e-4f38-b2f0-97b945bc3ea6",
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "results"
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": null,
214 | "id": "024745ef-b4ef-4ee8-ac08-9c839043ec6d",
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "import pandas as pd\n",
219 | "import seaborn as sns\n",
220 | "\n",
221 | "plt.figure(figsize=(16, 6))\n",
222 | "styles = [\"-\",\"--\",\"-.\",\":\"]\n",
223 | "markers = [\n",
224 | " '.', ',', 'o', 'v', '^', '<', '>',\n",
225 | " '1', '2', '3', '4', '8', 's', 'p',\n",
226 | " '*', 'h', 'H', '+', 'x', 'D', 'd',\n",
227 | " '|', '_', 'P', 'X', 0, 1, 2, 3, 4,\n",
228 | " 5, 6, 7, 8, 9, 10, 11\n",
229 | "]\n",
230 | "\n",
231 | "for i, (model, errors) in enumerate(results.items()):\n",
232 | " df = pd.DataFrame(errors)\n",
233 | " sns.lineplot(\n",
234 | " data=df,\n",
235 | " x=\"step\",\n",
236 | " y=\"error\",\n",
237 | " linestyle=styles[i%len(styles)],\n",
238 | " alpha=0.5,\n",
239 | " markersize=22,\n",
240 | " markers=markers[i%len(markers)], \n",
241 | " label=model,\n",
242 | " linewidth=4\n",
243 | " )\n",
244 | " \n",
245 | "plt.ylabel(\"MSE\")\n",
246 | "plt.xlabel(\"Step\")\n",
247 | "sns.set_style(\"ticks\")"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": null,
253 | "id": "ad7c4902-96ae-4abc-b48b-801929473533",
254 | "metadata": {},
255 | "outputs": [],
256 | "source": []
257 | }
258 | ],
259 | "metadata": {
260 | "kernelspec": {
261 | "display_name": "Python 3",
262 | "language": "python",
263 | "name": "python3"
264 | },
265 | "language_info": {
266 | "codemirror_mode": {
267 | "name": "ipython",
268 | "version": 3
269 | },
270 | "file_extension": ".py",
271 | "mimetype": "text/x-python",
272 | "name": "python",
273 | "nbconvert_exporter": "python",
274 | "pygments_lexer": "ipython3",
275 | "version": "3.8.8"
276 | }
277 | },
278 | "nbformat": 4,
279 | "nbformat_minor": 5
280 | }
281 |
--------------------------------------------------------------------------------