├── Econometric_methods.ipynb ├── NeuralNetworks.ipynb ├── README.md ├── SimpleLinearRegression.ipynb ├── XGB-features.ipynb ├── XGB-simple.ipynb └── rolling.py /Econometric_methods.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Econometric models\n", 8 | "This is a collection of classic econometric models that were used for the kaggle competition.\n", 9 | "For more info, read README and our blog post." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%matplotlib inline\n", 19 | "import os, os.path\n", 20 | "import itertools\n", 21 | "import pickle\n", 22 | "import time\n", 23 | "import numpy as np\n", 24 | "import pandas as pd\n", 25 | "from pyramid.arima import auto_arima\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "import seaborn as sns\n", 28 | "import statsmodels.formula.api as smf\n", 29 | "import statsmodels.tsa.api as smt\n", 30 | "import statsmodels.api as sm\n", 31 | "from scipy.stats import describe\n", 32 | "import warnings" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "pd.options.display.max_columns = 12\n", 42 | "pd.options.display.max_rows = 24\n", 43 | "\n", 44 | "# disable warnings in Anaconda\n", 45 | "warnings.simplefilter('ignore')\n", 46 | "\n", 47 | "# plots inisde jupyter notebook\n", 48 | "%matplotlib inline\n", 49 | "\n", 50 | "sns.set(style='darkgrid', palette='muted')\n", 51 | "color_scheme = {\n", 52 | " 'red': '#F1637A',\n", 53 | " 'green': '#6ABB3E',\n", 54 | " 'blue': '#3D8DEA',\n", 55 | " 'black': '#000000'\n", 56 | "}" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Seasonality" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "additive = {\"op\": lambda a,b: a+b, \"inv\": lambda a,b: a-b}\n", 73 | "multiplicative = {\"op\": lambda a,b: a*b, \"inv\": lambda a,b: a/b}" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "Choose seasonality model from above:" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 4, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "seasonality_model = additive\n", 90 | "#seasonality_model = multiplicative" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 5, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "categorize_by_week_of_year = lambda df: df.index.dayofyear // 7\n", 100 | "categorize_by_day_of_week = lambda df: df.index.dayofweek" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 6, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "def compute_seasonality(series, categorization):\n", 110 | " \"\"\"\n", 111 | " Computes seasonal component parameters based on provided series.\n", 112 | " \n", 113 | " :type series: pd.Series\n", 114 | " :param categorization: Function used to split values into various periods of the season.\n", 115 | " :type categorization: pd.DataFrame -> some categorical type, eg. int\n", 116 | " \"\"\"\n", 117 | " df = pd.DataFrame()\n", 118 | " df[\"values\"] = series\n", 119 | " df.index = series.index\n", 120 | " df[\"cat\"] = categorization(df)\n", 121 | " return df.groupby(by=\"cat\")[\"values\"].mean()" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 7, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "def alter_series_by_season(series, categorization, seasonality, op):\n", 131 | " df = pd.DataFrame()\n", 132 | " df[\"values\"] = series\n", 133 | " df[\"values\"] = series\n", 134 | " df[\"cat\"] = categorization(df)\n", 135 | " return op(df[\"values\"], df[\"cat\"].map(seasonality))" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 8, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "def add_seasonal_component(series, categorization, seasonality):\n", 145 | " \"\"\"\n", 146 | " Add previously computed seasonal component back into a deseasonalized series.\n", 147 | " \n", 148 | " :type series: pd.Series\n", 149 | " :param categorization: Function used to split values into various periods of the season.\n", 150 | " :type categorization: pd.DataFrame -> some categorical type, eg. int\n", 151 | " :param seasonality: value returned from compute_seasonality\n", 152 | " :returns: Series with added seasonal component.\n", 153 | " :rtype: pd.Series\n", 154 | " \"\"\"\n", 155 | " return alter_series_by_season(series, categorization, seasonality, seasonality_model[\"op\"])\n" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 9, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "def remove_seasonal_component(series, categorization, seasonality):\n", 165 | " \"\"\"\n", 166 | " Try removing a previously computed seasonal component.\n", 167 | " \n", 168 | " :type series: pd.Series\n", 169 | " :param categorization: Function used to split values into various periods of the season.\n", 170 | " :type categorization: pd.DataFrame -> some categorical type, eg. int\n", 171 | " :param seasonality: value returned from compute_seasonality\n", 172 | " :returns: Deseasonalized series.\n", 173 | " :rtype: pd.Series\n", 174 | " \"\"\"\n", 175 | " return alter_series_by_season(series, categorization, seasonality, seasonality_model[\"inv\"])" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "## Holt-Winters and SARIMA Models Utils" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 10, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "def train_and_forecast(data, categorization, trainer, forecaster, deseasonize, steps_to_forecast=90):\n", 192 | " \"\"\"\n", 193 | " Split input data, deseasonalizes train data,\n", 194 | " train using trainer (data -> model),\n", 195 | " forecast using forecaster\n", 196 | " \n", 197 | " predicts values and applies seasonalization and returns predicted vs actual values\n", 198 | " \n", 199 | " :param data: dataset with the training data\n", 200 | " :param categorization: Function used to split values into various periods of the season.\n", 201 | " :type categorization: pd.DataFrame -> some categorical type, eg. int\n", 202 | " :param trainer: Function used to train the model\n", 203 | " :type trainer: pd.DataFrame -> model\n", 204 | " :param forecaster: (model, steps) -> prediction\n", 205 | " :param steps_to_forecast: number of steps to forecast\n", 206 | " :returns: a dataframe with:\n", 207 | " date\n", 208 | " sales - true values\n", 209 | " forecast - forecasted values\n", 210 | " \"\"\"\n", 211 | " \n", 212 | " #prepare trainig and validation datasets\n", 213 | " df_train = data.iloc[:-365].copy()\n", 214 | " df_validation = data.iloc[-365:].copy()\n", 215 | " df_validation.index = pd.DatetimeIndex(df_validation[\"date\"])\n", 216 | " df_train.index = pd.DatetimeIndex(df_train[\"date\"])\n", 217 | " \n", 218 | " if deseasonize:\n", 219 | " seas = compute_seasonality(df_train[\"sales\"], categorization)\n", 220 | " series = remove_seasonal_component(df_train[\"sales\"], categorization, seas)\n", 221 | " df_train[\"sales\"] = series\n", 222 | " \n", 223 | " df_train = df_train.reset_index(drop=True)\n", 224 | "\n", 225 | " # train\n", 226 | " model = trainer(df_train)\n", 227 | " \n", 228 | " # forecast\n", 229 | " forecast = forecaster(model, steps_to_forecast)\n", 230 | " \n", 231 | " # Create the pandas series from the forecast\n", 232 | " forecast = pd.Series(forecast)\n", 233 | " forecast.name = \"sales\"\n", 234 | " forecast.index = pd.DatetimeIndex(start='2017-01-01', \n", 235 | " freq=\"D\",\n", 236 | " periods = forecast.size)\n", 237 | " \n", 238 | " if deseasonize:\n", 239 | " forecast = add_seasonal_component(forecast, categorization, seas)\n", 240 | " \n", 241 | " final_forecast = pd.DataFrame()\n", 242 | " final_forecast['real_values'] = df_validation[\"sales\"][:steps_to_forecast]\n", 243 | " final_forecast['forecast'] = forecast\n", 244 | " \n", 245 | " return final_forecast" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 11, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "def extract_single_df(data, store, item):\n", 255 | " \"\"\"\n", 256 | " Extract single store/item time series from provided \n", 257 | " dataset\n", 258 | " \n", 259 | " :param data: Pandas dataframe with multiple timeseries\n", 260 | " :param store: number of the store\n", 261 | " :param item: number of the item\n", 262 | " :returns: Pandas dataframe with single store/item time series\n", 263 | " \"\"\"\n", 264 | " \n", 265 | " df_single = data.loc[(data.store == store) & (data.item == item),[\"date\", \"sales\"]].copy()\n", 266 | " df_single.reset_index(drop=True, inplace=True)\n", 267 | " df_single.date = pd.to_datetime(df_single.date)\n", 268 | " return df_single" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 12, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "def compute_all_models(data, ids, categorization, trainer, forecaster, deseasonize, steps_to_forecast=90):\n", 278 | " \"\"\"\n", 279 | " Train the models and use them to make forecast for all of the individual\n", 280 | " time series separately\n", 281 | " \n", 282 | " :params data: dataframe with multiple time series\n", 283 | " :params ids: list of tuples with stores and items\n", 284 | " :param categorization: Function used to split values into various periods of the season.\n", 285 | " :type categorization: pd.DataFrame -> some categorical type, eg. int\n", 286 | " :param trainer: Function used to train the model\n", 287 | " :type trainer: pd.DataFrame -> model\n", 288 | " :param forecaster: (model, steps) -> prediction\n", 289 | " :param steps_to_forecast: number of steps to forecast\n", 290 | " \"\"\"\n", 291 | " \n", 292 | " all_models_forecast = {}\n", 293 | " all_models_smape = np.array([])\n", 294 | " number_of_time_series = 0\n", 295 | " for store, item in ids:\n", 296 | " single_time_series = extract_single_df(data, store, item)\n", 297 | " predictions = train_and_forecast(single_time_series, categorization,\n", 298 | " trainer, forecaster, deseasonize, steps_to_forecast)\n", 299 | " score = smape(predictions['real_values'], predictions['forecast'])\n", 300 | " results = {\n", 301 | " \"item\": item,\n", 302 | " \"store\": store,\n", 303 | " \"predictions\": predictions,\n", 304 | " \"smape\": score\n", 305 | " }\n", 306 | " print(score)\n", 307 | " all_models_smape = np.append(all_models_smape, score)\n", 308 | " all_models_forecast[str(store) + str(item)] = results\n", 309 | " number_of_time_series += 1\n", 310 | " forecast_smape = np.sum(all_models_smape) / number_of_time_series\n", 311 | " \n", 312 | " return results, forecast_smape" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 13, 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [ 321 | "def smape(y, y_pred):\n", 322 | " \"\"\"\n", 323 | " compute the SMAPE metrics\n", 324 | " \n", 325 | " :param y: array with true values\n", 326 | " :param y: array with forcasted values\n", 327 | " :returns: average smape metrics for the given period\n", 328 | " \"\"\"\n", 329 | " \n", 330 | " div = (abs(y_pred) + abs(y)) / 2\n", 331 | " errors = abs(y_pred - y) / div\n", 332 | " \n", 333 | " smape = np.sum(errors) / len(y)\n", 334 | " return smape" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 14, 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "def compute_avg_smape (df_y, df_y_pred):\n", 344 | " \"\"\"\n", 345 | " Compute average SMAPE of multiple forecast\n", 346 | " \n", 347 | " :param df_y: data series with real values\n", 348 | " :param df_y_pred: dataframe with multiple forecasts\n", 349 | " :returns: average SMAPE of all forecasts\n", 350 | " \"\"\"\n", 351 | " \n", 352 | " avg_smape = 0\n", 353 | " for i in range(df_y_pred.shape[1]):\n", 354 | " err = smape(y=df_y.iloc[:,i],\n", 355 | " y_pred=df_y_pred.iloc[:,i])\n", 356 | " avg_smape += err\n", 357 | "\n", 358 | " avg_smape /= df_y_pred.shape[1]\n", 359 | " return avg_smape\n" 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "metadata": {}, 365 | "source": [ 366 | "### Data import & preparation \n", 367 | "In order for this to work, download all datasets from kaggle competition: \n", 368 | "https://www.kaggle.com/c/demand-forecasting-kernels-only \n", 369 | "and place them in `../data/` folder. \n", 370 | "We could not append datasets to our repo because of copyrights." 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": 15, 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [ 379 | "#Import data\n", 380 | "dir_name = \"../data/\" #insert your own path here\n", 381 | "file_name = \"train.csv\"\n", 382 | "\n", 383 | "filepath_train = os.path.abspath(os.path.join(os.getcwd(), dir_name, file_name ))\n", 384 | "\n", 385 | "store_item_data = pd.read_csv(filepath_train)\n", 386 | "store_item_data.index = pd.DatetimeIndex(store_item_data['date'])" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 16, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "number_of_stores = 10\n", 396 | "number_of_items = 50" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 17, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "ids = list(itertools.product(range(1,number_of_stores+1), range(1,number_of_items+1)))" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "metadata": {}, 411 | "source": [ 412 | "## Holt-Winters Method" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "metadata": {}, 419 | "outputs": [], 420 | "source": [ 421 | "# create a trainer which can be use to train all of the models\n", 422 | "hw_trainer = lambda df: smt.ExponentialSmoothing(endog=df.loc[-365:,'sales'], damped=False,\n", 423 | " trend='add',seasonal='add', seasonal_periods=7).fit()" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "# create a forecaster function to get the forcasts from all of the models\n", 433 | "hw_forecaster = lambda model, steps: model.predict(steps)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": null, 439 | "metadata": { 440 | "scrolled": true 441 | }, 442 | "outputs": [], 443 | "source": [ 444 | "hw_results, hw_smape = compute_all_models(store_item_data, ids, categorize_by_week_of_year,\n", 445 | " hw_trainer, hw_forecaster, False)\n", 446 | "print(\"Holt Winters method SMAPE on the validation set: \", hw_smape)" 447 | ] 448 | }, 449 | { 450 | "cell_type": "markdown", 451 | "metadata": {}, 452 | "source": [ 453 | "## SARIMA" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "metadata": {}, 460 | "outputs": [], 461 | "source": [ 462 | "# create a trainer which can be use to grid search and train all of the models\n", 463 | "arima_trainer = lambda df: auto_arima(df.loc[-365:,'sales'], m=7, n_jobs=1, max_p=7, max_q=7, max_P=7,\n", 464 | " max_Q=7, max_order=12, trend='c', max_iter=100,\n", 465 | " trace=False, error_action='ignore', suppress_warnings=True,\n", 466 | " stepwise=True, random=False, n_fits=40, random_state=44443)\n" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": null, 472 | "metadata": {}, 473 | "outputs": [], 474 | "source": [ 475 | "# create a forecaster function to get the forcasts from all of the models\n", 476 | "arima_forecaster = lambda model, steps: model.predict(steps)" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "metadata": { 483 | "scrolled": true 484 | }, 485 | "outputs": [], 486 | "source": [ 487 | "sarima_results, sarima_smape = compute_all_models(store_item_data, ids, categorize_by_week_of_year,\n", 488 | " arima_trainer, arima_forecaster, True)\n", 489 | "print(\"SARIMA model SMAPE on the validation set: \", sarima_smape)" 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": {}, 495 | "source": [ 496 | "## VAR" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 18, 502 | "metadata": {}, 503 | "outputs": [], 504 | "source": [ 505 | "def var_train_and_forecast(data_train, data_validation, categorization, max_lags=16,ic='aic',\n", 506 | " plots=False):\n", 507 | " \"\"\"\n", 508 | " Apply deseasonalization and differencing to the time series and\n", 509 | " create VAR model with optimal number of lags, make forecast and compute smape\n", 510 | " \n", 511 | " :param data_train: training multivariete time series\n", 512 | " :param data_validation: validation multivariete time series\n", 513 | " :param max_lags: maximum number of lags in VAR model\n", 514 | " :type max_lags: int\n", 515 | " :param ic: information criterion used to choose best VAR model\n", 516 | " :param plots: if True forecast vs real_values plot will be printed\n", 517 | " \n", 518 | " returns: smape metric\n", 519 | " \"\"\"\n", 520 | " \n", 521 | " #DESEASONALIZATION\n", 522 | " data_train_deseasonalized = pd.DataFrame()\n", 523 | " seasonalized_all = pd.DataFrame()\n", 524 | "\n", 525 | " for col in data_train.columns.values:\n", 526 | " seasonalized_all[col] = compute_seasonality(data_train[col], categorization)\n", 527 | " series = remove_seasonal_component(data_train[col], categorization, seasonalized_all[col])\n", 528 | " data_train_deseasonalized[col] = series\n", 529 | " data_train_deseasonalized.index = pd.DatetimeIndex(start=\"01-01-2013\",\n", 530 | " periods=1461, freq='D')\n", 531 | " \n", 532 | " #TIME SERIES DIFFERENCING\n", 533 | " data_train_deseasonalized_differenced = data_train_deseasonalized.diff().iloc[1:]\n", 534 | " data_train_deseasonalized_differenced = data_train_deseasonalized_differenced.asfreq('D')\n", 535 | " \n", 536 | " #MODELING\n", 537 | " var_model = smt.VAR(data_train_deseasonalized_differenced,\n", 538 | " dates=data_train_deseasonalized_differenced.index)\n", 539 | " \n", 540 | " var_model_results = var_model.fit(maxlags=max_lags, ic=ic, verbose=True)\n", 541 | " lags = var_model_results.k_ar\n", 542 | " \n", 543 | " #FORECASTING\n", 544 | " forecast = pd.DataFrame(var_model_results.forecast(\n", 545 | " data_train_deseasonalized_differenced.values[-lags:], 90))\n", 546 | " forecast.columns +=1\n", 547 | " \n", 548 | " #FORCAST REVERSE DIFFERENCING\n", 549 | " var_forecast_diff = data_train_deseasonalized.tail(1)\n", 550 | "\n", 551 | " var_forecast_diff = var_forecast_diff.append(forecast)\n", 552 | " var_forecast = var_forecast_diff.cumsum().iloc[1:]\n", 553 | " var_forecast.index = pd.DatetimeIndex(start='01-01-2017', periods=90, freq='D')\n", 554 | " \n", 555 | "\n", 556 | " \n", 557 | " # FORCAST - APPLY SEASONALITY\n", 558 | " corrected_forecast = pd.DataFrame()\n", 559 | " for col in var_forecast.columns.values:\n", 560 | " tmp_forecast = pd.DataFrame()\n", 561 | " tmp_forecast = add_seasonal_component(var_forecast[col].copy(),\n", 562 | " categorize_by_week_of_year,\n", 563 | " seasonalized_all[col])\n", 564 | " corrected_forecast[col] = tmp_forecast\n", 565 | " corrected_forecast.index = pd.DatetimeIndex(start='01-01-2017', periods=90,freq='D')\n", 566 | " \n", 567 | " # forecast vs real_values plot\n", 568 | " if plots:\n", 569 | " plt.figure()\n", 570 | " corrected_forecast.iloc[:,0].plot(figsize=(12,16), color='b')\n", 571 | " data_validation.iloc[:90,0].plot(color='r')\n", 572 | " plt.show()\n", 573 | " smape = compute_avg_smape(df_y=data_validation.iloc[:90], df_y_pred=corrected_forecast)\n", 574 | " print(smape)\n", 575 | " return smape\n" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": 19, 581 | "metadata": {}, 582 | "outputs": [], 583 | "source": [ 584 | "def evaluate_multiple_var (data, num_time_series):\n", 585 | " \"\"\"\n", 586 | " Compute and evaluate VAR model for every batch of multivariete time series\n", 587 | " :param data: multiindexed columns dataframe with multivariete time series\n", 588 | " :param number_of_series: number of mutivariete time series\n", 589 | " :returns: average SMAPE metric\n", 590 | " \"\"\"\n", 591 | " smape_all = []\n", 592 | " for i in range(1,num_time_series+1):\n", 593 | " single_time_series = data.loc[:,i]\n", 594 | " \n", 595 | " single_time_series_train = single_time_series.iloc[:-365]\n", 596 | " single_time_series_validation = single_time_series.iloc[-365:]\n", 597 | " \n", 598 | " smape = var_train_and_forecast(single_time_series_train, single_time_series_validation, categorize_by_week_of_year,\n", 599 | " plots=False)\n", 600 | " smape_all.append(smape)\n", 601 | " \n", 602 | " return smape_all" 603 | ] 604 | }, 605 | { 606 | "cell_type": "markdown", 607 | "metadata": {}, 608 | "source": [ 609 | "#### Create a model for each store" 610 | ] 611 | }, 612 | { 613 | "cell_type": "code", 614 | "execution_count": 20, 615 | "metadata": {}, 616 | "outputs": [ 617 | { 618 | "data": { 619 | "text/html": [ 620 | "
\n", 621 | "\n", 638 | "\n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | "
store1...10
item123456...454647484950
date
2013-01-01133315101131...453711251733
2013-01-0211433011636...453313241337
2013-01-031423148818...562816291946
2013-01-0413181019919...504411392351
2013-01-0510342312831...624516342241
\n", 762 | "

