├── .gitignore ├── 01-forecasting.ipynb ├── 02-forecasting_with_deep_learning.ipynb ├── 03-forecasting_with_gbdt.ipynb ├── README.md ├── img ├── deepar.svg ├── deepvar.svg ├── lgb.svg ├── prophet.svg └── var.svg ├── requirements.txt └── utils ├── evaluation.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | *DS_Store 2 | *idea 3 | *ipynb_checkpoints 4 | *__pycache__ 5 | models/ 6 | *.pkl 7 | *.csv 8 | -------------------------------------------------------------------------------- /03-forecasting_with_gbdt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c1121594-bd32-4723-be44-7486fbbe57ab", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "5f19a86b-a69a-4dc0-9abe-ce3b6321dcd9", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import itertools\n", 22 | "import multiprocessing\n", 23 | "import os\n", 24 | "import warnings\n", 25 | "import hyperopt\n", 26 | "import tsfresh\n", 27 | "import numpy as np\n", 28 | "import pandas as pd\n", 29 | "import altair as alt\n", 30 | "import lightgbm as lgb\n", 31 | "from hyperopt import fmin, hp, space_eval, STATUS_OK, tpe, Trials\n", 32 | "from sklearn.compose import make_column_transformer\n", 33 | "from sklearn.impute import SimpleImputer\n", 34 | "from sklearn.metrics import mean_tweedie_deviance\n", 35 | "from sklearn.model_selection import TimeSeriesSplit\n", 36 | "from sklearn.pipeline import make_pipeline\n", 37 | "from sklearn.preprocessing import OrdinalEncoder\n", 38 | "from tsfresh import extract_features\n", 39 | "from tsfresh.feature_extraction.settings import (\n", 40 | " EfficientFCParameters, \n", 41 | " MinimalFCParameters,\n", 42 | ")\n", 43 | "from tsfresh.utilities.dataframe_functions import roll_time_series\n", 44 | "from utils.evaluation import calc_eval_metric, WRMSSEEvaluator\n", 45 | "from utils.misc import dump_pickle, load_pickle\n", 46 | "\n", 47 | "np.random.seed(42)\n", 48 | "warnings.filterwarnings(\"ignore\")\n", 49 | "n_jobs = multiprocessing.cpu_count() - 1" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "id": "a4cd41e9-420f-400a-8adb-5cf529c79b1a", 55 | "metadata": {}, 56 | "source": [ 57 | "The Kaggle dataset was saved in the local directory `~/data/mofc-demand-forecast` in advance." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 3, 63 | "id": "7fdefcfe-99a4-4217-bcb2-0bf92845c210", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "DATA_PATH = \"../../data/mofc-demand-forecast\"\n", 68 | "MODEL_PATH = \"models\"\n", 69 | "TUNE_PARAMS = True\n", 70 | "\n", 71 | "calendar = pd.read_csv(os.path.join(DATA_PATH, \"calendar.csv\"))\n", 72 | "selling_prices = pd.read_csv(os.path.join(DATA_PATH, \"sell_prices.csv\"))\n", 73 | "# df_train_valid = pd.read_csv(os.path.join(DATA_PATH, \"sales_train_validation.csv\"))\n", 74 | "df_train_eval = pd.read_csv(os.path.join(DATA_PATH, \"sales_train_evaluation.csv\"))\n", 75 | "# sample_submission = pd.read_csv(os.path.join(DATA_PATH, \"sample_submission.csv\"))" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 4, 81 | "id": "e1cfc0db-1be8-4515-be2f-79669c02b1ba", 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "305 out of 30490 IDs were selected for validation and testing.\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "key_ids = [\"id\", \"date\"]\n", 94 | "all_ids = df_train_eval[\"id\"].unique()\n", 95 | "date_names = [\"d_\" + str(i) for i in range(1, 1942)]\n", 96 | "calendar[\"date\"] = pd.to_datetime(calendar[\"date\"])\n", 97 | "dates = calendar[\"date\"].unique()\n", 98 | "test_steps = 28\n", 99 | "\n", 100 | "key_pairs = list(itertools.product(all_ids, dates))\n", 101 | "key_pairs = pd.DataFrame(key_pairs, columns=key_ids)\n", 102 | "\n", 103 | "sample_ratio = 0.01\n", 104 | "\n", 105 | "if sample_ratio == 1.0:\n", 106 | " sampled_ids = all_ids\n", 107 | "else:\n", 108 | " sampled_ids = np.random.choice(\n", 109 | " all_ids, round(sample_ratio * len(all_ids)), replace=False\n", 110 | " ).tolist()\n", 111 | " \n", 112 | "print(\n", 113 | " f\"{len(sampled_ids)} out of {len(all_ids)} IDs were selected for validation and testing.\"\n", 114 | ")" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "id": "405c6b9d-ebdb-4936-9f68-cb77bcb85e49", 120 | "metadata": {}, 121 | "source": [ 122 | "# Data Preprocessing" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 5, 128 | "id": "1d8abba8-ee9f-4d9c-a7fd-a23748ae069e", 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "sales = df_train_eval[[\"id\"] + date_names]\n", 133 | "date_dict = calendar[[\"date\", \"d\"]].set_index(\"d\").to_dict()[\"date\"]\n", 134 | "sales.columns = pd.Series(sales.columns).replace(date_dict)\n", 135 | "sales = pd.melt(\n", 136 | " sales,\n", 137 | " id_vars=\"id\",\n", 138 | " value_vars=sales.columns[1:],\n", 139 | " var_name=\"date\",\n", 140 | " value_name=\"sales\",\n", 141 | ")" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 6, 147 | "id": "e9a8b409-a978-4adc-b188-493071f8cf22", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "def split_list(lst, n):\n", 152 | " q = len(lst) // n\n", 153 | " chunks = []\n", 154 | " \n", 155 | " for i in range(n):\n", 156 | " if i == n - 1:\n", 157 | " chunks.append(lst[q * i : len(lst)])\n", 158 | " else:\n", 159 | " chunks.append(lst[q * i : q * (i + 1)])\n", 160 | " \n", 161 | " return chunks" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 7, 167 | "id": "5bacee4e-35f9-44e2-8aee-b32eef134669", 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stderr", 172 | "output_type": "stream", 173 | "text": [ 174 | "Rolling: 100%|██████████| 25/25 [00:11<00:00, 2.17it/s]\n", 175 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.23s/it]\n", 176 | "Rolling: 100%|██████████| 25/25 [00:12<00:00, 2.08it/s]\n", 177 | "Feature Extraction: 100%|██████████| 25/25 [00:29<00:00, 1.20s/it]\n", 178 | "Rolling: 100%|██████████| 25/25 [00:13<00:00, 1.86it/s]\n", 179 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.21s/it]\n", 180 | "Rolling: 100%|██████████| 25/25 [00:20<00:00, 1.23it/s]\n", 181 | "Feature Extraction: 100%|██████████| 25/25 [00:33<00:00, 1.32s/it]\n", 182 | "Rolling: 100%|██████████| 25/25 [00:11<00:00, 2.26it/s]\n", 183 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.20s/it]\n", 184 | "Rolling: 100%|██████████| 25/25 [00:12<00:00, 2.04it/s]\n", 185 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.22s/it]\n", 186 | "Rolling: 100%|██████████| 25/25 [00:13<00:00, 1.89it/s]\n", 187 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.20s/it]\n", 188 | "Rolling: 100%|██████████| 25/25 [00:18<00:00, 1.35it/s]\n", 189 | "Feature Extraction: 100%|██████████| 25/25 [00:27<00:00, 1.09s/it]\n", 190 | "Rolling: 100%|██████████| 25/25 [00:11<00:00, 2.23it/s]\n", 191 | "Feature Extraction: 100%|██████████| 25/25 [00:29<00:00, 1.19s/it]\n", 192 | "Rolling: 100%|██████████| 25/25 [00:12<00:00, 2.05it/s]\n", 193 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.23s/it]\n", 194 | "Rolling: 100%|██████████| 25/25 [00:14<00:00, 1.77it/s]\n", 195 | "Feature Extraction: 100%|██████████| 25/25 [00:29<00:00, 1.19s/it]\n", 196 | "Rolling: 100%|██████████| 25/25 [00:19<00:00, 1.30it/s]\n", 197 | "Feature Extraction: 100%|██████████| 25/25 [00:26<00:00, 1.05s/it]\n", 198 | "Rolling: 100%|██████████| 25/25 [00:11<00:00, 2.24it/s]\n", 199 | "Feature Extraction: 100%|██████████| 25/25 [00:31<00:00, 1.26s/it]\n", 200 | "Rolling: 100%|██████████| 25/25 [00:12<00:00, 2.06it/s]\n", 201 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.23s/it]\n", 202 | "Rolling: 100%|██████████| 25/25 [00:14<00:00, 1.74it/s]\n", 203 | "Feature Extraction: 100%|██████████| 25/25 [00:29<00:00, 1.18s/it]\n", 204 | "Rolling: 100%|██████████| 25/25 [00:19<00:00, 1.29it/s]\n", 205 | "Feature Extraction: 100%|██████████| 25/25 [00:26<00:00, 1.05s/it]\n", 206 | "Rolling: 100%|██████████| 25/25 [00:11<00:00, 2.24it/s]\n", 207 | "Feature Extraction: 100%|██████████| 25/25 [00:31<00:00, 1.26s/it]\n", 208 | "Rolling: 100%|██████████| 25/25 [00:12<00:00, 1.98it/s]\n", 209 | "Feature Extraction: 100%|██████████| 25/25 [00:31<00:00, 1.24s/it]\n", 210 | "Rolling: 100%|██████████| 25/25 [00:14<00:00, 1.74it/s]\n", 211 | "Feature Extraction: 100%|██████████| 25/25 [00:29<00:00, 1.19s/it]\n", 212 | "Rolling: 100%|██████████| 25/25 [00:18<00:00, 1.33it/s]\n", 213 | "Feature Extraction: 100%|██████████| 25/25 [00:27<00:00, 1.12s/it]\n", 214 | "Rolling: 100%|██████████| 25/25 [00:11<00:00, 2.25it/s]\n", 215 | "Feature Extraction: 100%|██████████| 25/25 [00:29<00:00, 1.20s/it]\n", 216 | "Rolling: 100%|██████████| 25/25 [00:12<00:00, 1.96it/s]\n", 217 | "Feature Extraction: 100%|██████████| 25/25 [00:31<00:00, 1.24s/it]\n", 218 | "Rolling: 100%|██████████| 25/25 [00:14<00:00, 1.75it/s]\n", 219 | "Feature Extraction: 100%|██████████| 25/25 [00:29<00:00, 1.20s/it]\n", 220 | "Rolling: 100%|██████████| 25/25 [00:20<00:00, 1.24it/s]\n", 221 | "Feature Extraction: 100%|██████████| 25/25 [00:26<00:00, 1.06s/it]\n", 222 | "Rolling: 100%|██████████| 25/25 [00:11<00:00, 2.25it/s]\n", 223 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.22s/it]\n", 224 | "Rolling: 100%|██████████| 25/25 [00:12<00:00, 2.06it/s]\n", 225 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.24s/it]\n", 226 | "Rolling: 100%|██████████| 25/25 [00:13<00:00, 1.84it/s]\n", 227 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.24s/it]\n", 228 | "Rolling: 100%|██████████| 25/25 [00:18<00:00, 1.33it/s]\n", 229 | "Feature Extraction: 100%|██████████| 25/25 [00:27<00:00, 1.10s/it]\n", 230 | "Rolling: 100%|██████████| 25/25 [00:11<00:00, 2.26it/s]\n", 231 | "Feature Extraction: 100%|██████████| 25/25 [00:29<00:00, 1.20s/it]\n", 232 | "Rolling: 100%|██████████| 25/25 [00:12<00:00, 2.06it/s]\n", 233 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.23s/it]\n", 234 | "Rolling: 100%|██████████| 25/25 [00:13<00:00, 1.86it/s]\n", 235 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.21s/it]\n", 236 | "Rolling: 100%|██████████| 25/25 [00:18<00:00, 1.33it/s]\n", 237 | "Feature Extraction: 100%|██████████| 25/25 [00:27<00:00, 1.10s/it]\n", 238 | "Rolling: 100%|██████████| 25/25 [00:11<00:00, 2.25it/s]\n", 239 | "Feature Extraction: 100%|██████████| 25/25 [00:29<00:00, 1.18s/it]\n", 240 | "Rolling: 100%|██████████| 25/25 [00:12<00:00, 1.99it/s]\n", 241 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.23s/it]\n", 242 | "Rolling: 100%|██████████| 25/25 [00:13<00:00, 1.80it/s]\n", 243 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.21s/it]\n", 244 | "Rolling: 100%|██████████| 25/25 [00:19<00:00, 1.30it/s]\n", 245 | "Feature Extraction: 100%|██████████| 25/25 [00:34<00:00, 1.37s/it]\n", 246 | "Rolling: 100%|██████████| 25/25 [00:12<00:00, 1.95it/s]\n", 247 | "Feature Extraction: 100%|██████████| 25/25 [00:35<00:00, 1.43s/it]\n", 248 | "Rolling: 100%|██████████| 25/25 [00:14<00:00, 1.73it/s]\n", 249 | "Feature Extraction: 100%|██████████| 25/25 [00:35<00:00, 1.42s/it]\n", 250 | "Rolling: 100%|██████████| 25/25 [00:16<00:00, 1.54it/s]\n", 251 | "Feature Extraction: 100%|██████████| 25/25 [00:35<00:00, 1.41s/it]\n", 252 | "Rolling: 100%|██████████| 25/25 [00:24<00:00, 1.01it/s]\n", 253 | "Feature Extraction: 100%|██████████| 25/25 [00:30<00:00, 1.23s/it]\n" 254 | ] 255 | }, 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | "CPU times: user 27min 43s, sys: 2min 44s, total: 30min 28s\n", 261 | "Wall time: 36min\n" 262 | ] 263 | } 264 | ], 265 | "source": [ 266 | "%%time\n", 267 | "split = date_dict[date_names[-test_steps]]\n", 268 | "sales_train = (\n", 269 | " sales[sales[\"date\"] < split].set_index(\"id\").loc[sampled_ids].reset_index()\n", 270 | ")\n", 271 | "\n", 272 | "frequencies = [7, 30, 90, 365]\n", 273 | "default_fc_parameters = MinimalFCParameters()\n", 274 | "# default_fc_parameters = EfficientFCParameters()\n", 275 | "\n", 276 | "chunks = split_list(sampled_ids, 10)\n", 277 | "\n", 278 | "for i, chunk in enumerate(chunks):\n", 279 | " for j, frequency in enumerate(frequencies):\n", 280 | " df_rolled = roll_time_series(\n", 281 | " sales_train.set_index(\"id\").loc[chunk].reset_index(),\n", 282 | " column_id=\"id\",\n", 283 | " column_sort=\"date\",\n", 284 | " max_timeshift=frequency,\n", 285 | " min_timeshift=frequency,\n", 286 | " n_jobs=n_jobs,\n", 287 | " disable_progressbar=False,\n", 288 | " )\n", 289 | "\n", 290 | " df_extracted = extract_features(\n", 291 | " df_rolled[[\"id\", \"date\", \"sales\"]],\n", 292 | " default_fc_parameters=default_fc_parameters,\n", 293 | " column_id=\"id\",\n", 294 | " column_sort=\"date\",\n", 295 | " n_jobs=n_jobs,\n", 296 | " pivot=True,\n", 297 | " )\n", 298 | "\n", 299 | " df_extracted.columns = df_extracted.columns + f\"__D{frequency}\"\n", 300 | "\n", 301 | " if j == 0:\n", 302 | " df_part = df_extracted\n", 303 | " else:\n", 304 | " df_part = df_part.merge(df_extracted, left_index=True, right_index=True)\n", 305 | "\n", 306 | " if i == 0:\n", 307 | " feat_dynamic_real = df_part\n", 308 | " else:\n", 309 | " feat_dynamic_real = pd.concat([feat_dynamic_real, df_part])\n", 310 | "\n", 311 | "feat_dynamic_real = feat_dynamic_real.reset_index()\n", 312 | "feat_dynamic_real.columns = key_ids + feat_dynamic_real.columns[2:].tolist()\n", 313 | "feat_dynamic_real = feat_dynamic_real.merge(sales_train, on=key_ids).rename(\n", 314 | " {\"sales\": \"sales__D1\"}, axis=1\n", 315 | ")\n", 316 | "feat_dynamic_real[\"date\"] = feat_dynamic_real[\"date\"].map(\n", 317 | " lambda x: x + pd.Timedelta(\"1 days\")\n", 318 | ")" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 8, 324 | "id": "a71861f7-104e-43c6-b5d3-14747c9f64de", 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "prices = (\n", 329 | " df_train_eval[[\"id\", \"store_id\", \"item_id\"]]\n", 330 | " .merge(selling_prices, how=\"left\")\n", 331 | " .drop([\"store_id\", \"item_id\"], axis=1)\n", 332 | ")\n", 333 | "week_to_date = calendar[[\"date\", \"wm_yr_wk\"]].drop_duplicates()\n", 334 | "prices = week_to_date.merge(prices, how=\"left\").drop(\n", 335 | " [\"wm_yr_wk\"], axis=1\n", 336 | ")\n", 337 | "\n", 338 | "snap = calendar[[\"date\", \"snap_CA\", \"snap_TX\", \"snap_WI\"]]\n", 339 | "snap.columns = [\"date\", \"CA\", \"TX\", \"WI\"]\n", 340 | "snap = pd.melt(\n", 341 | " snap,\n", 342 | " id_vars=\"date\",\n", 343 | " value_vars=[\"CA\", \"TX\", \"WI\"],\n", 344 | " var_name=\"state_id\",\n", 345 | " value_name=\"snap\",\n", 346 | ")\n", 347 | "snap = key_pairs.merge(df_train_eval[[\"id\", \"state_id\"]], how=\"left\").merge(\n", 348 | " snap, on=[\"date\", \"state_id\"], how=\"left\"\n", 349 | ").drop([\"state_id\"], axis=1)\n", 350 | "\n", 351 | "feat_dynamic_real = feat_dynamic_real.merge(prices, on=key_ids).merge(\n", 352 | " snap, on=key_ids\n", 353 | ")\n", 354 | "\n", 355 | "num_feature_names = feat_dynamic_real.columns.difference(key_ids)" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 9, 361 | "id": "4a74b9be-0de4-4d24-88e8-aa476d1cf9cd", 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "feat_dynamic_cat = calendar[\n", 366 | " [\n", 367 | " \"date\",\n", 368 | " \"wday\",\n", 369 | " \"month\",\n", 370 | " \"event_name_1\",\n", 371 | " \"event_type_1\",\n", 372 | " \"event_name_2\",\n", 373 | " \"event_type_2\",\n", 374 | " ]\n", 375 | "]\n", 376 | "feat_dynamic_cat[\"day\"] = (\n", 377 | " feat_dynamic_cat[\"date\"].astype(\"str\").map(lambda x: int(x.split(\"-\")[2]))\n", 378 | ")\n", 379 | "\n", 380 | "feat_static_cat = df_train_eval[\n", 381 | " [\"id\", \"item_id\", \"dept_id\", \"cat_id\", \"store_id\", \"state_id\"]\n", 382 | "]\n", 383 | "\n", 384 | "cat_feature_names = feat_dynamic_cat.columns\n", 385 | "cat_feature_names = cat_feature_names.union(feat_static_cat.columns).difference(key_ids)" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 10, 391 | "id": "96a7f726-63eb-435a-8fa1-39f2172a299c", 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "df_train = feat_dynamic_real.merge(feat_dynamic_cat).merge(feat_static_cat)\n", 396 | "df_recur = df_train[df_train[\"date\"] == split].set_index(key_ids)\n", 397 | "df_train = df_train.merge(sales_train, on=key_ids)\n", 398 | "\n", 399 | "all_feature_names = num_feature_names.union(cat_feature_names)\n", 400 | "feature_names_to_remove = all_feature_names[\n", 401 | " (df_train[all_feature_names].isna().sum() / df_train.shape[0] == 1.0)\n", 402 | " | (df_train[all_feature_names].std() == 0.0)\n", 403 | "]\n", 404 | "num_feature_names = num_feature_names.difference(feature_names_to_remove)\n", 405 | "cat_feature_names = cat_feature_names.difference(feature_names_to_remove)\n", 406 | "all_feature_names = all_feature_names.difference(feature_names_to_remove)\n", 407 | "\n", 408 | "train_dates = df_train[\"date\"].unique()\n", 409 | "df_x_train = (\n", 410 | " df_train[key_ids + all_feature_names.difference([\"sales\"]).tolist()]\n", 411 | " .set_index(key_ids)\n", 412 | " .swaplevel()\n", 413 | ")\n", 414 | "df_y_train = df_train[key_ids + [\"sales\"]].set_index(key_ids).swaplevel()" 415 | ] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "id": "f41561f3-3122-4573-894d-1422e4049d9f", 420 | "metadata": {}, 421 | "source": [ 422 | "# Hyperparameter Tuning" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 11, 428 | "id": "39446f59-9bc5-4a4e-95e9-48ec77514e46", 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "def objective(args):\n", 433 | " global n_jobs\n", 434 | " global train_dates\n", 435 | " global df_x_train, df_y_train\n", 436 | " global all_feature_names, num_feature_names, cat_feature_names\n", 437 | "\n", 438 | " tscv = TimeSeriesSplit(n_splits=3)\n", 439 | " cat_pipeline = make_pipeline(\n", 440 | " SimpleImputer(strategy=\"constant\", fill_value=\"\"),\n", 441 | " OrdinalEncoder(handle_unknown=\"use_encoded_value\", unknown_value=-1),\n", 442 | " )\n", 443 | " num_pipeline = SimpleImputer(strategy=\"median\")\n", 444 | " processor = make_column_transformer(\n", 445 | " (cat_pipeline, cat_feature_names), (num_pipeline, num_feature_names)\n", 446 | " )\n", 447 | "\n", 448 | " default_params = {\n", 449 | " \"objective\": \"tweedie\",\n", 450 | " \"num_threads\": n_jobs,\n", 451 | " \"device_type\": \"cpu\",\n", 452 | " \"seed\": 42,\n", 453 | " \"force_col_wise\": True,\n", 454 | " \"max_cat_threshold\": 32,\n", 455 | " \"verbosity\": 1,\n", 456 | " \"max_bin_by_feature\": None,\n", 457 | " \"min_data_in_bin\": 3,\n", 458 | " \"feature_pre_filter\": True,\n", 459 | " \"tweedie_variance_power\": 1.5,\n", 460 | " \"metric\": \"tweedie\",\n", 461 | " }\n", 462 | " params = {\n", 463 | " \"boosting\": args[\"boosting\"],\n", 464 | " \"learning_rate\": args[\"learning_rate\"],\n", 465 | " \"num_iterations\": int(args[\"num_iterations\"]),\n", 466 | " \"num_leaves\": int(args[\"num_leaves\"]),\n", 467 | " \"max_depth\": int(args[\"max_depth\"]),\n", 468 | " \"min_data_in_leaf\": int(args[\"min_data_in_leaf\"]),\n", 469 | " \"min_sum_hessian_in_leaf\": args[\"min_sum_hessian_in_leaf\"],\n", 470 | " \"bagging_fraction\": args[\"bagging_fraction\"],\n", 471 | " \"bagging_freq\": int(args[\"bagging_freq\"]),\n", 472 | " \"feature_fraction\": args[\"feature_fraction\"],\n", 473 | " \"extra_trees\": args[\"extra_trees\"],\n", 474 | " \"lambda_l1\": args[\"lambda_l1\"],\n", 475 | " \"lambda_l2\": args[\"lambda_l2\"],\n", 476 | " \"path_smooth\": args[\"path_smooth\"],\n", 477 | " \"max_bin\": int(args[\"max_bin\"]),\n", 478 | " }\n", 479 | " default_params.update(params)\n", 480 | "\n", 481 | " losses = []\n", 482 | "\n", 483 | " for train_index, test_index in tscv.split(train_dates):\n", 484 | " dtrain = df_x_train.loc[train_dates[train_index], :]\n", 485 | " dvalid = df_x_train.loc[train_dates[test_index], :]\n", 486 | "\n", 487 | " dtrain = processor.fit_transform(dtrain)\n", 488 | " dvalid = processor.transform(dvalid)\n", 489 | " dtrain = lgb.Dataset(dtrain, label=df_y_train.loc[train_dates[train_index], :])\n", 490 | "\n", 491 | " model = lgb.train(\n", 492 | " default_params,\n", 493 | " dtrain,\n", 494 | " feature_name=all_feature_names.tolist(),\n", 495 | " categorical_feature=cat_feature_names.tolist(),\n", 496 | " verbose_eval=False,\n", 497 | " )\n", 498 | "\n", 499 | " y_true = df_y_train.loc[train_dates[test_index], :]\n", 500 | " y_pred = model.predict(dvalid)\n", 501 | " loss = mean_tweedie_deviance(\n", 502 | " y_true, y_pred, power=default_params[\"tweedie_variance_power\"]\n", 503 | " )\n", 504 | "\n", 505 | " losses.append(loss)\n", 506 | "\n", 507 | " return {\"loss\": np.mean(losses), \"status\": STATUS_OK}" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": 12, 513 | "id": "a2a057f3-ae21-4135-b914-76342a8a06a9", 514 | "metadata": {}, 515 | "outputs": [ 516 | { 517 | "name": "stdout", 518 | "output_type": "stream", 519 | "text": [ 520 | "CPU times: user 3h 58min 25s, sys: 1min 15s, total: 3h 59min 41s\n", 521 | "Wall time: 54min 4s\n" 522 | ] 523 | } 524 | ], 525 | "source": [ 526 | "%%time\n", 527 | "%%capture\n", 528 | "if TUNE_PARAMS:\n", 529 | " space = {\n", 530 | " \"boosting\": hp.pchoice(\"boosting\", [(0.75, \"gbdt\"), (0.25, \"dart\")]),\n", 531 | " \"learning_rate\": 10 ** hp.uniform(\"learning_rate\", -2, 0),\n", 532 | " \"num_iterations\": hp.quniform(\"num_iterations\", 1, 1000, 1),\n", 533 | " \"num_leaves\": 2 ** hp.uniform(\"num_leaves\", 1, 8),\n", 534 | " \"max_depth\": -1,\n", 535 | " \"min_data_in_leaf\": 2 * 10 ** hp.uniform(\"min_data_in_leaf\", 0, 2),\n", 536 | " \"min_sum_hessian_in_leaf\": hp.uniform(\"min_sum_hessian_in_leaf\", 1e-4, 1e-2),\n", 537 | " \"bagging_fraction\": hp.uniform(\"bagging_fraction\", 0.5, 1.0),\n", 538 | " \"bagging_freq\": hp.qlognormal(\"bagging_freq\", 0.0, 1.0, 1),\n", 539 | " \"feature_fraction\": hp.uniform(\"feature_fraction\", 0.5, 1.0),\n", 540 | " \"extra_trees\": hp.pchoice(\"extra_trees\", [(0.75, False), (0.25, True)]),\n", 541 | " \"lambda_l1\": hp.lognormal(\"lambda_l1\", 0.0, 1.0),\n", 542 | " \"lambda_l2\": hp.lognormal(\"lambda_l2\", 0.0, 1.0),\n", 543 | " \"path_smooth\": hp.lognormal(\"path_smooth\", 0.0, 1.0),\n", 544 | " \"max_bin\": 2 ** hp.quniform(\"max_bin\", 6, 10, 1) - 1,\n", 545 | " }\n", 546 | "\n", 547 | " trials = Trials()\n", 548 | "\n", 549 | " best = fmin(\n", 550 | " objective,\n", 551 | " space=space,\n", 552 | " algo=tpe.suggest,\n", 553 | " max_evals=100,\n", 554 | " trials=trials,\n", 555 | " )\n", 556 | "\n", 557 | " best_params = space_eval(space, best)\n", 558 | " best_params[\"num_iterations\"] = int(best_params[\"num_iterations\"])\n", 559 | " best_params[\"num_leaves\"] = int(best_params[\"num_leaves\"])\n", 560 | " best_params[\"min_data_in_leaf\"] = int(best_params[\"min_data_in_leaf\"])\n", 561 | " best_params[\"bagging_freq\"] = int(best_params[\"bagging_freq\"])\n", 562 | " best_params[\"max_bin\"] = int(best_params[\"max_bin\"])\n", 563 | "\n", 564 | " os.makedirs(os.path.join(MODEL_PATH, \"lgb\"), exist_ok=True)\n", 565 | " dump_pickle(os.path.join(MODEL_PATH, \"lgb\", \"best_params.pkl\"), best_params)\n", 566 | " \n", 567 | "else:\n", 568 | " best_params = load_pickle(os.path.join(MODEL_PATH, \"lgb\", \"best_params.pkl\"))" 569 | ] 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "id": "d416824f-a790-4274-9d81-568217e9e571", 574 | "metadata": {}, 575 | "source": [ 576 | "# Model Validation" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 13, 582 | "id": "5f5d6bf2-4856-48b4-abcf-a90cbe60af33", 583 | "metadata": {}, 584 | "outputs": [], 585 | "source": [ 586 | "%%capture\n", 587 | "cat_pipeline = make_pipeline(\n", 588 | " SimpleImputer(strategy=\"constant\", fill_value=\"\"),\n", 589 | " OrdinalEncoder(handle_unknown=\"use_encoded_value\", unknown_value=-1),\n", 590 | ")\n", 591 | "num_pipeline = SimpleImputer(strategy=\"median\")\n", 592 | "processor = make_column_transformer(\n", 593 | " (cat_pipeline, cat_feature_names), (num_pipeline, num_feature_names)\n", 594 | ")\n", 595 | "\n", 596 | "default_params = {\n", 597 | " \"objective\": \"tweedie\",\n", 598 | " \"num_threads\": n_jobs,\n", 599 | " \"device_type\": \"cpu\",\n", 600 | " \"seed\": 42,\n", 601 | " \"force_col_wise\": True,\n", 602 | " \"max_cat_threshold\": 32,\n", 603 | " \"verbosity\": 1,\n", 604 | " \"max_bin_by_feature\": None,\n", 605 | " \"min_data_in_bin\": 3,\n", 606 | " \"feature_pre_filter\": True,\n", 607 | " \"tweedie_variance_power\": 1.5,\n", 608 | " \"metric\": \"tweedie\",\n", 609 | "}\n", 610 | "default_params.update(best_params)\n", 611 | "\n", 612 | "dtrain = processor.fit_transform(df_x_train)\n", 613 | "dtrain = lgb.Dataset(dtrain, label=df_y_train)\n", 614 | "\n", 615 | "model = lgb.train(\n", 616 | " default_params,\n", 617 | " dtrain,\n", 618 | " feature_name=all_feature_names.tolist(),\n", 619 | " categorical_feature=cat_feature_names.tolist(),\n", 620 | " verbose_eval=False,\n", 621 | ")" 622 | ] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "execution_count": 14, 627 | "id": "c56f8251-91b6-4c5d-96aa-4fad041d66e5", 628 | "metadata": {}, 629 | "outputs": [ 630 | { 631 | "name": "stderr", 632 | "output_type": "stream", 633 | "text": [ 634 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 113.09it/s]\n", 635 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 144.22it/s]\n", 636 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 143.92it/s]\n", 637 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 129.73it/s]\n", 638 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 167.08it/s]\n", 639 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 150.02it/s]\n", 640 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 159.86it/s]\n", 641 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 131.44it/s]\n", 642 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 148.97it/s]\n", 643 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 176.19it/s]\n", 644 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 191.48it/s]\n", 645 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 131.04it/s]\n", 646 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 160.26it/s]\n", 647 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 164.07it/s]\n", 648 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 138.85it/s]\n", 649 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 130.38it/s]\n", 650 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 218.60it/s]\n", 651 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 252.47it/s]\n", 652 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 149.70it/s]\n", 653 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 131.73it/s]\n", 654 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 153.53it/s]\n", 655 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 162.86it/s]\n", 656 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 147.86it/s]\n", 657 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 130.62it/s]\n", 658 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 151.66it/s]\n", 659 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 161.57it/s]\n", 660 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 138.81it/s]\n", 661 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 133.75it/s]\n", 662 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 173.34it/s]\n", 663 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 176.47it/s]\n", 664 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 165.16it/s]\n", 665 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 132.67it/s]\n", 666 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 217.98it/s]\n", 667 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 309.43it/s]\n", 668 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 138.83it/s]\n", 669 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 131.41it/s]\n", 670 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 219.16it/s]\n", 671 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 178.87it/s]\n", 672 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 148.35it/s]\n", 673 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 132.15it/s]\n", 674 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 179.92it/s]\n", 675 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 158.38it/s]\n", 676 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 149.11it/s]\n", 677 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 132.25it/s]\n", 678 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 219.60it/s]\n", 679 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 160.56it/s]\n", 680 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 154.62it/s]\n", 681 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 132.01it/s]\n", 682 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 184.67it/s]\n", 683 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 259.31it/s]\n", 684 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 141.43it/s]\n", 685 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 134.49it/s]\n", 686 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 220.60it/s]\n", 687 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 160.54it/s]\n", 688 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 141.03it/s]\n", 689 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 131.23it/s]\n", 690 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 170.68it/s]\n", 691 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 161.28it/s]\n", 692 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 139.50it/s]\n", 693 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 133.22it/s]\n", 694 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 149.68it/s]\n", 695 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 154.90it/s]\n", 696 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 149.38it/s]\n", 697 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 130.37it/s]\n", 698 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 435.51it/s]\n", 699 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 153.87it/s]\n", 700 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 154.48it/s]\n", 701 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 131.59it/s]\n", 702 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 158.73it/s]\n", 703 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 155.52it/s]\n", 704 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 207.69it/s]\n", 705 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 134.11it/s]\n", 706 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 150.12it/s]\n", 707 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 159.96it/s]\n", 708 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 147.16it/s]\n", 709 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 132.98it/s]\n", 710 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 182.96it/s]\n", 711 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 146.77it/s]\n", 712 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 168.26it/s]\n", 713 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 131.48it/s]\n", 714 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 164.31it/s]\n", 715 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 290.33it/s]\n", 716 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 144.81it/s]\n", 717 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 132.58it/s]\n", 718 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 184.55it/s]\n", 719 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 155.17it/s]\n", 720 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 192.15it/s]\n", 721 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 134.05it/s]\n", 722 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 160.59it/s]\n", 723 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 182.55it/s]\n", 724 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 139.20it/s]\n", 725 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 132.98it/s]\n", 726 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 166.66it/s]\n", 727 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 167.13it/s]\n", 728 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 138.51it/s]\n", 729 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 133.70it/s]\n", 730 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 206.44it/s]\n", 731 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 151.60it/s]\n", 732 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 137.91it/s]\n", 733 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 132.58it/s]\n", 734 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 152.85it/s]\n", 735 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 149.28it/s]\n", 736 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 136.67it/s]\n", 737 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 132.15it/s]\n", 738 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 171.13it/s]\n", 739 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 162.65it/s]\n", 740 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 165.37it/s]\n", 741 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 129.52it/s]\n", 742 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 187.68it/s]\n", 743 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 174.72it/s]\n", 744 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 138.90it/s]\n", 745 | "Feature Extraction: 100%|██████████| 24/24 [00:00<00:00, 128.89it/s]\n" 746 | ] 747 | }, 748 | { 749 | "name": "stdout", 750 | "output_type": "stream", 751 | "text": [ 752 | "CPU times: user 8min 7s, sys: 3min 44s, total: 11min 52s\n", 753 | "Wall time: 11min 46s\n" 754 | ] 755 | } 756 | ], 757 | "source": [ 758 | "%%time\n", 759 | "df_test = df_recur.copy()\n", 760 | "sales_pred = sales_train.copy()\n", 761 | "\n", 762 | "for i in range(test_steps):\n", 763 | " dtest = processor.transform(df_test[all_feature_names])\n", 764 | " y_pred = model.predict(dtest)\n", 765 | "\n", 766 | " sales_recur = pd.DataFrame(\n", 767 | " y_pred, columns=[\"sales\"], index=df_test.index\n", 768 | " ).reset_index()\n", 769 | " sales_pred = pd.concat([sales_pred, sales_recur])\n", 770 | " pred_dates = sales_pred[\"date\"].unique()\n", 771 | "\n", 772 | " for j, frequency in enumerate(frequencies):\n", 773 | " df_extracted = extract_features(\n", 774 | " sales_pred.set_index(\"date\").loc[pred_dates[-frequency:], :].reset_index(),\n", 775 | " default_fc_parameters=default_fc_parameters,\n", 776 | " column_id=\"id\",\n", 777 | " column_sort=\"date\",\n", 778 | " n_jobs=n_jobs,\n", 779 | " disable_progressbar=False,\n", 780 | " )\n", 781 | "\n", 782 | " df_extracted.columns = df_extracted.columns + f\"__D{frequency}\"\n", 783 | "\n", 784 | " if j == 0:\n", 785 | " feat_dynamic_real = df_extracted\n", 786 | " else:\n", 787 | " feat_dynamic_real = feat_dynamic_real.merge(\n", 788 | " df_extracted, left_index=True, right_index=True\n", 789 | " )\n", 790 | "\n", 791 | " feat_dynamic_real = feat_dynamic_real.reset_index()\n", 792 | " feat_dynamic_real.columns = [\"id\"] + feat_dynamic_real.columns[1:].tolist()\n", 793 | " feat_dynamic_real = feat_dynamic_real.merge(sales_recur[[\"id\", \"sales\"]]).rename(\n", 794 | " {\"sales\": \"sales__D1\"}, axis=1\n", 795 | " )\n", 796 | " feat_dynamic_real[\"date\"] = pred_dates[-1] + pd.Timedelta(\"1 days\")\n", 797 | "\n", 798 | " feat_dynamic_real = feat_dynamic_real.merge(prices, on=key_ids).merge(\n", 799 | " snap, on=key_ids\n", 800 | " )\n", 801 | "\n", 802 | " df_test = (\n", 803 | " feat_dynamic_real.merge(feat_dynamic_cat)\n", 804 | " .merge(feat_static_cat)\n", 805 | " .set_index(key_ids)\n", 806 | " )" 807 | ] 808 | }, 809 | { 810 | "cell_type": "code", 811 | "execution_count": 15, 812 | "id": "11176f77-bbd3-461d-94ee-99a69163c85b", 813 | "metadata": {}, 814 | "outputs": [ 815 | { 816 | "data": { 817 | "application/vnd.jupyter.widget-view+json": { 818 | "model_id": "855d6457b7fe4a64a26891b3156f5e0c", 819 | "version_major": 2, 820 | "version_minor": 0 821 | }, 822 | "text/plain": [ 823 | " 0%| | 0/12 [00:00\n", 858 | "\n", 871 | "\n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | "
maermsesmapemase
count305.000000305.000000305.000000303.000000
mean0.8922341.1978491.4428640.939395
std0.7976571.0683000.4582040.322833
min0.0865090.1863690.2189830.551782
25%0.4176780.5445451.0237540.741302
50%0.6655070.8502041.6051900.850567
75%1.0283931.4301111.8438771.032841
max6.8584618.9798672.0000002.981946
\n", 940 | "" 941 | ], 942 | "text/plain": [ 943 | " mae rmse smape mase\n", 944 | "count 305.000000 305.000000 305.000000 303.000000\n", 945 | "mean 0.892234 1.197849 1.442864 0.939395\n", 946 | "std 0.797657 1.068300 0.458204 0.322833\n", 947 | "min 0.086509 0.186369 0.218983 0.551782\n", 948 | "25% 0.417678 0.544545 1.023754 0.741302\n", 949 | "50% 0.665507 0.850204 1.605190 0.850567\n", 950 | "75% 1.028393 1.430111 1.843877 1.032841\n", 951 | "max 6.858461 8.979867 2.000000 2.981946" 952 | ] 953 | }, 954 | "metadata": {}, 955 | "output_type": "display_data" 956 | } 957 | ], 958 | "source": [ 959 | "predictions = sales_pred[sales_pred[\"date\"] > dates[-2 * test_steps - 1]]\n", 960 | "predictions = predictions.pivot(index=\"date\", columns=\"id\", values=\"sales\")\n", 961 | "\n", 962 | "df_pred_sampled = predictions.T\n", 963 | "df_pred_sampled = df_pred_sampled.loc[sampled_ids]\n", 964 | "df_pred_sampled.columns = df_test_sampled.columns\n", 965 | "df_pred_sampled.index = range(len(sampled_ids))\n", 966 | "\n", 967 | "wrmsse = evaluator.score(df_pred_sampled)\n", 968 | "eval_metrics = calc_eval_metric(df_test_sampled, df_pred_sampled)\n", 969 | "\n", 970 | "print(f\"LightGBM WRMSSE: {wrmsse:.6f}\")\n", 971 | "display(eval_metrics.describe())" 972 | ] 973 | }, 974 | { 975 | "cell_type": "code", 976 | "execution_count": 17, 977 | "id": "ab0d63c6-390d-473c-9b86-50fcc1e3ac43", 978 | "metadata": {}, 979 | "outputs": [], 980 | "source": [ 981 | "def plot_forecast(source, test_steps, plot_id=None, model_name=None, start_date=None):\n", 982 | " if start_date is not None:\n", 983 | " source = source[source[\"time\"] >= start_date]\n", 984 | "\n", 985 | " points = (\n", 986 | " alt.Chart(source)\n", 987 | " .mark_circle(size=10.0, color=\"#000000\")\n", 988 | " .encode(\n", 989 | " x=alt.X(\"time:T\", axis=alt.Axis(title=\"Date\")),\n", 990 | " y=alt.Y(\"y\", axis=alt.Axis(title=\"Demand\")),\n", 991 | " tooltip=[\"time:T\", \"y:Q\"],\n", 992 | " )\n", 993 | " )\n", 994 | "\n", 995 | " line = (\n", 996 | " alt.Chart(source)\n", 997 | " .mark_line(size=1.0, color=\"#4267B2\")\n", 998 | " .encode(\n", 999 | " x=\"time:T\",\n", 1000 | " y=\"fcst\",\n", 1001 | " )\n", 1002 | " )\n", 1003 | "\n", 1004 | " band = (\n", 1005 | " alt.Chart(source)\n", 1006 | " .mark_area(opacity=0.25, color=\"#4267B2\")\n", 1007 | " .encode(\n", 1008 | " x=\"time:T\",\n", 1009 | " y=\"fcst_lower\",\n", 1010 | " y2=\"fcst_upper\",\n", 1011 | " )\n", 1012 | " )\n", 1013 | "\n", 1014 | " rule = (\n", 1015 | " alt.Chart(source[[\"time\"]].iloc[-test_steps : -test_steps + 1])\n", 1016 | " .mark_rule(size=1.0, color=\"#FF0000\", strokeDash=[2, 2])\n", 1017 | " .encode(x=\"time:T\")\n", 1018 | " )\n", 1019 | "\n", 1020 | " title = \"Demand Forecast\"\n", 1021 | " if plot_id is not None:\n", 1022 | " title += f\" for '{plot_id}'\"\n", 1023 | " if model_name is not None:\n", 1024 | " title = f\"{model_name}: \" + title\n", 1025 | "\n", 1026 | " return (points + line + band + rule).properties(title=title, width=1000, height=300)" 1027 | ] 1028 | }, 1029 | { 1030 | "cell_type": "code", 1031 | "execution_count": 18, 1032 | "id": "bd3708db-0776-4a5f-b9ce-79e37d70efed", 1033 | "metadata": {}, 1034 | "outputs": [ 1035 | { 1036 | "data": { 1037 | "text/html": [ 1038 | "\n", 1039 | "
\n", 1040 | "" 1088 | ], 1089 | "text/plain": [ 1090 | "alt.VConcatChart(...)" 1091 | ] 1092 | }, 1093 | "execution_count": 18, 1094 | "metadata": {}, 1095 | "output_type": "execute_result" 1096 | } 1097 | ], 1098 | "source": [ 1099 | "best_perf_indices = eval_metrics[\"smape\"].dropna().sort_values()[:3].index\n", 1100 | "plots = []\n", 1101 | "\n", 1102 | "for index in best_perf_indices:\n", 1103 | " plot_id = pd.Series(sampled_ids).iloc[index]\n", 1104 | "\n", 1105 | " y = (df_train_eval[df_train_eval[\"id\"] == plot_id].loc[:, date_names]).T\n", 1106 | " y.columns = [\"y\"]\n", 1107 | " y = calendar.merge(y, left_on=\"d\", right_index=True)[[\"date\", \"y\"]]\n", 1108 | " y[\"time\"] = pd.to_datetime(y[\"date\"])\n", 1109 | "\n", 1110 | " source = y.merge(predictions[plot_id].reset_index(), how=\"left\").drop([\"date\"], axis=1)\n", 1111 | " source.columns = [\"y\", \"time\", \"fcst\"]\n", 1112 | " source[\"fcst_lower\"] = np.nan\n", 1113 | " source[\"fcst_upper\"] = np.nan\n", 1114 | "\n", 1115 | " p = plot_forecast(\n", 1116 | " source, test_steps, plot_id=plot_id, model_name=\"LightGBM\", start_date=\"2015-05-23\"\n", 1117 | " )\n", 1118 | " \n", 1119 | " plots.append(p)\n", 1120 | " \n", 1121 | "alt.VConcatChart(vconcat=plots)" 1122 | ] 1123 | } 1124 | ], 1125 | "metadata": { 1126 | "kernelspec": { 1127 | "display_name": "Python 3", 1128 | "language": "python", 1129 | "name": "python3" 1130 | }, 1131 | "language_info": { 1132 | "codemirror_mode": { 1133 | "name": "ipython", 1134 | "version": 3 1135 | }, 1136 | "file_extension": ".py", 1137 | "mimetype": "text/x-python", 1138 | "name": "python", 1139 | "nbconvert_exporter": "python", 1140 | "pygments_lexer": "ipython3", 1141 | "version": "3.7.10" 1142 | } 1143 | }, 1144 | "nbformat": 4, 1145 | "nbformat_minor": 5 1146 | } 1147 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MOFC Demand Forecasting with Time Series Analysis 2 | ### Goals 3 | * Compare the accuracy of various time series forecasting algorithms such as *Prophet*, *DeepAR*, *VAR*, *DeepVAR*, and *[LightGBM](https://papers.nips.cc/paper/2017/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf)* 4 | * (Optional) Use `tsfresh` for automated feature engineering of time series data. 5 | 6 | ### Requirements 7 | * The dataset can be downloaded from [this Kaggle competition](https://www.kaggle.com/c/m5-forecasting-accuracy). 8 | * In addition to the [Anaconda](https://www.anaconda.com) libraries, you need to install `altair`, `vega_datasets`, `category_encoders`, `mxnet`, `gluonts`, `kats`, `lightgbm`, `hyperopt` and `pandarallel`. 9 | * `kats` requires Python 3.7 or higher. 10 | 11 | ## Competition, Datasets and Evaluation 12 | * [The M5 Competition](https://mofc.unic.ac.cy/m5-competition) aims to forecast daily sales for the next 28 days based on sales over the last 1,941 days for IDs of 30,490 items per Walmart store. 13 | * Data includes (i) time series of daily sales quantity by ID, (ii) sales prices, and (iii) holiday and event information. 14 | * Evaluation is done through *Weighted Root Mean Squared Scaled Error*. A detailed explanation is given in the M5 Participants Guide and the implementation is at [this link](https://www.kaggle.com/c/m5-forecasting-accuracy/discussion/133834). 15 | * For hyperparameter tuning, 0.1% of IDs were randomly selected and used, and 1% were used to measure test set performance. 16 | 17 | ## Algorithms 18 | ### Kats: Prophet 19 | * *Prophet* can incorporate forward-looking related time series into the model, so additional features were created with holiday and event information. 20 | * Since a *Prophet* model has to fit for each ID, I had to use the `apply` function of the `pandas dataframe` and instead used `pandarallel` to maximize the parallelization performance. 21 | * *Prophet* hyperparameters were tuned through 3-fold CV using the *Bayesian Optimization* module built into the `Kats` library. In this case, *[Tweedie](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_tweedie_deviance.html)* was applied as the loss function. Below is the hyperparameter tuning result. 22 | 23 | |seasonality_prior_scale|changepoint_prior_scale|changepoint_range|n_changepoints|holidays_prior_scale|seasonality_mode| 24 | |:---:|:---:|:---:|:---:|:---:|:---:| 25 | |0.01|0.046|0.93|5|100.00|multiplicative| 26 | 27 | * In the figures below, the actual sales (black dots), the point predictions and confidence intervals (blue lines and bands), and the red dotted lines representing the test period are shown. 28 | 29 | ![Forecasting](./img/prophet.svg) 30 | 31 | ### Kats: VAR 32 | * Since *VAR* is a multivariate time series model, the more IDs it fits simultaneously, the better the performance, and the memory requirement increases exponentially. 33 | 34 | ![Forecasting](./img/var.svg) 35 | 36 | ### GluonTS: DeepAR 37 | * *DeepAR* can incorporate metadata and forward-looking related time series into the model, so additional features were created with sales prices, holiday and event information. Dynamic categorical variables were quantified through [Feature Hashing](https://alex.smola.org/papers/2009/Weinbergeretal09.pdf). 38 | * As a hyperparameter, it is very important to set the probability distribution of the output, and here it is set as the *Negative Binomial* distribution. 39 | 40 | ![Forecasting](./img/deepar.svg) 41 | 42 | ### GluonTS: DeepVAR 43 | * In the case of *DeepVAR*, a multivariate model, what can be set as the probability distribution of the output is limited (i.e. *Multivariate Gaussian* distribution), which leads to a decrease in performance. 44 | 45 | ![Forecasting](./img/deepvar.svg) 46 | 47 | ### LightGBM 48 | * I used `tsfresh` to convert time series into structured data features, which consumes a lot of computational resources even with minimal settings. 49 | * A *LightGBM* *Tweedie* regression model was fitted. Hyperparameters were tuned via 3-fold CV using the *Bayesian Optimization* function of the `hyperopt` library. The following is the hyperparameter tuning result. 50 | 51 | |boosting|learning_rate|num_iterations|num_leaves|min_data_in_leaf|min_sum_hessian_in_leaf|bagging_fraction|bagging_freq|feature_fraction|extra_trees|lambda_l1|lambda_l2|path_smooth|max_bin| 52 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 53 | |gbdt|0.01773|522|11|33|0.0008|0.5297|4|0.5407|False|2.9114|0.2127|217.3879|1023| 54 | 55 | * The sales forecast for day D+1 was used recursively to predict the sales volume for day D+2 through feature engineering, and through this iterative process, 28-day test set performance was measured. 56 | 57 | ![Forecasting](./img/lgb.svg) 58 | 59 | ## Algorithms Performance Summary 60 | |Algorithm|WRMSSE|sMAPE|MAE|MASE|RMSE| 61 | |:---:|:---:|:---:|:---:|:---:|:---:| 62 | |DeepAR|0.7513|1.4200|0.8795|0.9269|1.1614| 63 | |LightGBM|1.0701|1.4429|0.8922|0.9394|1.1978| 64 | |Prophet|1.0820|1.4174|1.1014|1.0269|1.4410| 65 | |VAR|1.2876|2.3818|1.5545|1.6871|1.9502| 66 | |Naive Method|1.3430|1.5074|1.3730|1.1077|1.7440| 67 | |Mean Method|1.5984|1.4616|1.1997|1.0708|1.5352| 68 | |DeepVAR|4.6933|4.6847|1.9201|1.3683|2.3195| 69 | 70 | As a result, *DeepAR* was finally selected and submitted its predictions to Kaggle, achieving a WRMSSE value of 0.8112 based on the private leaderboard. 71 | 72 | ### References 73 | * [Taylor SJ, Letham B. 2017. Forecasting at scale. *PeerJ Preprints* 5:e3190v2](https://peerj.com/preprints/3190.pdf) 74 | * [Prophet: Forecasting at Scale](https://research.fb.com/blog/2017/02/prophet-forecasting-at-scale) 75 | * [Stock, James, H., Mark W. Watson. 2001. Vector Autoregressions. *Journal of Economic Perspectives*, 15 (4): 101-115.](https://www.princeton.edu/~mwatson/papers/Stock_Watson_JEP_2001.pdf) 76 | * [David Salinas, Valentin Flunkert, Jan Gasthaus, Tim Januschowski. 2020. DeepAR: Probabilistic forecasting with autoregressive recurrent networks, *International Journal of Forecasting*, 36 (3): 1181-1191.](https://arxiv.org/pdf/1704.04110.pdf) 77 | * [David Salinas, Michael Bohlke-Schneider, Laurent Callot, Roberto Medico, 78 | Jan Gasthaus. 2019. High-dimensional multivariate forecasting with low-rank Gaussian Copula Processes. *In Advances in Neural Information Processing Systems*. 6827–6837.](https://arxiv.org/pdf/1910.03002.pdf) 79 | * [Kats - One Stop Shop for Time Series Analysis in Python](https://facebookresearch.github.io/Kats/) 80 | * [GluonTS - Probabilistic Time Series Modeling](https://ts.gluon.ai/index.html) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | altair==4.1.0 2 | category_encoders==2.2.2 3 | gluonts==0.8.0 4 | hyperopt==0.2.5 5 | kats==0.1.0 6 | lightgbm==3.1.1 7 | mxnet==1.8.0.post0 8 | pandarallel==1.4.8 9 | tsfresh==0.17.0 10 | vega_datasets==0.9.0 -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from typing import Union 4 | from sklearn.metrics import mean_absolute_error, mean_squared_error 5 | from tqdm.notebook import tqdm 6 | 7 | 8 | def calc_eval_metric(y_true, y_pred): 9 | eval_metrics = dict() 10 | 11 | for index in y_true.index: 12 | eval_metrics_by_id = dict() 13 | 14 | eval_metrics_by_id["mae"] = mean_absolute_error( 15 | y_true.loc[index, :], y_pred.loc[index, :] 16 | ) 17 | eval_metrics_by_id["rmse"] = np.sqrt( 18 | mean_squared_error(y_true.loc[index, :], y_pred.loc[index, :]) 19 | ) 20 | eval_metrics_by_id["smape"] = mean_absolute_percentage_error( 21 | y_true.loc[index, :], y_pred.loc[index, :], is_symmetric=True 22 | ) 23 | eval_metrics_by_id["mase"] = mean_absolute_scaled_error( 24 | y_true.loc[index, :], y_pred.loc[index, :] 25 | ) 26 | 27 | eval_metrics[index] = eval_metrics_by_id 28 | 29 | return pd.DataFrame(eval_metrics).T 30 | 31 | 32 | def mean_absolute_percentage_error(y_true, y_pred, is_symmetric=False): 33 | if is_symmetric: 34 | return np.nanmean(2 * np.abs((y_true - y_pred) / (y_true + y_pred))) 35 | else: 36 | return np.nanmean(p.abs((y_true - y_pred) / y_true)) 37 | 38 | 39 | def mean_absolute_scaled_error(y_true, y_pred, seasonality=1): 40 | naive_forecast = y_true[:-seasonality] 41 | denominator = mean_absolute_error(y_true[seasonality:], naive_forecast) 42 | return ( 43 | mean_absolute_error(y_true, y_pred) / denominator 44 | if denominator > 0.0 45 | else np.nan 46 | ) 47 | 48 | 49 | class WRMSSEEvaluator(object): 50 | def __init__( 51 | self, 52 | df_train: pd.DataFrame, 53 | df_test: pd.DataFrame, 54 | calendar: pd.DataFrame, 55 | selling_prices: pd.DataFrame, 56 | test_steps: int, 57 | ): 58 | train = df_train.copy() 59 | test = df_test.copy() 60 | 61 | target = train.loc[:, train.columns.str.startswith("d_")] 62 | train_target_columns = target.columns.tolist() 63 | weight_columns = target.iloc[:, -test_steps:].columns.tolist() 64 | train["all_id"] = 0 65 | key_columns = train.loc[:, ~train.columns.str.startswith("d_")].columns.tolist() 66 | test_target_columns = test.loc[ 67 | :, test.columns.str.startswith("d_") 68 | ].columns.tolist() 69 | 70 | if not all([column in test.columns for column in key_columns]): 71 | test = pd.concat([train[key_columns], test], axis=1, sort=False) 72 | 73 | self.train = train 74 | self.test = test 75 | self.calendar = calendar 76 | self.selling_prices = selling_prices 77 | self.weight_columns = weight_columns 78 | self.key_columns = key_columns 79 | self.test_target_columns = test_target_columns 80 | 81 | sales_weights = self.get_sales_weight() 82 | 83 | self.group_ids = ( 84 | "all_id", 85 | "state_id", 86 | "store_id", 87 | "cat_id", 88 | "dept_id", 89 | ["state_id", "cat_id"], 90 | ["state_id", "dept_id"], 91 | ["store_id", "cat_id"], 92 | ["store_id", "dept_id"], 93 | "item_id", 94 | ["item_id", "state_id"], 95 | ["item_id", "store_id"], 96 | ) 97 | 98 | for i, group_id in enumerate(tqdm(self.group_ids)): 99 | train_total_quantities = train.groupby(group_id)[train_target_columns].sum() 100 | scale = [] 101 | for _, row in train_total_quantities.iterrows(): 102 | series = row.values[np.argmax(row.values != 0) :] 103 | scale.append(((series[1:] - series[:-1]) ** 2).mean()) 104 | setattr(self, f"level-{i + 1}_scale", np.array(scale)) 105 | setattr( 106 | self, f"level-{i + 1}_train_total_quantities", train_total_quantities 107 | ) 108 | setattr( 109 | self, 110 | f"level-{i + 1}_test_total_quantities", 111 | test.groupby(group_id)[test_target_columns].sum(), 112 | ) 113 | level_weight = ( 114 | sales_weights.groupby(group_id)[weight_columns].sum().sum(axis=1) 115 | ) 116 | setattr(self, f"level-{i + 1}_weight", level_weight / level_weight.sum()) 117 | 118 | def get_sales_weight(self) -> pd.DataFrame: 119 | day_to_week = self.calendar.set_index("d")["wm_yr_wk"].to_dict() 120 | sales_weights = self.train[ 121 | ["item_id", "store_id"] + self.weight_columns 122 | ].set_index(["item_id", "store_id"]) 123 | sales_weights = ( 124 | sales_weights.stack() 125 | .reset_index() 126 | .rename(columns={"level_2": "d", 0: "value"}) 127 | ) 128 | sales_weights["wm_yr_wk"] = sales_weights["d"].map(day_to_week) 129 | 130 | sales_weights = sales_weights.merge( 131 | self.selling_prices, how="left", on=["item_id", "store_id", "wm_yr_wk"] 132 | ) 133 | sales_weights["value"] = sales_weights["value"] * sales_weights["sell_price"] 134 | sales_weights = sales_weights.set_index(["item_id", "store_id", "d"]).unstack( 135 | level=2 136 | )["value"] 137 | sales_weights = sales_weights.loc[ 138 | zip(self.train["item_id"], self.train["store_id"]), : 139 | ].reset_index(drop=True) 140 | sales_weights = pd.concat( 141 | [self.train[self.key_columns], sales_weights], axis=1, sort=False 142 | ) 143 | return sales_weights 144 | 145 | def rmsse(self, prediction: pd.DataFrame, level: int) -> pd.Series: 146 | test_total_quantities = getattr(self, f"level-{level}_test_total_quantities") 147 | score = ((test_total_quantities - prediction) ** 2).mean(axis=1) 148 | scale = getattr(self, f"level-{level}_scale") 149 | return (score / scale).map(np.sqrt) 150 | 151 | def score(self, predictions: Union[pd.DataFrame, np.ndarray]) -> float: 152 | assert self.test[self.test_target_columns].shape == predictions.shape 153 | 154 | if isinstance(predictions, np.ndarray): 155 | predictions = pd.DataFrame(predictions, columns=self.test_target_columns) 156 | 157 | predictions = pd.concat( 158 | [self.test[self.key_columns], predictions], axis=1, sort=False 159 | ) 160 | 161 | all_scores = [] 162 | for i, group_id in enumerate(self.group_ids): 163 | level_scores = self.rmsse( 164 | predictions.groupby(group_id)[self.test_target_columns].sum(), i + 1 165 | ) 166 | weight = getattr(self, f"level-{i + 1}_weight") 167 | level_scores = pd.concat([weight, level_scores], axis=1, sort=False).prod( 168 | axis=1 169 | ) 170 | all_scores.append(level_scores.sum()) 171 | 172 | return np.mean(all_scores) 173 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from numba import jit 4 | 5 | 6 | def dump_pickle(file_path, obj): 7 | with open(file_path, "wb") as f: 8 | pickle.dump(obj, f) 9 | 10 | 11 | def load_pickle(file_path): 12 | with open(file_path, "rb") as f: 13 | obj = pickle.load(f) 14 | return obj 15 | 16 | 17 | @jit(nopython=True) 18 | def remove_outlier(series, window_size, fill_na=True, n_sigmas=3): 19 | n = len(series) 20 | copied = series.copy() 21 | indices = [] 22 | k = 1.4826 23 | 24 | for i in range(window_size, (n - window_size)): 25 | median = np.nanmedian(series[(i - window_size) : (i + window_size)]) 26 | stat = k * np.nanmedian( 27 | np.abs(series[(i - window_size) : (i + window_size)] - median) 28 | ) 29 | if np.abs(series[i] - median) > n_sigmas * stat: 30 | copied[i] = median if fill_na else np.nan 31 | indices.append(i) 32 | 33 | return copied, indices 34 | --------------------------------------------------------------------------------