5 rows × 500 columns

\n", 763 | "
" 764 | ], 765 | "text/plain": [ 766 | "store 1 ... 10 \n", 767 | "item 1 2 3 4 5 6 ... 45 46 47 48 49 50\n", 768 | "date ... \n", 769 | "2013-01-01 13 33 15 10 11 31 ... 45 37 11 25 17 33\n", 770 | "2013-01-02 11 43 30 11 6 36 ... 45 33 13 24 13 37\n", 771 | "2013-01-03 14 23 14 8 8 18 ... 56 28 16 29 19 46\n", 772 | "2013-01-04 13 18 10 19 9 19 ... 50 44 11 39 23 51\n", 773 | "2013-01-05 10 34 23 12 8 31 ... 62 45 16 34 22 41\n", 774 | "\n", 775 | "[5 rows x 500 columns]" 776 | ] 777 | }, 778 | "execution_count": 20, 779 | "metadata": {}, 780 | "output_type": "execute_result" 781 | } 782 | ], 783 | "source": [ 784 | "stores_time_series = pd.pivot_table(\n", 785 | " columns=['store','item'], values='sales',\n", 786 | " index=store_item_data.index, data=store_item_data).asfreq('D')\n", 787 | "stores_time_series.head()" 788 | ] 789 | }, 790 | { 791 | "cell_type": "code", 792 | "execution_count": 21, 793 | "metadata": { 794 | "scrolled": true 795 | }, 796 | "outputs": [ 797 | { 798 | "name": "stdout", 799 | "output_type": "stream", 800 | "text": [ 801 | " 16, BIC -> 1, FPE -> 6, HQIC -> 2>\n", 802 | "Using 16 based on aic criterion\n", 803 | "0.1614194423273425\n", 804 | " 16, BIC -> 1, FPE -> 6, HQIC -> 2>\n", 805 | "Using 16 based on aic criterion\n", 806 | "0.13758328818926077\n", 807 | " 16, BIC -> 1, FPE -> 6, HQIC -> 2>\n", 808 | "Using 16 based on aic criterion\n", 809 | "0.1526445758664326\n", 810 | " 16, BIC -> 1, FPE -> 6, HQIC -> 2>\n", 811 | "Using 16 based on aic criterion\n", 812 | "0.15339676809920877\n", 813 | " 16, BIC -> 1, FPE -> 6, HQIC -> 2>\n", 814 | "Using 16 based on aic criterion\n", 815 | "0.17406155815250476\n", 816 | " 16, BIC -> 1, FPE -> 6, HQIC -> 2>\n", 817 | "Using 16 based on aic criterion\n", 818 | "0.17554093888197206\n", 819 | " 16, BIC -> 1, FPE -> 6, HQIC -> 2>\n", 820 | "Using 16 based on aic criterion\n", 821 | "0.18497158371207476\n", 822 | " 16, BIC -> 1, FPE -> 6, HQIC -> 2>\n", 823 | "Using 16 based on aic criterion\n", 824 | "0.1408921398129204\n", 825 | " 16, BIC -> 1, FPE -> 6, HQIC -> 2>\n", 826 | "Using 16 based on aic criterion\n", 827 | "0.15685467489594204\n", 828 | " 16, BIC -> 1, FPE -> 6, HQIC -> 2>\n", 829 | "Using 16 based on aic criterion\n", 830 | "0.15031806865439024\n" 831 | ] 832 | } 833 | ], 834 | "source": [ 835 | "all_stores_smape = evaluate_multiple_var(stores_time_series, number_of_stores) " 836 | ] 837 | }, 838 | { 839 | "cell_type": "code", 840 | "execution_count": 22, 841 | "metadata": { 842 | "scrolled": true 843 | }, 844 | "outputs": [ 845 | { 846 | "data": { 847 | "image/png": "\n", 848 | "text/plain": [ 849 | "
" 850 | ] 851 | }, 852 | "metadata": {}, 853 | "output_type": "display_data" 854 | } 855 | ], 856 | "source": [ 857 | "sns.distplot(all_stores_smape)\n", 858 | "plt.show()" 859 | ] 860 | }, 861 | { 862 | "cell_type": "code", 863 | "execution_count": 23, 864 | "metadata": {}, 865 | "outputs": [ 866 | { 867 | "data": { 868 | "text/plain": [ 869 | "DescribeResult(nobs=10, minmax=(0.13758328818926077, 0.18497158371207476), mean=0.15876830385920487, variance=0.0002354032604925673, skewness=0.3271973487275648, kurtosis=-0.9587751430852118)" 870 | ] 871 | }, 872 | "execution_count": 23, 873 | "metadata": {}, 874 | "output_type": "execute_result" 875 | } 876 | ], 877 | "source": [ 878 | "describe(all_stores_smape)" 879 | ] 880 | }, 881 | { 882 | "cell_type": "markdown", 883 | "metadata": {}, 884 | "source": [ 885 | "#### Create a model for each item" 886 | ] 887 | }, 888 | { 889 | "cell_type": "code", 890 | "execution_count": 24, 891 | "metadata": {}, 892 | "outputs": [ 893 | { 894 | "data": { 895 | "text/html": [ 896 | "
\n", 897 | "\n", 914 | "\n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | "
item1...50
store123456...5678910
date
2013-01-01131219101120...192021453633
2013-01-02111681296...252330544437
2013-01-0314161081211...283820542946
2013-01-041320151587...273327524351
2013-01-05101622191312...313318485341
\n", 1038 | "

5 rows × 500 columns

\n", 1039 | "
" 1040 | ], 1041 | "text/plain": [ 1042 | "item 1 ... 50 \n", 1043 | "store 1 2 3 4 5 6 ... 5 6 7 8 9 10\n", 1044 | "date ... \n", 1045 | "2013-01-01 13 12 19 10 11 20 ... 19 20 21 45 36 33\n", 1046 | "2013-01-02 11 16 8 12 9 6 ... 25 23 30 54 44 37\n", 1047 | "2013-01-03 14 16 10 8 12 11 ... 28 38 20 54 29 46\n", 1048 | "2013-01-04 13 20 15 15 8 7 ... 27 33 27 52 43 51\n", 1049 | "2013-01-05 10 16 22 19 13 12 ... 31 33 18 48 53 41\n", 1050 | "\n", 1051 | "[5 rows x 500 columns]" 1052 | ] 1053 | }, 1054 | "execution_count": 24, 1055 | "metadata": {}, 1056 | "output_type": "execute_result" 1057 | } 1058 | ], 1059 | "source": [ 1060 | "items_time_series = pd.pivot_table(\n", 1061 | " columns=['item','store'], values='sales',\n", 1062 | " index=store_item_data.index, data=store_item_data).asfreq('D')\n", 1063 | "items_time_series.head()" 1064 | ] 1065 | }, 1066 | { 1067 | "cell_type": "code", 1068 | "execution_count": 25, 1069 | "metadata": { 1070 | "scrolled": true 1071 | }, 1072 | "outputs": [ 1073 | { 1074 | "name": "stdout", 1075 | "output_type": "stream", 1076 | "text": [ 1077 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1078 | "Using 13 based on aic criterion\n", 1079 | "0.2043320424047193\n", 1080 | " 15, BIC -> 6, FPE -> 15, HQIC -> 6>\n", 1081 | "Using 15 based on aic criterion\n", 1082 | "0.13445202689996147\n", 1083 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1084 | "Using 13 based on aic criterion\n", 1085 | "0.15741923536156543\n", 1086 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1087 | "Using 13 based on aic criterion\n", 1088 | "0.20120284034167052\n", 1089 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1090 | "Using 14 based on aic criterion\n", 1091 | "0.24970152598833356\n", 1092 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1093 | "Using 13 based on aic criterion\n", 1094 | "0.12942933421786357\n", 1095 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1096 | "Using 13 based on aic criterion\n", 1097 | "0.12863479551418702\n", 1098 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1099 | "Using 13 based on aic criterion\n", 1100 | "0.11866561463768463\n", 1101 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1102 | "Using 13 based on aic criterion\n", 1103 | "0.1376883394260311\n", 1104 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1105 | "Using 14 based on aic criterion\n", 1106 | "0.11284699217169543\n", 1107 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1108 | "Using 14 based on aic criterion\n", 1109 | "0.12012753168700632\n", 1110 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1111 | "Using 13 based on aic criterion\n", 1112 | "0.11787332996262459\n", 1113 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1114 | "Using 13 based on aic criterion\n", 1115 | "0.11491647120370825\n", 1116 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1117 | "Using 13 based on aic criterion\n", 1118 | "0.13187597914707128\n", 1119 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1120 | "Using 13 based on aic criterion\n", 1121 | "0.1079554425439\n", 1122 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1123 | "Using 13 based on aic criterion\n", 1124 | "0.19855118250488368\n", 1125 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1126 | "Using 13 based on aic criterion\n", 1127 | "0.1723317493868881\n", 1128 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1129 | "Using 13 based on aic criterion\n", 1130 | "0.11344391017057162\n", 1131 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1132 | "Using 14 based on aic criterion\n", 1133 | "0.15193098277727066\n", 1134 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1135 | "Using 13 based on aic criterion\n", 1136 | "0.14166204101891786\n", 1137 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1138 | "Using 13 based on aic criterion\n", 1139 | "0.16095980004545482\n", 1140 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1141 | "Using 13 based on aic criterion\n", 1142 | "0.11093603642335355\n", 1143 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1144 | "Using 13 based on aic criterion\n", 1145 | "0.1808236193490169\n", 1146 | " 16, BIC -> 6, FPE -> 16, HQIC -> 6>\n", 1147 | "Using 16 based on aic criterion\n", 1148 | "0.12227766031454101\n", 1149 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1150 | "Using 13 based on aic criterion\n", 1151 | "0.10351964176964461\n", 1152 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1153 | "Using 13 based on aic criterion\n", 1154 | "0.14792959942507267\n", 1155 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1156 | "Using 14 based on aic criterion\n", 1157 | "0.20894285395589313\n", 1158 | " 13, BIC -> 6, FPE -> 13, HQIC -> 7>\n", 1159 | "Using 13 based on aic criterion\n", 1160 | "0.1029443135882081\n", 1161 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1162 | "Using 14 based on aic criterion\n", 1163 | "0.12582718288542172\n", 1164 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1165 | "Using 13 based on aic criterion\n", 1166 | "0.15963237418646978\n", 1167 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1168 | "Using 14 based on aic criterion\n", 1169 | "0.13219985642766646\n", 1170 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1171 | "Using 13 based on aic criterion\n", 1172 | "0.15052725137817607\n", 1173 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1174 | "Using 13 based on aic criterion\n", 1175 | "0.12255504894529416\n", 1176 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1177 | "Using 13 based on aic criterion\n", 1178 | "0.19002541903165365\n", 1179 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1180 | "Using 13 based on aic criterion\n", 1181 | "0.12379437224104943\n", 1182 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1183 | "Using 14 based on aic criterion\n", 1184 | "0.11272356435197321\n", 1185 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1186 | "Using 13 based on aic criterion\n", 1187 | "0.1868260323414077\n", 1188 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1189 | "Using 13 based on aic criterion\n", 1190 | "0.11319047108255897\n", 1191 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1192 | "Using 13 based on aic criterion\n", 1193 | "0.15519510165703648\n", 1194 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1195 | "Using 13 based on aic criterion\n", 1196 | "0.1815004771798632\n", 1197 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1198 | "Using 13 based on aic criterion\n", 1199 | "0.2247066627129281\n", 1200 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1201 | "Using 14 based on aic criterion\n", 1202 | "0.16505144674459196\n", 1203 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1204 | "Using 14 based on aic criterion\n", 1205 | "0.1431511638352087\n", 1206 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1207 | "Using 13 based on aic criterion\n", 1208 | "0.17806507145740066\n", 1209 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1210 | "Using 14 based on aic criterion\n", 1211 | "0.11436285813577104\n", 1212 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1213 | "Using 13 based on aic criterion\n", 1214 | "0.12649667664699574\n", 1215 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1216 | "Using 13 based on aic criterion\n", 1217 | "0.21582488166098793\n" 1218 | ] 1219 | }, 1220 | { 1221 | "name": "stdout", 1222 | "output_type": "stream", 1223 | "text": [ 1224 | " 14, BIC -> 6, FPE -> 14, HQIC -> 6>\n", 1225 | "Using 14 based on aic criterion\n", 1226 | "0.13822003310461012\n", 1227 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1228 | "Using 13 based on aic criterion\n", 1229 | "0.18572490790042442\n", 1230 | " 13, BIC -> 6, FPE -> 13, HQIC -> 6>\n", 1231 | "Using 13 based on aic criterion\n", 1232 | "0.12751262851051898\n" 1233 | ] 1234 | } 1235 | ], 1236 | "source": [ 1237 | "all_items_smape = evaluate_multiple_var(items_time_series, number_of_items)" 1238 | ] 1239 | }, 1240 | { 1241 | "cell_type": "code", 1242 | "execution_count": 26, 1243 | "metadata": { 1244 | "scrolled": true 1245 | }, 1246 | "outputs": [ 1247 | { 1248 | "data": { 1249 | "text/plain": [ 1250 | "" 1251 | ] 1252 | }, 1253 | "execution_count": 26, 1254 | "metadata": {}, 1255 | "output_type": "execute_result" 1256 | }, 1257 | { 1258 | "data": { 1259 | "image/png": "\n", 1260 | "text/plain": [ 1261 | "
" 1262 | ] 1263 | }, 1264 | "metadata": {}, 1265 | "output_type": "display_data" 1266 | } 1267 | ], 1268 | "source": [ 1269 | "sns.distplot(all_items_smape)" 1270 | ] 1271 | }, 1272 | { 1273 | "cell_type": "code", 1274 | "execution_count": 27, 1275 | "metadata": {}, 1276 | "outputs": [ 1277 | { 1278 | "data": { 1279 | "text/plain": [ 1280 | "DescribeResult(nobs=50, minmax=(0.1029443135882081, 0.24970152598833356), mean=0.14912976749311493, variance=0.0013015468133833437, skewness=0.8241553410554039, kurtosis=-0.1881847700143413)" 1281 | ] 1282 | }, 1283 | "execution_count": 27, 1284 | "metadata": {}, 1285 | "output_type": "execute_result" 1286 | } 1287 | ], 1288 | "source": [ 1289 | "describe(all_items_smape)" 1290 | ] 1291 | } 1292 | ], 1293 | "metadata": { 1294 | "kernelspec": { 1295 | "display_name": "Python [conda env:tf]", 1296 | "language": "python", 1297 | "name": "conda-env-tf-py" 1298 | }, 1299 | "language_info": { 1300 | "codemirror_mode": { 1301 | "name": "ipython", 1302 | "version": 3 1303 | }, 1304 | "file_extension": ".py", 1305 | "mimetype": "text/x-python", 1306 | "name": "python", 1307 | "nbconvert_exporter": "python", 1308 | "pygments_lexer": "ipython3", 1309 | "version": "3.6.5" 1310 | } 1311 | }, 1312 | "nbformat": 4, 1313 | "nbformat_minor": 2 1314 | } 1315 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kaggle-Demand-Forecasting-Models 2 | This is a collection of models for a [kaggle demand forecasting competition](https://www.kaggle.com/c/demand-forecasting-kernels-only). 3 | 4 | We wanted to test as many models as possible and share the most interesting ones here. 5 | Make sure to check out [a series of blog posts that describe our exploration in detail](https://semantive.com/long-term-demand-forecasting/). 6 | 7 | -------------------------------------------------------------------------------- /XGB-features.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Preamble" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 2, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import pandas as pd\n", 18 | "from scipy.stats import describe\n", 19 | "pd.options.display.max_columns = 12\n", 20 | "pd.options.display.max_rows = 24\n", 21 | "\n", 22 | "# disable warnings in Anaconda\n", 23 | "import warnings\n", 24 | "warnings.simplefilter('ignore')\n", 25 | "\n", 26 | "# plots inisde jupyter notebook\n", 27 | "%matplotlib inline\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "\n", 30 | "import seaborn as sns\n", 31 | "sns.set(style='darkgrid', palette='muted')\n", 32 | "color_scheme = {\n", 33 | " 'red': '#F1637A',\n", 34 | " 'green': '#6ABB3E',\n", 35 | " 'blue': '#3D8DEA',\n", 36 | " 'black': '#000000'\n", 37 | "}\n", 38 | "\n", 39 | "import datetime as dt\n", 40 | "import xgboost as xgb" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "def smape(y_pred, y_true):\n", 50 | " # calculate error\n", 51 | " denom = (abs(y_pred) + abs(y_true)) / 2\n", 52 | " errors = abs(y_pred - y_true) / denom\n", 53 | " return 100 * np.sum(errors) / len(y_true)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "def serie_split(s, fcast_len = 90):\n", 63 | " \"\"\"\n", 64 | " We split our datasets: year 2017 is used for validation and the rest is for training.\n", 65 | " As our goal is to predict the first 90 days of 2018, we use only the first 90 days for validation.\n", 66 | " \"\"\"\n", 67 | " train = s.iloc[s.index < '2017-01-01']\n", 68 | " test = s.iloc[s.index >= '2017-01-01'].iloc[:fcast_len]\n", 69 | " \n", 70 | " return train, test" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "## Load data" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "df = pd.read_csv('../data/train.csv')\n", 87 | "df['date'] = pd.to_datetime(df['date'])\n", 88 | "df.index = pd.DatetimeIndex(df['date'])" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "## Feature extraction" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 6, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "from rolling import Rolling" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 7, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "def compute_over_interval(functions, interval):\n", 114 | " return np.array([f(interval) for f in functions])\n", 115 | "\n", 116 | "def features(prev_values, current_date, metadata=None):\n", 117 | " last_2months = prev_values[-60:]\n", 118 | " last_month = prev_values[-30:]\n", 119 | " last_week = prev_values[-7:]\n", 120 | " \n", 121 | " featureset = [\n", 122 | " lambda x: np.median(x),\n", 123 | " #lambda x: np.mean(x),\n", 124 | " #lambda x: np.min(x),\n", 125 | " #lambda x: np.max(x),\n", 126 | " lambda x: np.var(x),\n", 127 | " ]\n", 128 | " \n", 129 | " weekday_v = np.zeros(7)\n", 130 | " weekday_v[current_date.dayofweek] = 1\n", 131 | " \n", 132 | " return np.concatenate([\n", 133 | " #weekday_v,\n", 134 | " last_2months,\n", 135 | " compute_over_interval(featureset, last_month),\n", 136 | " prev_values[:30], # one month almost a year before\n", 137 | " ])\n" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "## Training" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 8, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "def train_and_validate(serie):\n", 154 | " r = Rolling(window=365, extract_features=features)\n", 155 | "\n", 156 | " model = xgb.XGBRegressor(n_jobs=-1)\n", 157 | " \n", 158 | " train, test = serie_split(serie)\n", 159 | " train_X, train_y = r.make_training_data(train)\n", 160 | "\n", 161 | " model.fit(train_X, train_y)\n", 162 | "\n", 163 | " y = r.predict(model.predict, train)\n", 164 | "\n", 165 | " return smape(y, test)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "## Run all models" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 9, 178 | "metadata": { 179 | "scrolled": true 180 | }, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "23.87113765793728\n", 187 | "16.99160090105278\n", 188 | "16.141614314794275\n", 189 | "28.64544802575536\n", 190 | "26.453506047715074\n", 191 | "15.990990711570038\n", 192 | "14.017887875891574\n", 193 | "12.456157106776926\n", 194 | "13.245837265091568\n", 195 | "12.687403477013433\n", 196 | "11.997499667995571\n", 197 | "12.336846490887913\n", 198 | "12.313948453848758\n", 199 | "14.51198997929609\n", 200 | "11.671782363569319\n", 201 | "21.336957734547575\n", 202 | "17.081446784094602\n", 203 | "13.579809834786705\n", 204 | "14.899975363219918\n", 205 | "14.680308257559899\n", 206 | "16.899036250421325\n", 207 | "15.791734558427525\n", 208 | "17.211455848666738\n", 209 | "13.718790631921538\n", 210 | "10.703860694996468\n", 211 | "14.54913591003008\n", 212 | "23.49658293051857\n", 213 | "12.131304941515413\n", 214 | "14.092751893101362\n", 215 | "15.70123948218163\n", 216 | "14.023008009442437\n", 217 | "17.90507033840503\n", 218 | "13.133528274251614\n", 219 | "20.855897924540958\n", 220 | "10.950528703107484\n", 221 | "10.494224777943742\n", 222 | "20.463013334106375\n", 223 | "12.476130841230422\n", 224 | "15.923918245805567\n", 225 | "22.576516722950696\n", 226 | "23.69048516137382\n", 227 | "16.344950834713902\n", 228 | "14.451519875800939\n", 229 | "19.25797359346662\n", 230 | "11.45298399368433\n", 231 | "12.369521804596776\n", 232 | "22.321977942648832\n", 233 | "11.557411972591478\n", 234 | "19.476030293772315\n", 235 | "13.083247076272398\n", 236 | "17.13121087936198\n", 237 | "12.13667540317369\n", 238 | "13.368218702790463\n", 239 | "17.047586016195034\n", 240 | "22.517937297218662\n", 241 | "10.435155917708512\n", 242 | "10.776582936759267\n", 243 | "11.855884116557345\n", 244 | "11.114365555668504\n", 245 | "12.199764407963691\n", 246 | "11.440809790523481\n", 247 | "11.100593301668535\n", 248 | "11.198155128271555\n", 249 | "12.109589756082748\n", 250 | "10.356547944904461\n", 251 | "14.80646295819545\n", 252 | "14.257632648772793\n", 253 | "8.6513308444192\n", 254 | "13.17292066855217\n", 255 | "12.552270733576133\n", 256 | "12.115588389348984\n", 257 | "10.842307449344938\n", 258 | "17.46884946627776\n", 259 | "13.092736462106958\n", 260 | "12.290984559541975\n", 261 | "11.250041140169133\n", 262 | "16.203003032787763\n", 263 | "13.180231563942339\n", 264 | "10.912923949065842\n", 265 | "13.501286956875104\n", 266 | "14.388617587968888\n", 267 | "14.00458471518346\n", 268 | "11.38261202347833\n", 269 | "20.40628863834678\n", 270 | "10.007711217842958\n", 271 | "9.152971948645867\n", 272 | "15.452625187815244\n", 273 | "11.19123525259093\n", 274 | "13.004616884169488\n", 275 | "16.595199926340676\n", 276 | "19.479194819049557\n", 277 | "13.913773652013052\n", 278 | "11.995088394428608\n", 279 | "17.628433929888114\n", 280 | "11.012331322639701\n", 281 | "12.71171144993314\n", 282 | "17.525241979829307\n", 283 | "11.254314693193647\n", 284 | "15.793015484923684\n", 285 | "13.488072398623398\n", 286 | "14.992590515424114\n", 287 | "13.588215675219532\n", 288 | "15.847018478060601\n", 289 | "19.96351852563633\n", 290 | "20.74517072709113\n", 291 | "13.09241682971775\n", 292 | "11.323054622371313\n", 293 | "11.458371004410045\n", 294 | "13.692410619762967\n", 295 | "10.778263169196977\n", 296 | "11.234377728351145\n", 297 | "10.41214053483852\n", 298 | "11.298310569736854\n", 299 | "14.30802335439698\n", 300 | "10.91522576310903\n", 301 | "16.587409215268337\n", 302 | "15.739105177565886\n", 303 | "11.684859296077265\n", 304 | "13.347847040708231\n", 305 | "12.477952115653407\n", 306 | "13.89082928963118\n", 307 | "10.200756927088035\n", 308 | "15.749579114144916\n", 309 | "12.466784763799106\n", 310 | "12.246724134887321\n", 311 | "20.12769034205891\n", 312 | "21.020605218495213\n", 313 | "10.13356911563521\n", 314 | "11.748002424858685\n", 315 | "15.233176177760136\n", 316 | "11.956233368453347\n", 317 | "12.905940800703256\n", 318 | "11.516637215722433\n", 319 | "18.14092412741941\n", 320 | "12.145031614918093\n", 321 | "8.978314276915642\n", 322 | "16.571169453482625\n", 323 | "10.713295708539642\n", 324 | "17.217896032613\n", 325 | "14.59886777077799\n", 326 | "20.25422760075955\n", 327 | "15.731924226507115\n", 328 | "11.51183058782278\n", 329 | "18.40977262195443\n", 330 | "10.950245508381855\n", 331 | "13.001020637778943\n", 332 | "23.75876152341625\n", 333 | "12.638836421187667\n", 334 | "16.720164126899917\n", 335 | "10.285712018775534\n", 336 | "21.030970289175777\n", 337 | "12.16803192353621\n", 338 | "13.724731398984053\n", 339 | "20.47835077644959\n", 340 | "24.843564212341008\n", 341 | "12.867274293556296\n", 342 | "15.082059963114627\n", 343 | "11.836780957829383\n", 344 | "12.74637051305049\n", 345 | "12.48087577729408\n", 346 | "12.620647507707437\n", 347 | "13.101967657235766\n", 348 | "11.536565734137724\n", 349 | "12.751307361354172\n", 350 | "10.766815482767152\n", 351 | "19.852873895317714\n", 352 | "14.887034698059523\n", 353 | "11.81027886015163\n", 354 | "14.720175282771354\n", 355 | "13.73274021750489\n", 356 | "15.573137471297496\n", 357 | "9.856330187495741\n", 358 | "16.070601459816643\n", 359 | "11.59236247160353\n", 360 | "12.771103702472026\n", 361 | "13.232840080698397\n", 362 | "22.482554939186482\n", 363 | "13.634841384244833\n", 364 | "12.74172206536606\n", 365 | "17.81035884952235\n", 366 | "13.005313472201962\n", 367 | "16.80400124327286\n", 368 | "11.523701425037224\n", 369 | "20.5990406107118\n", 370 | "13.435155049508598\n", 371 | "13.456303717480985\n", 372 | "17.46358455792213\n", 373 | "10.619775634348962\n", 374 | "14.96956324346667\n", 375 | "18.948875308893527\n", 376 | "21.25915218434177\n", 377 | "17.05976794683174\n", 378 | "12.74863370414432\n", 379 | "19.63598891546681\n", 380 | "10.176283925487455\n", 381 | "15.024859467779363\n", 382 | "18.453588223176315\n", 383 | "13.54781256918899\n", 384 | "16.613857689689148\n", 385 | "11.04322758603172\n", 386 | "23.554388666033553\n", 387 | "16.586818473342856\n", 388 | "19.42672943198623\n", 389 | "28.404701832366108\n", 390 | "24.790276416398957\n", 391 | "14.848371362461762\n", 392 | "17.719007269134497\n", 393 | "13.73826315900906\n", 394 | "15.722191542337518\n", 395 | "10.936683470492493\n", 396 | "15.361681286147604\n", 397 | "13.686354622456829\n", 398 | "12.995619692720384\n", 399 | "16.29779809799747\n", 400 | "13.73105041924479\n", 401 | "25.12786692870503\n", 402 | "20.180141464648838\n", 403 | "12.45853133301449\n", 404 | "16.99744831619319\n", 405 | "17.444276694625408\n", 406 | "17.77908360322428\n", 407 | "12.747352034712337\n", 408 | "19.760308564049932\n", 409 | "14.766251459859783\n", 410 | "11.17437589910694\n", 411 | "17.397304000862555\n", 412 | "25.582265175927457\n", 413 | "9.540540963948484\n", 414 | "16.37206106179275\n", 415 | "16.914826564842215\n", 416 | "18.439648557909678\n", 417 | "14.238792779557096\n", 418 | "13.96367360133653\n", 419 | "20.825691207903898\n", 420 | "11.289598326529765\n", 421 | "13.87248964530023\n", 422 | "17.125098098392932\n", 423 | "11.749394893222911\n", 424 | "17.72055519966441\n", 425 | "24.58680481930025\n", 426 | "27.271519279827817\n", 427 | "21.46091009819667\n", 428 | "16.412933894435557\n", 429 | "18.82775894110208\n", 430 | "12.777844079613033\n", 431 | "15.657519783526297\n", 432 | "28.081855706229817\n", 433 | "15.72274318952631\n", 434 | "18.95181095176833\n", 435 | "15.315657323355728\n", 436 | "22.05758975049735\n", 437 | "15.356703337482308\n", 438 | "19.670076476252365\n", 439 | "20.990333454479188\n", 440 | "24.194166575197304\n", 441 | "13.378168519898885\n", 442 | "15.714416061079838\n", 443 | "12.878215426345502\n", 444 | "14.44693662917413\n", 445 | "14.353770620300518\n", 446 | "15.428740824232776\n", 447 | "13.868949197933473\n", 448 | "10.785831605502361\n", 449 | "15.179868433137566\n", 450 | "12.246021479870024\n", 451 | "20.413883658927443\n", 452 | "20.955006446244028\n", 453 | "12.95687969203753\n", 454 | "18.477550822052923\n", 455 | "16.15106256569919\n", 456 | "17.055240767716487\n", 457 | "13.539764698901942\n", 458 | "22.35648046977383\n", 459 | "15.171945996342233\n", 460 | "13.560789943140739\n", 461 | "16.88959549167191\n", 462 | "20.907924924241975\n", 463 | "9.478902812897086\n", 464 | "15.155591290450143\n", 465 | "21.538816431237283\n", 466 | "12.665825328702114\n", 467 | "20.625474545980637\n", 468 | "13.31659990126007\n", 469 | "22.189121399007547\n", 470 | "14.696859465658198\n", 471 | "14.089862398868933\n", 472 | "20.32955574008201\n", 473 | "12.274346682274883\n", 474 | "15.266061925222177\n", 475 | "26.405088729873412\n", 476 | "24.80125536847212\n", 477 | "19.963538756538853\n", 478 | "17.170617715949426\n", 479 | "20.625275061417447\n", 480 | "13.989870212374177\n", 481 | "15.024333946699098\n", 482 | "23.12561409464823\n", 483 | "16.335969363998196\n", 484 | "20.07283742311259\n", 485 | "13.19575675863355\n", 486 | "24.087665265060664\n", 487 | "17.645842383844002\n", 488 | "16.304972375463652\n", 489 | "20.885761555127097\n", 490 | "31.58593318447265\n", 491 | "15.863329833511742\n", 492 | "17.295065894241404\n", 493 | "11.715222604896432\n", 494 | "14.734274631775442\n", 495 | "16.406754922637408\n", 496 | "16.89821806841404\n", 497 | "12.575677551870148\n", 498 | "14.22279477740662\n", 499 | "16.09941631391621\n", 500 | "12.17685198225884\n", 501 | "25.07354328378021\n", 502 | "21.46769426512005\n", 503 | "13.020715402026733\n", 504 | "19.1686190993485\n", 505 | "16.721810185269828\n", 506 | "18.959276584283188\n", 507 | "12.866446872664511\n", 508 | "21.647153642460626\n", 509 | "13.708761485748472\n", 510 | "13.377224026458292\n", 511 | "17.270328255474706\n", 512 | "25.837692239062683\n", 513 | "14.82653219212987\n", 514 | "14.519964376913258\n", 515 | "16.975486588095738\n", 516 | "16.294555179111367\n", 517 | "16.06872129709943\n", 518 | "12.902436599168459\n", 519 | "22.542904263765774\n", 520 | "14.139343767424053\n", 521 | "14.723766726131135\n", 522 | "17.998466213452193\n", 523 | "14.144058199726503\n", 524 | "20.246044034304003\n", 525 | "20.56777630809112\n", 526 | "24.240864615123247\n", 527 | "21.482230180703226\n", 528 | "16.905253876943945\n", 529 | "20.639411374742245\n", 530 | "15.46281836457442\n", 531 | "19.367381503611185\n", 532 | "28.268661773431152\n", 533 | "17.775207857380256\n", 534 | "22.54420450347106\n", 535 | "14.352218477625003\n", 536 | "17.680939148516366\n", 537 | "13.786499131443957\n", 538 | "16.579607654931298\n", 539 | "21.146746859415607\n", 540 | "19.819786593261103\n", 541 | "10.690048712303192\n", 542 | "11.686828057151988\n", 543 | "10.508672982282015\n", 544 | "13.07284720819801\n", 545 | "11.495923412170871\n", 546 | "11.101030881000481\n", 547 | "13.103455685461446\n", 548 | "9.285744636011113\n", 549 | "12.247577645877932\n", 550 | "12.327731968681903\n", 551 | "17.497588307790146\n", 552 | "16.43173589547952\n", 553 | "10.822581366989363\n", 554 | "15.306727111562592\n", 555 | "13.657933390937291\n", 556 | "14.070070841386453\n", 557 | "12.190052593716945\n", 558 | "20.66046875236637\n", 559 | "11.400653787682225\n", 560 | "12.964833958934186\n", 561 | "14.071418252504115\n", 562 | "17.87423281143325\n", 563 | "9.120533404092718\n", 564 | "12.479936378867976\n", 565 | "15.807171545324529\n", 566 | "11.384942795282363\n", 567 | "13.406454346699288\n", 568 | "12.018460921232508\n", 569 | "16.06756746700191\n", 570 | "10.428844525227486\n", 571 | "12.532323272998065\n", 572 | "16.193290370761268\n", 573 | "10.266061276622642\n", 574 | "14.319386958053203\n", 575 | "21.10764780824413\n", 576 | "18.41641731171029\n", 577 | "15.575998246184481\n", 578 | "13.861322104630888\n", 579 | "17.45298578708685\n", 580 | "8.537481686884203\n", 581 | "11.298221124967542\n", 582 | "19.321655808416736\n", 583 | "13.170198294604239\n", 584 | "16.529855244269168\n", 585 | "10.94032727420105\n", 586 | "19.48154457562275\n", 587 | "14.045420183041744\n", 588 | "16.20197764143328\n", 589 | "23.05207604875756\n", 590 | "22.62245630823433\n", 591 | "12.590115510502653\n", 592 | "12.033681449175834\n", 593 | "11.01620456829503\n", 594 | "13.209038966949246\n", 595 | "10.755645982093096\n", 596 | "12.778838732933583\n", 597 | "10.796854331186918\n", 598 | "10.655524890468417\n", 599 | "11.89030576733016\n", 600 | "9.745006378652677\n", 601 | "18.428707584980707\n", 602 | "18.110043078419963\n", 603 | "13.708717445581389\n", 604 | "15.865918791766621\n", 605 | "14.626778526034878\n", 606 | "14.482666293656509\n", 607 | "11.93536044571356\n", 608 | "16.313409383514095\n", 609 | "11.865647107110915\n", 610 | "13.40698322852571\n", 611 | "13.86687298776843\n", 612 | "22.632717860136264\n", 613 | "10.808146793612627\n", 614 | "12.207685234131011\n", 615 | "15.079863075168468\n", 616 | "13.240540905746009\n", 617 | "13.522876204524321\n", 618 | "11.570821066306571\n", 619 | "17.980588834621308\n", 620 | "13.082322008270676\n", 621 | "10.822184887615267\n", 622 | "16.957564357751536\n", 623 | "14.012660480171403\n", 624 | "15.551562035333783\n" 625 | ] 626 | }, 627 | { 628 | "name": "stdout", 629 | "output_type": "stream", 630 | "text": [ 631 | "16.60247320411955\n", 632 | "20.35714418948094\n", 633 | "16.196684747066165\n", 634 | "13.729617039841477\n", 635 | "20.31487748441642\n", 636 | "10.302271527138254\n", 637 | "13.439512183322044\n", 638 | "21.005290011416697\n", 639 | "13.450612797022417\n", 640 | "16.501114507759482\n", 641 | "12.179044066284076\n", 642 | "21.556963139298627\n", 643 | "12.33327321078022\n", 644 | "17.197581458803608\n", 645 | "18.979082334131174\n", 646 | "20.44431146256982\n", 647 | "12.537804134490816\n", 648 | "12.922281573790592\n", 649 | "12.268027217773907\n", 650 | "10.330956229030264\n", 651 | "11.425936846481155\n", 652 | "10.18338957919352\n", 653 | "12.308412556396458\n", 654 | "10.56482732403133\n", 655 | "13.953919730602312\n", 656 | "10.404180612311524\n", 657 | "16.883577362962534\n", 658 | "19.141086411749345\n", 659 | "11.029219866619513\n", 660 | "16.45456449090657\n", 661 | "11.452054181308702\n", 662 | "16.337782557352266\n", 663 | "12.429624179847984\n", 664 | "15.463232503803608\n", 665 | "12.618158391445041\n", 666 | "11.520684638919885\n", 667 | "14.571931802399169\n", 668 | "17.349889283112354\n", 669 | "8.456667533909537\n", 670 | "10.375180853237094\n", 671 | "15.613341780585433\n", 672 | "11.83141098177827\n", 673 | "14.578257162871653\n", 674 | "11.630192467104441\n", 675 | "20.65821984344638\n", 676 | "12.618388826047353\n", 677 | "11.030055298113528\n", 678 | "19.357369880038195\n", 679 | "11.331278223690086\n", 680 | "14.214754751934601\n", 681 | "16.09430299783017\n", 682 | "23.641589899706215\n", 683 | "15.392940603911075\n", 684 | "11.352340188028073\n", 685 | "15.61531035448032\n", 686 | "12.66681205225313\n", 687 | "12.924287724785202\n", 688 | "20.84914822096515\n", 689 | "13.323548408932878\n", 690 | "19.07914665710944\n", 691 | "12.351793563402447\n" 692 | ] 693 | } 694 | ], 695 | "source": [ 696 | "errors = []\n", 697 | "for store in range(1,11):\n", 698 | " for item in range(1,51):\n", 699 | " sales = df[(df[\"store\"] == store) & (df[\"item\"] == item)][\"sales\"]\n", 700 | " error = train_and_validate(sales)\n", 701 | " print(error)\n", 702 | " errors.append(error)" 703 | ] 704 | }, 705 | { 706 | "cell_type": "code", 707 | "execution_count": 10, 708 | "metadata": {}, 709 | "outputs": [ 710 | { 711 | "name": "stdout", 712 | "output_type": "stream", 713 | "text": [ 714 | "SMAPE = 15.37532205938642\n" 715 | ] 716 | } 717 | ], 718 | "source": [ 719 | "print(f\"SMAPE = {np.mean(errors)}\")" 720 | ] 721 | }, 722 | { 723 | "cell_type": "code", 724 | "execution_count": 11, 725 | "metadata": { 726 | "scrolled": false 727 | }, 728 | "outputs": [ 729 | { 730 | "data": { 731 | "image/png": "\n", 732 | "text/plain": [ 733 | "
" 734 | ] 735 | }, 736 | "metadata": {}, 737 | "output_type": "display_data" 738 | } 739 | ], 740 | "source": [ 741 | "sns.distplot(errors)\n", 742 | "plt.savefig('xgboost_features.svg')\n", 743 | "plt.show()" 744 | ] 745 | } 746 | ], 747 | "metadata": { 748 | "kernelspec": { 749 | "display_name": "Python [default]", 750 | "language": "python", 751 | "name": "python3" 752 | }, 753 | "language_info": { 754 | "codemirror_mode": { 755 | "name": "ipython", 756 | "version": 3 757 | }, 758 | "file_extension": ".py", 759 | "mimetype": "text/x-python", 760 | "name": "python", 761 | "nbconvert_exporter": "python", 762 | "pygments_lexer": "ipython3", 763 | "version": "3.6.5" 764 | } 765 | }, 766 | "nbformat": 4, 767 | "nbformat_minor": 2 768 | } 769 | -------------------------------------------------------------------------------- /XGB-simple.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Preamble" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 6, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import pandas as pd\n", 18 | "from scipy.stats import describe\n", 19 | "pd.options.display.max_columns = 12\n", 20 | "pd.options.display.max_rows = 24\n", 21 | "\n", 22 | "# disable warnings in Anaconda\n", 23 | "import warnings\n", 24 | "warnings.simplefilter('ignore')\n", 25 | "\n", 26 | "# plots inisde jupyter notebook\n", 27 | "%matplotlib inline\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "\n", 30 | "import seaborn as sns\n", 31 | "sns.set(style='darkgrid', palette='muted')\n", 32 | "color_scheme = {\n", 33 | " 'red': '#F1637A',\n", 34 | " 'green': '#6ABB3E',\n", 35 | " 'blue': '#3D8DEA',\n", 36 | " 'black': '#000000'\n", 37 | "}\n", 38 | "\n", 39 | "import datetime as dt\n", 40 | "import xgboost as xgb" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 7, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "def smape(y_pred, y_true):\n", 50 | " # calculate error\n", 51 | " denom = (abs(y_pred) + abs(y_true)) / 2\n", 52 | " errors = abs(y_pred - y_true) / denom\n", 53 | " return 100 * np.sum(errors) / len(y_true)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 8, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "def serie_split(s, fcast_len = 90):\n", 63 | " \"\"\"\n", 64 | " We split our datasets: year 2017 is used for validation and the rest is for training.\n", 65 | " As our goal is to predict the first 90 days of 2018, we use only the first 90 days for validation.\n", 66 | " \"\"\"\n", 67 | " train = s.iloc[s.index < '2017-01-01']\n", 68 | " test = s.iloc[s.index >= '2017-01-01'].iloc[:fcast_len]\n", 69 | " \n", 70 | " return train, test" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "## Load data" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 9, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "df = pd.read_csv('../data/train.csv')\n", 87 | "df['date'] = pd.to_datetime(df['date'])\n", 88 | "df.index = pd.DatetimeIndex(df['date'])" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "## Training" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 10, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "from rolling import Rolling" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 11, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "def train_and_validate(serie):\n", 114 | " r = Rolling(window=365)\n", 115 | "\n", 116 | " model = xgb.XGBRegressor(n_jobs=-1)\n", 117 | " \n", 118 | " train, test = serie_split(serie)\n", 119 | " train_X, train_y = r.make_training_data(train)\n", 120 | "\n", 121 | " model.fit(train_X, train_y)\n", 122 | "\n", 123 | " y = r.predict(model.predict, train)\n", 124 | "\n", 125 | " return smape(y, test)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "## Run all models" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 12, 138 | "metadata": { 139 | "scrolled": true 140 | }, 141 | "outputs": [ 142 | { 143 | "name": "stdout", 144 | "output_type": "stream", 145 | "text": [ 146 | "22.689162357309584\n", 147 | "15.933136269062517\n", 148 | "16.026707961098246\n", 149 | "20.225466231291445\n", 150 | "24.778601788395562\n", 151 | "14.897109251553532\n", 152 | "14.820856520478545\n", 153 | "12.832542859409747\n", 154 | "14.009638200506965\n", 155 | "12.951276558392589\n", 156 | "11.72057508791376\n", 157 | "12.411474588367865\n", 158 | "11.292299765594759\n", 159 | "14.028551047059297\n", 160 | "11.445466637311132\n", 161 | "20.92118566526828\n", 162 | "15.95219119986153\n", 163 | "13.344086687500095\n", 164 | "16.73486921150301\n", 165 | "13.675310834771217\n", 166 | "16.3575072897129\n", 167 | "13.345534222677225\n", 168 | "17.286568318988284\n", 169 | "13.053142977282237\n", 170 | "10.313500947983556\n", 171 | "16.680982732072636\n", 172 | "23.452430111455552\n", 173 | "12.745284090142642\n", 174 | "14.105701193713859\n", 175 | "15.243006762354677\n", 176 | "13.288044825150207\n", 177 | "17.095318647384143\n", 178 | "12.690455209420891\n", 179 | "19.100684462567507\n", 180 | "13.42892433070632\n", 181 | "11.011177021471864\n", 182 | "19.788817966797744\n", 183 | "12.093418245102647\n", 184 | "15.2509746341059\n", 185 | "21.59870587435788\n", 186 | "23.45756623930497\n", 187 | "18.3590818153269\n", 188 | "13.47522230088848\n", 189 | "17.781728190475523\n", 190 | "11.548469471420828\n", 191 | "11.23305061344768\n", 192 | "19.870837396537\n", 193 | "13.417975330517097\n", 194 | "18.918550510888164\n", 195 | "11.726966106186792\n", 196 | "17.899662319945524\n", 197 | "12.125572479818032\n", 198 | "13.325127484735134\n", 199 | "16.903883433734546\n", 200 | "22.011931308694653\n", 201 | "10.564950545871998\n", 202 | "11.59937678180833\n", 203 | "11.072777410070657\n", 204 | "10.74310363992739\n", 205 | "10.173401087046994\n", 206 | "11.04049041933104\n", 207 | "11.095445355949655\n", 208 | "10.241492674574832\n", 209 | "11.985830142677244\n", 210 | "10.501359520764101\n", 211 | "14.83578997298709\n", 212 | "14.06596578225513\n", 213 | "9.143243841542976\n", 214 | "12.555372730087015\n", 215 | "12.195077793312764\n", 216 | "10.659796769254326\n", 217 | "11.135164658845813\n", 218 | "19.22086164728629\n", 219 | "13.088411954311447\n", 220 | "11.940904402615692\n", 221 | "11.804486048075132\n", 222 | "16.50086061864994\n", 223 | "12.169657042306984\n", 224 | "10.839247740062591\n", 225 | "14.052421483039305\n", 226 | "14.056862648500234\n", 227 | "13.989662799805368\n", 228 | "11.437278789848872\n", 229 | "18.142771025630864\n", 230 | "10.523404404618093\n", 231 | "9.36751695558682\n", 232 | "15.636350879411696\n", 233 | "12.530098022726163\n", 234 | "12.244664703844638\n", 235 | "15.58574989655416\n", 236 | "19.164596757351926\n", 237 | "14.69990769666186\n", 238 | "12.091637896731699\n", 239 | "17.177406496280778\n", 240 | "11.257915278439562\n", 241 | "11.593090566064504\n", 242 | "18.162446072405324\n", 243 | "10.986719213045754\n", 244 | "16.50361934283281\n", 245 | "13.483503570419963\n", 246 | "16.14587208579454\n", 247 | "13.675864834741256\n", 248 | "15.375419786079062\n", 249 | "20.612413340858204\n", 250 | "19.87884677332824\n", 251 | "12.272652890423696\n", 252 | "11.34330382291547\n", 253 | "11.564197788109428\n", 254 | "14.125102995702065\n", 255 | "9.817930228099035\n", 256 | "11.632864144802605\n", 257 | "11.073551582854261\n", 258 | "9.357685695135434\n", 259 | "13.07735013647697\n", 260 | "10.803044208669126\n", 261 | "17.252125704022923\n", 262 | "15.67872622660318\n", 263 | "10.882755373682123\n", 264 | "12.472763889352693\n", 265 | "12.866504570189703\n", 266 | "13.862374360138217\n", 267 | "10.144263391754635\n", 268 | "15.45066821196627\n", 269 | "12.520877956831375\n", 270 | "11.111545012208001\n", 271 | "14.655894369372502\n", 272 | "19.40145560226068\n", 273 | "10.116034352236456\n", 274 | "12.121833609782648\n", 275 | "15.93907999106739\n", 276 | "12.24474298331014\n", 277 | "12.409733030905429\n", 278 | "11.361605016254636\n", 279 | "18.338166099304733\n", 280 | "12.351761558479309\n", 281 | "8.6375059409602\n", 282 | "16.40551232046509\n", 283 | "10.303219407129745\n", 284 | "16.916683766516226\n", 285 | "16.216108693070506\n", 286 | "19.180985469268897\n", 287 | "15.603075017758584\n", 288 | "12.404162905200678\n", 289 | "17.89877545151738\n", 290 | "10.286619101653363\n", 291 | "12.7712040985887\n", 292 | "20.743144778277234\n", 293 | "12.256610740269936\n", 294 | "15.619972648134423\n", 295 | "9.888861848292429\n", 296 | "21.480291151198962\n", 297 | "13.62713732946401\n", 298 | "14.668935517856353\n", 299 | "19.50665231399126\n", 300 | "24.05088989836412\n", 301 | "13.504356242544866\n", 302 | "14.37248237142913\n", 303 | "12.073263525309269\n", 304 | "13.024742680079147\n", 305 | "11.720718484235572\n", 306 | "12.329075641066396\n", 307 | "12.749472183172562\n", 308 | "9.841780031353965\n", 309 | "12.706804121997619\n", 310 | "10.322894656382346\n", 311 | "20.784854467729513\n", 312 | "14.62341496092043\n", 313 | "11.034250466028558\n", 314 | "15.682606093298839\n", 315 | "12.511796509543393\n", 316 | "14.805076981084696\n", 317 | "9.558874789351467\n", 318 | "17.311138879757145\n", 319 | "12.012666016451853\n", 320 | "11.577468680294073\n", 321 | "12.92838663934494\n", 322 | "23.85225511808796\n", 323 | "10.376446377693151\n", 324 | "12.037738497630954\n", 325 | "17.10447542522373\n", 326 | "10.87973849316099\n", 327 | "16.224840466674756\n", 328 | "11.506534201641415\n", 329 | "17.661575910742616\n", 330 | "13.474257710040757\n", 331 | "13.422029610589913\n", 332 | "17.111745121208806\n", 333 | "10.86890423904125\n", 334 | "14.137830039027389\n", 335 | "19.2011350645349\n", 336 | "22.02808201002264\n", 337 | "14.936747350904028\n", 338 | "13.177035473937563\n", 339 | "18.038527342803565\n", 340 | "9.922520077061622\n", 341 | "13.705867140398514\n", 342 | "17.427232570362317\n", 343 | "12.128255961444413\n", 344 | "17.030108669993982\n", 345 | "11.18642365875194\n", 346 | "22.14714208276171\n", 347 | "15.630262623826757\n", 348 | "19.47840501159135\n", 349 | "29.528957797261153\n", 350 | "24.27250819834894\n", 351 | "14.582876512282136\n", 352 | "15.885086319413677\n", 353 | "14.266984790794188\n", 354 | "15.201983645517448\n", 355 | "10.119543185463986\n", 356 | "14.12900495982926\n", 357 | "14.357889938965144\n", 358 | "13.015260977757567\n", 359 | "15.688321269552889\n", 360 | "13.961131990121412\n", 361 | "25.433142025725378\n", 362 | "18.225167187866003\n", 363 | "12.606831819477808\n", 364 | "17.41758095602241\n", 365 | "14.715177961451413\n", 366 | "15.330418603445418\n", 367 | "11.130945361213554\n", 368 | "18.717085163617288\n", 369 | "13.07757307805049\n", 370 | "11.07359942103203\n", 371 | "16.700172651259397\n", 372 | "25.466170869647417\n", 373 | "9.043990084470162\n", 374 | "14.831556001622971\n", 375 | "16.80519115947733\n", 376 | "17.91004971759213\n", 377 | "15.31738931667464\n", 378 | "13.612434405754337\n", 379 | "19.719276304459957\n", 380 | "11.891120818810487\n", 381 | "14.598674763420806\n", 382 | "18.29543268016767\n", 383 | "12.201524403573897\n", 384 | "18.348337749884934\n", 385 | "23.41321933143365\n", 386 | "25.24093176560065\n", 387 | "19.10680606055026\n", 388 | "16.50551223313307\n", 389 | "18.709563293185237\n", 390 | "12.67704618980239\n", 391 | "14.863702859666581\n", 392 | "27.245441852152545\n", 393 | "14.169380628738606\n", 394 | "19.993265979660105\n", 395 | "14.662783938326546\n", 396 | "21.253592927811134\n", 397 | "14.44927524715221\n", 398 | "19.472683355496862\n", 399 | "21.46681407082249\n", 400 | "25.226262679913287\n", 401 | "13.185581144309054\n", 402 | "14.406435222641244\n", 403 | "13.040904566917598\n", 404 | "14.479062761927006\n", 405 | "14.256301507196927\n", 406 | "14.894463455166692\n", 407 | "14.208588439829878\n", 408 | "10.196198590363887\n", 409 | "15.86392477066576\n", 410 | "11.593620524136824\n", 411 | "19.99918266789966\n", 412 | "19.618358684410364\n", 413 | "13.598515864271189\n", 414 | "17.49175171794763\n", 415 | "16.865040962524663\n", 416 | "16.360136320039633\n", 417 | "13.502790185403278\n", 418 | "22.245119568734967\n", 419 | "14.343799865507021\n", 420 | "13.824939924308188\n", 421 | "17.396618312292702\n", 422 | "20.581529949918213\n", 423 | "9.35537267822985\n", 424 | "13.961876614355129\n", 425 | "19.699183690665205\n", 426 | "11.654711296683146\n", 427 | "14.90593880469168\n", 428 | "12.787046969007998\n", 429 | "19.01994383287107\n", 430 | "16.378385834907647\n", 431 | "13.070233580015337\n", 432 | "20.905840735126347\n", 433 | "12.225486867200951\n", 434 | "15.00121840363095\n", 435 | "25.88097807499766\n", 436 | "24.704477027873292\n", 437 | "19.935797983231993\n", 438 | "16.266439630117876\n", 439 | "19.790826125307554\n", 440 | "14.572528077429526\n", 441 | "13.993728998075577\n", 442 | "23.999253250654704\n", 443 | "16.172648602154574\n", 444 | "21.003693842994462\n", 445 | "13.10133658723353\n", 446 | "24.908999450747135\n", 447 | "15.66458279259224\n", 448 | "17.108854089473212\n", 449 | "20.2540436461596\n", 450 | "32.63441997602402\n", 451 | "15.853848516730421\n", 452 | "15.798734005878671\n", 453 | "12.900223967905337\n", 454 | "16.210648519826574\n", 455 | "16.377465212500017\n", 456 | "16.79442472766424\n", 457 | "13.34996925203492\n", 458 | "12.5329995945093\n", 459 | "15.066864617312689\n", 460 | "12.921098337519226\n", 461 | "24.720290146862624\n", 462 | "21.277698660066342\n", 463 | "14.034471306606692\n", 464 | "16.691652936694034\n", 465 | "17.333520078625554\n", 466 | "17.52422047432821\n", 467 | "12.566960264771836\n", 468 | "22.076804074270154\n", 469 | "13.631946101898537\n", 470 | "13.344903016267972\n", 471 | "16.81559691228596\n", 472 | "26.010589232268497\n", 473 | "13.075416495321758\n", 474 | "12.671770888144154\n", 475 | "16.483014159990635\n", 476 | "16.22536827639414\n", 477 | "16.463271474684884\n", 478 | "12.804421348395454\n", 479 | "23.21885074267706\n", 480 | "14.294896951372268\n", 481 | "13.965097608430682\n", 482 | "19.319579095518392\n", 483 | "14.02169437383543\n", 484 | "18.73377187094678\n", 485 | "20.816422110111795\n", 486 | "25.49667774551089\n", 487 | "19.909338346357167\n", 488 | "16.785277475439297\n", 489 | "21.055059518674767\n", 490 | "12.922122081153757\n", 491 | "14.484270274053408\n", 492 | "26.857066898289265\n", 493 | "19.536206363309642\n", 494 | "22.698253095564098\n", 495 | "13.819638127746657\n", 496 | "18.710928224416822\n", 497 | "13.167550667214753\n", 498 | "16.215739686919505\n", 499 | "20.279670950681584\n", 500 | "20.310284372106437\n", 501 | "11.214633359281233\n", 502 | "12.1494868413472\n", 503 | "10.098760231408242\n", 504 | "13.332951119199127\n", 505 | "11.681328782637546\n", 506 | "10.770786170684278\n", 507 | "11.883415165024783\n", 508 | "9.151090543925001\n", 509 | "12.247541861822537\n", 510 | "10.842405147177981\n", 511 | "17.697487893293296\n", 512 | "15.355703307495785\n", 513 | "9.13270331141618\n", 514 | "14.890917274219294\n", 515 | "13.713974935942222\n", 516 | "13.611557373153573\n", 517 | "12.311156370195835\n", 518 | "16.51116566252099\n", 519 | "10.466378123248628\n", 520 | "10.365574656126542\n", 521 | "12.097013751939393\n", 522 | "17.9961321034557\n", 523 | "9.664420459944568\n", 524 | "12.384112748409219\n", 525 | "14.494305144797684\n", 526 | "10.679226647017263\n", 527 | "12.755548196532446\n", 528 | "12.23337840354084\n", 529 | "17.6229402778556\n", 530 | "10.001763474581846\n", 531 | "11.323360554792952\n", 532 | "15.802381492040107\n", 533 | "10.33753010775856\n", 534 | "14.835413008004442\n", 535 | "16.296171225451573\n", 536 | "19.370498476486954\n", 537 | "14.497922544492285\n", 538 | "13.146360226748648\n", 539 | "17.201317145231435\n", 540 | "8.251902479189445\n", 541 | "10.737825135767443\n", 542 | "20.199897407092973\n", 543 | "13.456641211335798\n", 544 | "16.367802967720323\n", 545 | "11.396113362155008\n", 546 | "18.876195734357346\n", 547 | "13.814865770636683\n", 548 | "15.741089240298827\n", 549 | "23.060991915159825\n", 550 | "22.571720031816106\n", 551 | "11.118991138394628\n", 552 | "11.635761046208737\n", 553 | "10.774575118737264\n", 554 | "13.631648386650225\n", 555 | "11.781984063699785\n", 556 | "12.026899752403663\n", 557 | "10.059186806987885\n", 558 | "10.848049277016788\n", 559 | "12.227827978161287\n", 560 | "9.770575249403056\n", 561 | "19.156738349458156\n", 562 | "17.293031916345804\n", 563 | "12.18837146965103\n", 564 | "16.333101084831064\n", 565 | "13.772948545602317\n", 566 | "15.502532416021142\n", 567 | "11.31345970106515\n", 568 | "17.494987796046008\n", 569 | "11.941767230995326\n", 570 | "10.273122673111963\n", 571 | "14.241058418133571\n", 572 | "22.770264995415236\n", 573 | "11.413999607771272\n", 574 | "12.190980948510216\n", 575 | "15.797925612863613\n", 576 | "12.86819113471635\n", 577 | "12.808993960319821\n", 578 | "11.11174570380878\n", 579 | "19.019013241441787\n", 580 | "12.691604196836945\n", 581 | "10.23701284622007\n", 582 | "16.29769834842618\n", 583 | "13.333033888405316\n", 584 | "15.653299770616394\n" 585 | ] 586 | }, 587 | { 588 | "name": "stdout", 589 | "output_type": "stream", 590 | "text": [ 591 | "17.325793993561494\n", 592 | "19.772733049768785\n", 593 | "15.821244381628741\n", 594 | "13.218674413328962\n", 595 | "20.955476164300713\n", 596 | "9.670311342215225\n", 597 | "12.936747286115937\n", 598 | "21.019373628970467\n", 599 | "14.197491821552019\n", 600 | "16.87577724545241\n", 601 | "11.189258360346686\n", 602 | "20.64460634287903\n", 603 | "12.203667413494484\n", 604 | "16.780079326133517\n", 605 | "18.54829953034621\n", 606 | "20.395421761513266\n", 607 | "11.639313981150824\n", 608 | "11.96040397345782\n", 609 | "12.000695879952183\n", 610 | "10.561389419273874\n", 611 | "11.84554029654193\n", 612 | "11.354734221784451\n", 613 | "11.684622504304457\n", 614 | "10.65106606105049\n", 615 | "14.290115472761126\n", 616 | "10.34008654107451\n", 617 | "17.271882459963894\n", 618 | "15.299682122874286\n", 619 | "10.18338856221456\n", 620 | "15.874507462859258\n", 621 | "11.448619774319132\n", 622 | "16.049436305570136\n", 623 | "12.529018700770262\n", 624 | "15.930307564905887\n", 625 | "12.154199612492082\n", 626 | "10.980786043529886\n", 627 | "15.390446980773929\n", 628 | "17.495350123406\n", 629 | "8.521542639362256\n", 630 | "9.11315197925933\n", 631 | "14.49406537922168\n", 632 | "11.674506346114528\n", 633 | "14.108076478678091\n", 634 | "11.745042995557485\n", 635 | "20.522879787044854\n", 636 | "13.279472073265064\n", 637 | "10.905702749522634\n", 638 | "17.107884415156082\n", 639 | "10.255979248915208\n", 640 | "13.082519873148033\n", 641 | "16.73105764586341\n", 642 | "23.218211475045532\n", 643 | "16.140341781060965\n", 644 | "12.12708375817042\n", 645 | "15.702247023494111\n", 646 | "11.992740598880195\n", 647 | "12.962843048439545\n", 648 | "21.281439476984154\n", 649 | "13.953557516562862\n", 650 | "19.057965421028744\n", 651 | "11.870214922812046\n" 652 | ] 653 | } 654 | ], 655 | "source": [ 656 | "errors = []\n", 657 | "for store in range(1,11):\n", 658 | " for item in range(1,51):\n", 659 | " sales = df[(df[\"store\"] == store) & (df[\"item\"] == item)][\"sales\"]\n", 660 | " error = train_and_validate(sales)\n", 661 | " print(error)\n", 662 | " errors.append(error)" 663 | ] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "execution_count": 13, 668 | "metadata": {}, 669 | "outputs": [ 670 | { 671 | "name": "stdout", 672 | "output_type": "stream", 673 | "text": [ 674 | "SMAPE = 15.041399804642642\n" 675 | ] 676 | } 677 | ], 678 | "source": [ 679 | "print(f\"SMAPE = {np.mean(errors)}\")" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": 14, 685 | "metadata": { 686 | "scrolled": false 687 | }, 688 | "outputs": [ 689 | { 690 | "data": { 691 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xt0VPW9///nnlsyyUwyScgNCOEWuYNgLYh4IRgiRKpA0G9rzxGXHM/3HD0uylJb23NYq/y+56j90mU59PQckeL39Fvbb6U1WkmFYsByEQVRGeSOEAiQTG6TyUxuc9u/P0JSQkgmgUlm9uT9WEvJZPbseX/YyWs2n/35fLaiqqqKEEKImKWLdAFCCCEGlgS9EELEOAl6IYSIcRL0QggR4yTohRAixhkiXcD1amrcYdlPSkoCTmdzWPYVCVqvH6QN0UDr9YP22zBY9aenW3t8LmbP6A0GfaRLuCVarx+kDdFA6/WD9tsQDfXHbNALIYRoJ0EvhBAxToJeCCFinAS9EELEOAl6IYSIcRL0QggR4yTohRAixknQCyFEjJOgF0KIGBd1SyCIgbPb3tiv7edPTxqgSoQQg0nO6IUQIsZJ0AshRIyToBdCiBgnQS+EEDFOgl4IIWKcBL0QQsQ4CXohhIhxEvRCCBHjJOiFECLGSdALIUSMk6AXQogY16eg37NnD4WFhRQUFLBp06Zuzx86dIilS5cyefJktm/f3vn9EydO8Nhjj1FUVMSSJUv405/+FL7KhRBC9EnIRc0CgQDr1q3jzTffJDMzk+LiYvLz8xk/fnznNtnZ2bz88sts2bKly2vj4+N59dVXGT16NA6Hg+XLlzNv3jySkmSxLCGEGCwhg95ut5Obm0tOTg4ARUVFlJWVdQn6kSNHAqDTdf0HwpgxYzq/zszMJDU1lfr6egl6IYQYRCGD3uFwkJWV1fk4MzMTu93e7zey2+34fD5GjRrV63YpKQkYDPp+7/9G0tOtYdlPpIS7fou1bdDfX+vHALTfBq3XD9pvQ6TrDxn0qqp2+56iKP16k+rqal544QVeffXVbmf913M6m/u1756kp1upqXGHZV+RMBD1e9z9C/pbfX+tHwPQfhu0Xj9ovw2DVX9vHyYhL8ZmZWVRVVXV+djhcJCRkdHnN/d4PPz93/89q1ev5vbbb+/z64QQQoRHyKCfNm0a5eXlVFRU4PV6KS0tJT8/v08793q9PPPMMzz88MMsWrTolosVQgjRfyG7bgwGA2vXrmXVqlUEAgGWL19OXl4eGzZsYOrUqSxYsAC73c6zzz5LY2Mju3fvZuPGjZSWlvLBBx/w2Wef0dDQQElJCQCvvPIKkyZNGvCGCSGEaKeoN+qEj6Bw9WVJv153g33PWK0fA9B+G7ReP2i/DZrooxdCCKFtIbtuRHTq79m5EGLokqAfYvwBlSPnm6l3+1GBOKPCjNEJ2CzyoyBErJLf7iGkocnPR0cbcXoCXb5f7mhjxpgEpo9OQKfr3xwJIUT0k6AfImpcPrYfbsAfhAkj4rljfCI6ncKVOi8HTnr44lwzdW4/+dOT+j0hTggR3STohwB/QGXvMTf+INw31crYrPjO53Iz4shKMbLL3sjFGi9fnGtm1rjECFYrhAg3GXUzBHxxrglXc4DJOeYuId8hzqhj/rQkrGYdR843c66qNQJVCiEGigR9jKt2+Th2oQWrWces8T2fqcebdCyYkYxRr7D/uBtPS6DHbYUQ2iJBH+MOnvKgAvMmWzHqe+97T7EYmDPBgj8IB097BqdAIcSAk6CPYdUuHzWNfnKGmchKMfXpNeOy48hINnChxsvR8vCsJCqEiCwJ+hh2/GILAFNGmfv8GkVRuGuiFQX4v7tr8fmjaoUMIcRNkKCPUZ7WAOXVbaRY9GSlGPv12lSrgYkj46ly+ig74hqgCoUQg0WCPkadqGhBVWHKqISbGhc/c1wi8SaF0kMNtPmCA1ChEGKwSNDHIH9A5fTlVuKNCmMy425qH3FGHQW3J+NqDrD7qKyrI4SWSdDHoMt1Xrx+lbzh8RhCjLTpzYN32Ig3tp/Ve+WsXgjNkqCPQeWO9nvDjr7Js/kOVrOegpnJuJrkrF4ILZOgjzH+gMrFWi8Ws440662vcPHgHTbijAp/+qwBf0BG4AihRRL0MeZKvRd/QGV0RlxYFiezmvXcNzUJpycgk6iE0CgJ+hjT2W2TcWvdNtcqmJmMAuz43EWU3XlSCNEHEvQxJBBUqaj1khinY1hS+BYmzbQZmTUukfOONk5fkQXPhNAaCfoYcqW+fbTN6MzwdNtcq/COZAB2HJYJVEJojaxHH0MqarxA+xrz4XDtfWlVVSXNauDw2Sb++KkTq1l/w9fMn54UlvcWQoSPnNHHkCv1Xox6hfQwdtt0UBSFyaPMqLTPuhVCaEefgn7Pnj0UFhZSUFDApk2buj1/6NAhli5dyuTJk9m+fXuX50pKSli4cCELFy6kpKQkPFWLbtwtAdwtQbJTjQN239cxmXGYTTpOX27F65cJVEJoRcigDwQCrFu3js2bN1NaWsq2bds4e/Zsl22ys7N5+eWXeeihh7p8v6GhgZ///Oe8/fbbbN26lZ///Oe4XNLHOxAq69u7bbJT+7Yc8c3Q6xQm5cTjC6ickYuyQmhGyKC32+3k5uaSk5ODyWSiqKiIsrKyLtuMHDmSiRMnotN13d2+ffu4++67sdlsJCcnc/fdd7N3797wtkAAcLneB8Dw1P6tVNlfE0aY0evgeEULQRlqKYQmhOzMdTgcZGVldT7OzMzEbrf3aec3eq3D4ej1NSkpCRgMN77Q11/p6daw7CdSeqvfYm3r/FpVVaqcPixmPSMyb261yr6yABNHJXKsvIlqt8r4EV3vQXt9zVo/BqD9Nmi9ftB+GyJdf8igv9EEmb4Gyc281ukMz12N0tOt1NS4w7KvSAhVv8f916Cva/TR6g0yPjuOJo93wGu7LdvEsfImDp9ykZXU9V9x19as9WMA2m+D1usH7bdhsOrv7cMkZNdNVlYWVVVVnY8dDgcZGRl9euNbea3ouyud3TYD1z9/LVuigZFpJqpdfmpcvkF5TyHEzQsZ9NOmTaO8vJyKigq8Xi+lpaXk5+f3aefz5s1j3759uFwuXC4X+/btY968ebdctOjqytULsYMV9ABTcttvT/jVRRlqKUS0C9l1YzAYWLt2LatWrSIQCLB8+XLy8vLYsGEDU6dOZcGCBdjtdp599lkaGxvZvXs3GzdupLS0FJvNxj/+4z9SXFwMwDPPPIPNZhvwRg0lwaBKdYMPW6Iec9zgTYvITjGSatFzwdGGe3ygxwlUQojIU9QoW6UqXH1Zsd6v1zFrtcblY9uhBiaMiGfupMG94HP2Sit7j7uZMsrMN2+zAF1nxmr9GID226D1+kH7bdBEH72IbtVX+8gzbAM7rPJGxmTFkRAnE6iEiHYS9BpX3eAHICN58IO+fQKVGd/Ve9QKIaKTBL2GqapKtcuH2aRgNUfmUE4YEY9BB8cvthAMRlUvoBDiKgl6DWtqDdLcFiQj2Tigk6R6E2fUkTc8nqa2IOXVbaFfIIQYdBL0GuaIYP/8tSaPSgDg2MUWuQOVEFFIgl7DqhuiI+iTEvSMSjdR2+jn1CXpqxci2kjQa1h1gx+9DtKskb9/zLTc9rP69z51RrgSIcT1JOg1yucP4vT4GZZkQD9A68/3R4bNyPBUI8cutnD6ssyWFSKaSNBrVI3LjwqkR2BYZU9uH5sIwLufyFm9ENFEgl6jahsjN36+J5k2I1NGmfnqgpzVCxFNJOg1qqax/ULssAG4P+yteOSuFADeOeCUEThCRAkJeo2qa/RjNikkDOJCZn0xYYSZablmjl9s4fBp7a5PIkQsia6UEH3iavLT1BZkWFLkJkr15rF701CAX35wRWbLChEFJOg16LyjfQZqWpR123QYlR7HvClWyh3tq1sKISJLgl6DzlW1B3209c9fa/ncVOKMCu/sr6fNJytbChFJEvQa1HFGPywpekbcXC/VamDZvAycTQH+KJOohIgoCXqNUVWV8442EuN1mE3RffgevT+DNKuBP33WwOW6gb9puRDixqI7KUQ39Z4Ajc2BqO626RBv0vPd+cMIBOFXu2pkuKUQESJBrzHnqtoXDYvmbptrzRqXwIwxCZyoaOXASU+kyxFiSJKg15hyR/RfiL2Woij8zfxhmAwKb31Ui7slEOmShBhyJOg1pnNoZRSsWNlXGTYjy+em4m4J8tZHtZEuR4ghR4JeQ1RV5UJ1G+nJBuKM2jp0C2clMyYzjo9PeDhyvinS5QgxpGgrLYY4pyeAuyVIbnpcpEvpN71O4amF6eh18OaHtbR4ZWy9EIOlT0G/Z88eCgsLKSgoYNOmTd2e93q9rF69moKCAlasWMGlS5cA8Pl8fP/732fJkiUsWrSI119/PbzVDzEXrt6TNTdDe0EP7TNmH7ozhXq3n6176yJdjhBDRsigDwQCrFu3js2bN1NaWsq2bds4e/Zsl222bt1KUlISO3fuZOXKlaxfvx6A7du34/V6ef/993nnnXf43e9+1/khIPrvr0FvinAlN+9bs1MYnmqk7Egjp2QpYyEGRcigt9vt5ObmkpOTg8lkoqioiLKysi7b7Nq1i6VLlwJQWFjIgQMHUFUVRVFoaWnB7/fT2tqK0WjEYrEMTEuGgAs17ZOOtHpGD2A0KDy1MAOALX+uweuXLhwhBlrIoRsOh4OsrKzOx5mZmdjt9m7bZGdnt+/QYMBqteJ0OiksLKSsrIx58+bR2trKSy+9hM1m6/X9UlISMBj0N9OWbtLTrWHZT6RcX/+lugpsFgN5o22crYnOYYrX13yjY5CebuVbF9t47+NadtqbWVmYPVjl3ZRY+znSIq23IdL1hwz6G81mvH5p3J62sdvt6HQ69u7dS2NjI9/5zneYO3cuOTk5Pb6f09ncl7pDSk+3UlOj3ZUTr6+/qTWAw+llWq6Z2loPHndbBKvr2bU193YMimZZ2f9VA7/f42DqSGPU/isl1n6OtEjrbRis+nv7MAnZdZOVlUVVVVXnY4fDQUZGRrdtKisrAfD7/bjdbmw2G9u2beOee+7BaDSSlpbGrFmzOHr06M22Y0i7UN3ebTMqSgOxv+JNOp58IJ1AEH7552oCsm69EAMm5Bn9tGnTKC8vp6KigszMTEpLS/npT3/aZZv8/HxKSkqYOXMmO3bsYM6cOSiKQnZ2Np9++ikPP/wwLS0tHDlyhCeeeGLAGhPLLtRoY8TNbntj59cWa1vIf3ncM8XK3mNuth9uoOjOlIEuT4ghKeQZvcFgYO3ataxatYrFixezaNEi8vLy2LBhQ+dF2eLiYhoaGigoKODNN9/k+eefB+Dxxx+nqamJhx56iOLiYpYtW8bEiRMHtkUxKhZG3NzIt+9NIylBT8kBJ/Vuf6TLESImKWqULSkYrr6sWOvXe+m/L1LX6Oe/nh2DTlG6nDlHK4s1LuQZ/fzpSfzlq0Z++eca5k6y8D8XZQ5SdX0Taz9HWqT1NkRDH712FkwZwry+IJX1PsZlx6OLwnvE3ord9kZUVSXNauDjEx5SLAYykntfmXP+9KRBqk6I2CBLIGjApTovQRVGpcdWt00HRVGYPaF9fsWnpzyybr0QYSZBrwEVtVdH3GhwjZu+yrQZGZMZR22jn/Lq6Bw6KoRWSdBrQEVNR9DH5hl9h1njElEU+OLrZoIy3FKIsJGg14CKmjYUYOSw2A76pAQ9ecPjcTUH+LpKzuqFCBcJ+iinqioXa7xk2IyaW4P+ZswYk4BOgS/PNckkKiHCJPaTQ+PqPQGa2oIx323TwRKvZ8JIM57WIGeutEa6HCFiggR9lKu4OiM2li/EXm/66AT0OjhaLn31QoSDBH2Uu3j1QmzOEDmjB0iI05E3PB5Pa1BG4AgRBhL0Ue6vQyuHTtADTB2VgALYy5tlXL0Qt0iCPspdrGkjIU5HmnVoTWK2JugZkxmH0xPgcp030uUIoWkS9FHM6wtS5fSRM8zU7R4AQ8G00QkA2MvlloNC3AoJ+ih2qc6Lqg6t/vlrpVoNjEgz4mjwUdvoi3Q5QmiWBH0U65gRmzNs6Iy4ud7kHDMAJyrkrF6ImyVBH8UuDpGlD3ozIs1EUoKe8442Wr1yI3EhboYEfRSrqB0aSx/0RlEUJo2MJxCEU5flrF6ImyFBH6U6lj7ITBkaSx/0ZvzweAx6hZOXWmUClRA3YWiN2Yti198xStU30dwWJCPZqIm7SQ0kk0FHXnYcJy61cqFGhloK0V9D+1QxitW62keZpFj0Ea4kOkySi7JC3DQJ+ijVMZwwdYhNlOpJcqKB4antQy0vyLIIQvSLBH2U6jijl6D/q8mj2s/qd37hinAlQmiLBH2UqnX5MBkUEuPkEHUYmWbCatZx4KQHd0sg0uUIoRmSIlHIH1BxefykWAxDcumDnrQPtTTjC6h8dHRoX6AWoj/6FPR79uyhsLCQgoICNm3a1O15r9fL6tWrKSgoYMWKFVy6dKnzuZMnT/LYY49RVFTEkiVLaGuT/tVQnB4/KpAqF2K7GT88njijQtmRRrkDlRB9FDLoA4EA69atY/PmzZSWlrJt2zbOnj3bZZutW7eSlJTEzp07WblyJevXrwfA7/fzwgsv8OMf/5jS0lJ+9atfYTBIn3MoTo8fkP75G4kz6rh7spV6t58vvm6KdDlCaELIoLfb7eTm5pKTk4PJZKKoqIiysrIu2+zatYulS5cCUFhYyIEDB1BVlf379zNhwgQmTpwIQEpKCnq9nKWGUu9p739OsUjQ38gDM5IB2PmlXJQVoi9CJonD4SArK6vzcWZmJna7vds22dnZ7Ts0GLBarTidTs6fP4+iKDz11FPU19ezePFi/u7v/q7X90tJScBgCM+HQXq6NSz7GQwW61+7tFzNrvalD7ISMBq0fRnFYg3/gmwzJ6UxY5yTI197aA4ayM00h/09rqWln6Mb0Xr9oP02RLr+kEF/o7v7XH+BsKdtAoEAhw8f5ve//z1ms5mVK1cydepU7rrrrh7fz+ls7kvdIaWnW6mpcYdlX4PB424PelVVqXX5SLYYaGvxoeUrGhZrXGe7wqmmxs39Uywc+drD27srWbkgPezv0UFrP0fX03r9oP02DFb9vX2YhDxdzMrKoqqqqvOxw+EgIyOj2zaVlZVAe7+82+3GZrORlZXFN7/5TVJTUzGbzdx7770cO3bsZtsxJHhag3j9KunJxkiXEtVuH5tAmtXA/uNumlplqKUQvQkZ9NOmTaO8vJyKigq8Xi+lpaXk5+d32SY/P5+SkhIAduzYwZw5c1AUhXnz5nHq1ClaWlrw+/0cOnSI8ePHD0xLYkS9u/1C7DAJ+l7pdQoP3J5Em0+GWgoRSsigNxgMrF27llWrVrF48WIWLVpEXl4eGzZs6LwoW1xcTENDAwUFBbz55ps8//zzACQnJ7Ny5UqKi4t55JFHmDx5Mvfff/+ANkjrOoI+3TZ0lybuq/unJRFnVNj5hQt/QIZaCtETRb1RB3sEhasvS2v9eh0rVJYdcXGxxstTi7IJ+vwRrurWDFQf/fzpSZ1f/3p3LX/+wsX/XJTB3Enhv+CltZ+j62m9ftB+GzTRRy8GV53bj9mkIyFehqH2ReGsZBQFth9uuOGgACGEBH1UafMFaWoNkmqVkO+r9GQj3xifSHm1l5OXWiNdjhBRSYI+inT0z8uM2P5ZdIcNgPcPOiNciRDRSYI+inQGvcyI7Zfxw+OZPMrMVxdaOHNFzuqFuJ4EfRSpu7rGTZqc0ffb0jkpALx7oD7ClQgRfSToo0i9249BB9YE6aPvrwkjzUzKiefohRa+rpSzeiGuJUEfJQJBlYamAClWAzpZg/6mLJ2TCkCJnNUL0YUEfZRo8PhRVem2uRUTc9rP6u3lLRy7GJ41k4SIBRL0UaJOLsSGxbfvG4YCvLW7Tm5MIsRVEvRRomMNehlaeWtGZ8Rx71Qrl+q8sgaOEFdJ0EeJercfBbnZSDgU351KvEnhD/vrZWVLIZCgjwpBVaXe7ScpUY9BLxdib1VyooGHZ6fgaQ3ym7/URbocISJOgj4K1Lr8+AIqaXI2HzaFs2yMzoxj7zE3h8/KvWXF0CZBHwUu1LSv8Cj98+Fj0Cv8/YMZGPUKW3ZW09is7ZVAhbgVEvRR4GK1F5CgD7cRaSZWzEvF3RLkjR01BGUUjhiiJFmiQOcZvXTd9EnH2v19YTIoDE81cuR8M+tLKrkzz9Kn11275r0QWidn9FHgYnUbZpMOc5wcjnBTFIX7pyWRlKDnqwstnJVFz8QQJMkSYe6WAPWegMyIHUBxRh0PzEjCZFDYf8LN5TpvpEsSYlBJ0EfYRbkQOyiSEw3Mn56EorTfrvFKvYS9GDok6CNMLsQOnuGpJhbMSEZV4cMvJezF0CFBH2EXqjsuxMrSxINhRJqJ/BlJnWHf8fcvRCyToI+w8472C7FJsgb9oMkZFkfB7e03Fd99tFEu0IqYJ0EfQc1tASqdPkZnmlBkDfpBNTzNxIOzbBj1CnuPu2VZYxHT+hT0e/bsobCwkIKCAjZt2tTtea/Xy+rVqykoKGDFihVcunSpy/NXrlxh5syZ/PKXvwxP1TGi3NHeRzw2Mz7ClQxN6clGFn/Dhtmk4+DpJj7/uglVlUlVIvaEDPpAIMC6devYvHkzpaWlbNu2jbNnz3bZZuvWrSQlJbFz505WrlzJ+vXruzz/8ssvc88994S38hhw3tHeZTAmKy7ClQxdKRYDRXfasJp1HDnfzCenPBL2IuaEDHq73U5ubi45OTmYTCaKioooKyvrss2uXbtYunQpAIWFhRw4cKDzl+XDDz9k5MiR5OXlDUD52nauqv1C4JhMCfpIspr1LP6GjRSLnpOXWtlzzI0/IGEvYkfIMX0Oh4OsrKzOx5mZmdjt9m7bZGdnt+/QYMBqteJ0OomPj+eNN95gy5YtbNmypU8FpaQkYDCE58Jkero1LPsZKBdqLpKUqGfSuBTK64PdnrdYtf8BoJU2WKxQfF887x+o5VxVG2/srONH3xkNRP/PUSharx+034ZI1x8y6G/0z9jrLxz2tM3GjRt54oknSExM7HNBTmd4Loqlp1upqXGHZV8DobE5QHWDj+mjE6it9eBxdx3mZ7HGdfue1mixDQ/MSKLsSxcHjrtY/R8neXD2MJo8ocfbR+vaONH+e9AXWm/DYNXf24dJyKDPysqiqqqq87HD4SAjI6PbNpWVlWRlZeH3+3G73dhsNo4cOcKOHTtYv349jY2N6HQ64uLi+O53v3sLzYkNHf3zY6V/PqoY9QoLZiSx/XMXZyvb2HfUxe2j42VUlNC0kEE/bdo0ysvLqaioIDMzk9LSUn7605922SY/P5+SkhJmzpzJjh07mDNnDoqi8Jvf/KZzm40bN5KQkCAhf9V56Z+PWkaDjoLbk/nT4Qa+/NpDvAEm5ZgjXZYQNy3kxViDwcDatWtZtWoVixcvZtGiReTl5bFhw4bOi7LFxcU0NDRQUFDAm2++yfPPPz/ghWvdOcfVoJcz+qgUb9KxcGYyZpOOT097ZLkEoWmKGmVjycLVlxXN/XqqqvLcpgvoFfjZ06OB7musa7F/+3qx0IbGNijZV4PRoPDQnSk9zmCWPvqBo/U2REMfvcyMjYB6TwBXU4AxWTJRKtoNHxbHXRMttPlUPjraSEDuUiU0SII+As5cXVtlfLZ022jBbSPM5A2Pp87t59AZudG40B4J+gjoWEQrb7ic0WvFnAkWbIl6TlS0yIqXQnMk6CPgzJVW9DoYLSNuNMOgb78loV4H+467aW4LRLokIfpMgn6QtfmCXKhuY3RmHCaD/PVrSYrFwJ15Frx+lY9PyJo4QjskaQbZ+ao2gqp022jVxJHxZKcYqaj18nWldOEIbZCgH2RnKq/2z2dL0GuRoijcPdmKQa/w6WkPTa3ShSOinwT9IOsccSNn9JplNev55m2JeP0q+0+4pQtHRD0J+kGkqipnrrQyLMlAikVuBq5ltw2PZ3iqkct1vs4PbyGilQT9IKpy+mhqDUr/fAzo6MIx6hUOnm6ittEX6ZKE6JEE/SA6I+PnY4olXs/sCRZ8AZXNf66RLhwRtSToB9HJS+1Bf9sICfpYMT47jpHDTBy/2NJtvSIhooUE/SBRVZXjFS1Y4nWMHGaKdDkiTBRF4e5JFhLidPx2Tx01LunCEdFHgn6QVLv81Lv9TMoxo5ObWMSUhDg9fzN/GG0+lc1/riYoXTgiykjQD5LjF1sAuYFFrJo7ycLMcQmcqGhl1xHpwhHRRYJ+kJyokKCPZYqi8OQD6STG6/h/e+qobpAuHBE9JOgHgaqqnKhoITlRz/BUY6TLEQPElmjgb+YPw+tXeWOHdOGI6CFBPwiu1PtwNQeYNNIsN5mOcXdNtHDH+EROXW7lz5+7Il2OEIAE/aDo6LaZPEq6bWKdoiisXDCMpAQ9v9tbR7lDFj4TkSdBPwikf35oSU408PSDGQSC8B+lDlq8wUiXJIY4CfoBFgiqHL/YQprVQEayrG8zVEwfncCiO5JxNPj47zKZNSsiS4J+gJ290kpTW5AZYxKkf36IWTEvjbFZcXx8wsPOL6S/XkSOBP0A+/J8MwC3j02IcCVisBn0Cs8tySI5Qc9v/lLHsYvNkS5JDFF9Cvo9e/ZQWFhIQUEBmzZt6va81+tl9erVFBQUsGLFCi5dugTA/v37WbZsGUuWLGHZsmUcOHAgvNVrwJfnmjAZFCZL//yQlGo18E9LMlEU+Pk2Bw6njK8Xgy9k0AcCAdatW8fmzZspLS1l27ZtnD17tss2W7duJSkpiZ07d7Jy5UrWr18PQEpKCv/5n//J+++/zyuvvMKLL744MK2IUjUuH5frfEzKMWMyyj+ehqrbRphZuSCdptYgr/7hCvVuf6RLEkNMyPSx2+3k5uaSk5ODyWSiqKiIsrKyLtvs2rWLpUuXAlBYWMiBAwdQVZXJkyeTmZkJQF5eHl6vF6/XOwDNiE5HrnbbzJRumyHvvmlJLJubQm2jn//9zhU8LXILQjF4Qg4DcTgcZGVldT7OzMzEbrd32yY7O7t9hwYDVqsVp9NJampq5zY7duyYMf0HAAAQL0lEQVRg0qRJmEy9r9yYkpKAwaDvVyN6kp5uDct+btbxS9UA5H8jg3Rb7+22WLuPt7ZY4wakrsE0VNrQl5+1VQ9ZCKDnvY9r+el7Dv7Xk2OxWQZ+pnSkfw/CQettiHT9IYP+RsPCrh89EmqbM2fOsH79erZs2RKyIKczPBes0tOt1NS4w7Kvm9HmC3Lkaw+j0k3ga6OmpveJMx531+ct1rhu39OaodSGvv6sLZ2dRKPby+6jjaz+j9N8vzibYUkDF/aR/j0IB623YbDq7+3DJGTXTVZWFlVVVZ2PHQ4HGRkZ3baprKwEwO/343a7sdlsAFRVVfHss8/y6quvMmrUqJtqgBbZzzfjC6jMGCPdNuKvdIrCygeG8dCdNhwNPv6/317mrNxzVgywkEE/bdo0ysvLqaiowOv1UlpaSn5+fpdt8vPzKSkpAdq7aObMmYOiKDQ2NvL000+zZs0a7rjjjoFpQZT65JQHgNkTLBGuREQbRVF49J40vnNfGg3NAf717ct8+KVLJlWJAROy68ZgMLB27VpWrVpFIBBg+fLl5OXlsWHDBqZOncqCBQsoLi7mhRdeoKCggOTkZF577TUAfv3rX3Px4kV+8Ytf8Itf/AKALVu2kJaWNrCtirAdnzfw+ddN2BL1nL3SyteV2u6+EKHdzG0EH7zDRk66iV+UOvjVrlqOljfzxIJ0Uq0yg1qEl6JG2WlEuPqyItmvt2l7NfuOu5k1LoEZYxJvah9DqX87mg1kG+ZPTwKg3u3n9e0OTlS0Em9SKJ6bSv6MZAz6W59JrfX+bdB+GzTRRy/671xVe5/r2Cy5CbgILdVq4AfFw3mqIB2dovDrj+r4wf+5yMcn3ASCUXUeJjRK/o0YZg1NfirrfaQnG7CawzNMVMSmG3X3fGt2CkfON3PqUgv/9UE1v/6olsk5ZvKGx1M4yxaBKkUskKAPs4OnPKjA2Ew5mxf9ZzbpmDPBwpRRZo6WN3O2spWDp5s4fLaJryvbmDfZypRcM3qdLJAn+k6CPoxUVeWjo24UBcZkan+ikIgcq1nP3ElWZo1L5PTlVs5UtvLJKQ+fnPKQnKDnrkkW5k60kpthklVRRUgS9GF08lIrl+q8jMmMwxwnlz/ErYs36Zg+JoFpo83kZsSx/7ibA6c8bD/sYvthF9kpRuZMtDBnooXslN5nX4uhS4I+jDrWHJc7SYlwUxSFcdnxjMuO5zv3D+PI+WYOnHTz5blmSg44KTngZHRmHHdNsDB7gkWGaIou5KchTGobfRz+uonRGSa5k5QYENdfvJ0yKoHbhsdzscbL11VtXKhuo9zRxm/31DE81cij96TxQJpM2BMS9GFTdqQRVYWCmcn4ZWFCMUiMBl3nmX6rN0h5dRtfV7Zypd7Hz96r4rd76imcmcS9U6yyVPYQJkEfBq3eIH852ojVrGP2BAv7j3siXZIYguJNOiaONDNxpJl6t5+Tl1o4W9nKr3bV8va+OmaMTmDCyN5H7HRM4hKxRT7iw2D75w14WoM8cHsyJoP8lYrIS7UamDvJysoHs5k+OoFAAD493cS7nzi5UN0m6+oMMXJGf4vcLQH+9FkDVrOOB++QCS0iuiTE6bljfCJTRpn58lwTJy+3ssveSKbNyJ15iaQnD/x6+CLy5PTzFv3xUyetXpVvzU7BbJK/ThGd4k065ky0snROCqPSTTgafGw71MD+425avcFIlycGmJzR34LaRh9lR1wMSzKQPz050uUIEVJyooEFM5Kpcnr55JSH01dauVDTxp15FsZnyyS/WCWnoDdJVVX+u6wWfwCWzU3FaJDZiUI7slJMfOubKXwjL5FAQGXfcTfbD7u4XDd07uk8lEjQ36Q9x9wcOd/M5FFm5k6SscpCe3Q6hWm5CSydm8qodBNVDT7++f9WsHVfHW0+6c6JJRL0N6G20cdbH9ViNulYtbB9aVkhtMoSr2fBjGQWzEjClmjg/YMN/PBXFRw51xTp0kSYSND3k8+v8l8fVNPqVfnu/LQBvbGzEINpVHocr6zMoegbNurdfn76bhUb36+i3u2PdGniFsnF2H4Iqiqvb3dw+nIr37wtkXmTe76jixBaFGfU8di9acydZOH/lNVy6EwTR8438+AdNorutMnIMo2So9ZHqqrym4/qOHi6iQkj4nn6wQxZHlbErJz0OH702HCeWphOQpyOP37q5PlfXqD0kFOGY2qQnNH3gT+g8t9lNfzlKzcj0oysfjhLZsCKmKdTFO6bmsScCRa2H3bxp88a+N3eekoPNfDA7cnMn55EikUiRAvkKIXQ2Oxn4/sOTl1uZXSGie89kk1ivNwiUAwdcUYdD89J4YHbk9j5hYsdn7t49xMn7x90MmtcInMnWZk+OkGGGEcxCfoeBIMqu4828vt99TS1BfnmbYn8XWEGcbICoIhhN7qP7bWSEw0snZvKuapWTlS0cOhME4fONJEQp2Pa6ASmjzYzZVSCrIcfZeRoXKfNF+TASQ9//tzFpTovZpOO796fRsHMZOmTFwIw6hUmjDBz2/B4xmbHc+CEm4Onm/j0lIdPT7Wv3Jpq0TM2K57sVCOZtqv/pRhJTtDL71EE9Cno9+zZw7/+678SDAZZsWIFTz/9dJfnvV4vL774IseOHcNms/Haa68xcuRIAF5//XV+//vfo9Pp+Od//mfuueee8LfiFqiqSr0nwMmKFuzlzRw530xzWxCdAvMmW3n0nlRsifJ5KMT1FEVhdEYcozPi+B/3pnG5zoe9vJlTl1s4e6WVz852H4dv1CsYDQpxRoU4o454o4LRoMOgVzDqwaBXrn6tdH79wJ3Q2tTa+RqjXkGnU9Dr6PxTr1PQKXR+iOw60n63N7Xzf+1/XLtop3rN/+6ZkkTw6pOq2v5tnQImgw6DHs1/OIVMsEAgwLp163jzzTfJzMykuLiY/Px8xo8f37nN1q1bSUpKYufOnZSWlrJ+/Xp+9rOfcfbsWUpLSyktLcXhcPDkk0+yY8cO9PqB6eNubgvgbgniD6i4fM3U1rbiC6j4r/7naQ3iavLT0BTA1RTA6fFzqdZLU9tfRxGkWPQU3J7C/OlJ8s9PIfpIURRGDjMxcpiJxd+woaoqTk8AR4OPKqeP6gYfVQ0+6t1+ahp9eFqCOD19u0NP2ZHeu5O61EFnrvfLrz+q6/E5nULnh5LJqOv8wLnhn4bu38+oC9La3Nb5WK8DBQVFofM/aP87TE7QY9CH/0MlZJLZ7XZyc3PJyckBoKioiLKysi5Bv2vXLp599lkACgsLWbduHaqqUlZWRlFRESaTiZycHHJzc7Hb7cycOTPsDXE1+Vmz+SK+QN8PswJk2IxMHmVmbFYc00cnMHKYSfOf3kJEmqIopFoNpFoN3e6h3HEdIBBUafOp+PztJ2f+YPsIN5//6slZUMUXUJk4ykq9q5U2X5BWX5BAoP21wSD4gyrBoEpAbb+uFlTbf69dTYH2L7j6h9L5kL/+eiudz6UnGTq3Ua5+oart9bX5gp1/tvqCuJrbHw/Ekv4TR8bzw0dHhH2/IYPe4XCQlZXV+TgzMxO73d5tm+zs7PYdGgxYrVacTicOh4MZM2Z0ea3D4ej1/dLTb24SUno6/PF/pdzUa8Pt0QUykUqInsjvx+ALOYTkRneiuf6Mt6dt+vJaIYQQAytk0GdlZVFVVdX52OFwkJGR0W2byspKAPx+P263G5vN1qfXCiGEGFghg37atGmUl5dTUVGB1+ultLSU/Pz8Ltvk5+dTUlICwI4dO5gzZw6KopCfn09paSler5eKigrKy8uZPn36wLRECCHEDYXsozcYDKxdu5ZVq1YRCARYvnw5eXl5bNiwgalTp7JgwQKKi4t54YUXKCgoIDk5mddeew2AvLw8Fi1axOLFi9Hr9axdu3bARtwIIYS4MUWV28ELIURMk/n8QggR4yTohRAixsXk1M/8/HwSExPR6XTo9XreeeedSJfUq5deeomPPvqItLQ0tm3bBkBDQwPf+973uHz5MiNGjOBnP/sZycnJEa60Zzdqw8aNG3n77bdJTU0FYM2aNdx3332RLLNHlZWVvPjii9TW1qLT6Xj00Ud54oknNHMceqpfS8egra2Nxx9/HK/XSyAQoLCwkOeee46KigrWrFmDy+Vi8uTJ/OQnP8FkMkW63BvqqQ0/+MEPOHjwIFZr+xyCV155hUmTJg1eYWoMmj9/vlpXVxfpMvrs4MGD6ldffaUWFRV1fu/VV19VX3/9dVVVVfX1119Xf/KTn0SqvD65URv+/d//Xd28eXMEq+o7h8OhfvXVV6qqqqrb7VYXLlyonjlzRjPHoaf6tXQMgsGg6vF4VFVVVa/XqxYXF6tffPGF+txzz6nbtm1TVVVV/+Vf/kV96623Illmr3pqw/e//331gw8+iFhd0nUTBe68885uZ4llZWU88sgjADzyyCN8+OGHkSitz27UBi3JyMhgypQpAFgsFsaOHYvD4dDMceipfi1RFIXExESgfT6O3+9HURQ++eQTCgsLAVi6dCllZWWRLLNXPbUh0mI26J966imWLVvG7373u0iXclPq6uo6J5dlZGRQX18f4YpuzltvvcWSJUt46aWXcLlckS6nTy5dusSJEyeYMWOGJo/DtfWDto5BIBDg4YcfZu7cucydO5ecnBySkpIwGNp7mbOysqL+A+z6NnQch9dee40lS5bwb//2b3i93kGtKSaD/re//S0lJSW88cYbvPXWWxw6dCjSJQ1J3/72t9m5cyfvvfceGRkZvPLKK5EuKaSmpiaee+45fvjDH2KxWCJdTr9dX7/WjoFer+e9997jL3/5C3a7nXPnznXbJhrOkHtzfRtOnz7NmjVr2L59O3/4wx9wuVxs2rRpUGuKyaDPzMwEIC0tjYKCgm6LsGlBWloa1dXVAFRXV3deTNOSYcOGodfr0el0rFixgqNHj0a6pF75fD6ee+45lixZwsKFCwFtHYcb1a+1Y9AhKSmJ2bNn8+WXX9LY2Ijf7wegqqpKM8uodLRh7969ZGRkoCgKJpOJZcuWDfpxiLmgb25uxuPxdH69f/9+8vLyIlxV/+Xn5/Puu+8C8O6777JgwYIIV9R/HQEJ8OGHH0b1cVBVlR/96EeMHTuWJ598svP7WjkOPdWvpWNQX19PY2P7Esatra18/PHHjBs3jtmzZ7Njxw4ASkpKui3BEk1u1IaxY8d2HgdVVSNyHGJuZmxFRQXPPPMM0N5X9tBDD/EP//APEa6qd2vWrOHgwYM4nU7S0tL4p3/6Jx544AFWr15NZWUl2dnZbNiwAZvNFulSe3SjNhw8eJCTJ08CMGLECNatWxe1Z2OfffYZjz/+OLfddhs6Xfv5z5o1a5g+fbomjkNP9W/btk0zx+DkyZP84Ac/IBAIoKoqDz74IM8++ywVFRV873vfw+VyMWnSJNavXx+1wyt7asPf/u3f4nQ6UVWViRMn8uMf/7jzou1giLmgF0II0VXMdd0IIYToSoJeCCFinAS9EELEOAl6IYSIcRL0QggR4yTohRAixknQCyFEjPv/AT0HLHjcBzzGAAAAAElFTkSuQmCC\n", 692 | "text/plain": [ 693 | "
" 694 | ] 695 | }, 696 | "metadata": {}, 697 | "output_type": "display_data" 698 | } 699 | ], 700 | "source": [ 701 | "sns.distplot(errors)\n", 702 | "plt.savefig('xgboost_simple.svg')\n", 703 | "plt.show()" 704 | ] 705 | } 706 | ], 707 | "metadata": { 708 | "kernelspec": { 709 | "display_name": "Python [default]", 710 | "language": "python", 711 | "name": "python3" 712 | }, 713 | "language_info": { 714 | "codemirror_mode": { 715 | "name": "ipython", 716 | "version": 3 717 | }, 718 | "file_extension": ".py", 719 | "mimetype": "text/x-python", 720 | "name": "python", 721 | "nbconvert_exporter": "python", 722 | "pygments_lexer": "ipython3", 723 | "version": "3.6.5" 724 | } 725 | }, 726 | "nbformat": 4, 727 | "nbformat_minor": 2 728 | } 729 | -------------------------------------------------------------------------------- /rolling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import datetime 4 | 5 | # from https://stackoverflow.com/a/21230438 6 | def running_view(arr, window, axis=-1): 7 | """ 8 | return a running view of length 'window' over 'axis' 9 | the returned array has an extra last dimension, which spans the window 10 | """ 11 | shape = list(arr.shape) 12 | shape[axis] -= (window-1) 13 | assert(shape[axis]>0) 14 | return np.lib.index_tricks.as_strided( 15 | arr, 16 | shape + [window], 17 | arr.strides + (arr.strides[axis],), 18 | writeable=False 19 | ) 20 | 21 | def add_dates_to_array(array, start_date): 22 | """ 23 | Converts a raw array to a Series with DatetimeIndex. 24 | 25 | Parameters: 26 | array: array-like (np.array, pd.Series, etc.) 27 | start_date : date 28 | Date corresponding to the first value of the series. 29 | 30 | Return value: 31 | pd.Series containing the same values as the array with a daily DatetimeIndex starting at start_date. 32 | """ 33 | s = pd.Series(array) 34 | s.index = pd.DatetimeIndex(start=start_date, freq="D", periods=len(array)) 35 | return s 36 | 37 | class Rolling(object): 38 | """ 39 | A helper for training and forecasting single-variable time series. 40 | window : int 41 | Specifies how many days from the past to make available for feature computation. 42 | Warning: big window may slow-down computation and increase the amount of data for the utility to work - periods shorter than window are not currently supported. 43 | Default: 365 days 44 | extract_features : (np.array, date, object) -> np.array 45 | The function that extracts features from previous values. 46 | It's arguments are an array of previous values (of length `window`), 47 | date corresponding to the first value and some arbitrary object that can be passed in other methods. 48 | Default behaviour is to just take all the values and add no other features. 49 | pretransform : (np.array, date) -> np.array 50 | A function called on input data, used to convert it to a supported format. 51 | It's a good place to call scaler.transform or remove_seasonal_component here. 52 | Default behavior is to take the data as-is. 53 | posttransform : (np.array, date) -> np.array 54 | A function called on generated data, used to undo any conversions for pretransform. 55 | It's a good place to call scaler.inverse_transform or add_seasonal_component here. 56 | Default behavior is to pass the input without changes. 57 | """ 58 | def __init__(self, 59 | window=365, 60 | extract_features=lambda prev_values, current_date, metadata: prev_values, 61 | pretransform=lambda values, start_date: values, 62 | posttransform=lambda values, start_date: values): 63 | self.window = window 64 | self.extract_features = extract_features 65 | self.vextract_features = np.vectorize(extract_features) 66 | self.pretransform = pretransform 67 | self.posttransform = posttransform 68 | 69 | def make_training_data(self, value_series, metadata=None, start_date=None): 70 | """ 71 | Generates training data from a raw series of values. 72 | 73 | Parameters: 74 | value_series : np.array or pd.Series 75 | Values of some time series for forecasting. 76 | metadata : object 77 | Any data that you want to be passed to extract_features. 78 | start_date : date or None 79 | The date corresponding to the first value in value_series. 80 | If not set, value_series has to be a pd.Series with a DatetimeIndex with 1 day interval. 81 | 82 | Return value: (np.array, np.array) 83 | Tuple of arrays X and y which are training data in the format supported by libraries like sklearn. 84 | """ 85 | if start_date is None: 86 | start_date = value_series.index.min() 87 | value_series = self.pretransform(value_series, start_date) 88 | X_base = value_series[:-1] # remove last value 89 | X = [] 90 | for i, row in enumerate(running_view(X_base, self.window)): 91 | X.append(self.extract_features(row, start_date + datetime.timedelta(days=i), metadata)) 92 | X = np.array(X) 93 | y = value_series[self.window:] 94 | assert len(X) == len(y) 95 | return X, y 96 | 97 | def predict(self, 98 | predictor, 99 | previous_values, 100 | metadata=None, 101 | prev_start_date=None, 102 | fcast_len=90): 103 | """ 104 | Makes future predictions based on previous data. 105 | 106 | Parameters: 107 | predictor : ([X]) -> [y] 108 | A function that takes an array containing a single X vector and returning an array containing a single y scalar. 109 | previous_values: np.array or pd.Series 110 | Values that are used to kickoff the forecasting. Forecasted values will start on the day following the last value from this series. 111 | metadata : object 112 | Any data that you want to be passed to extract_features. 113 | prev_start_date : date or None 114 | The date corresponding to the first value in previous_values. 115 | If not set, previous_values has to be a pd.Series with a DatetimeIndex with 1 day interval. 116 | fcast_len : int 117 | The amount of days that will be forecasted (length of output) 118 | 119 | Return value: 120 | np.array of length fcast_len - forecasted values 121 | """ 122 | if prev_start_date is None: 123 | prev_start_date = previous_values.index.min() 124 | 125 | fcast_start = prev_start_date + datetime.timedelta(days=len(previous_values)) 126 | 127 | # TODO if we cut the window without pretransform we may gain some speed, but be cautious about not breaking prev_start_date 128 | # pretransform 129 | previous_values = self.pretransform(previous_values, prev_start_date) 130 | 131 | # we only need one window 132 | previous_values = previous_values[-self.window:] 133 | 134 | fcast_date = fcast_start 135 | fcast = np.zeros(fcast_len) 136 | for i in range(fcast_len): 137 | X = self.extract_features(previous_values, fcast_date, metadata) 138 | y = predictor([X])[0] 139 | fcast[i] = y 140 | fcast_date += datetime.timedelta(days=1) 141 | 142 | previous_values = np.roll(previous_values, -1) 143 | previous_values[-1] = y 144 | 145 | return self.posttransform(fcast, fcast_start) 146 | 147 | 148 | --------------------------------------------------------------------------------