","text/html":"\n \n \n
\n [200/200 1:04:48, Epoch 0/1]\n
\n \n \n \n | Step | \n Training Loss | \n
\n \n \n \n | 10 | \n 2.871600 | \n
\n \n | 20 | \n 2.576800 | \n
\n \n | 30 | \n 2.196500 | \n
\n \n | 40 | \n 2.024300 | \n
\n \n | 50 | \n 1.784300 | \n
\n \n | 60 | \n 1.618500 | \n
\n \n | 70 | \n 1.524300 | \n
\n \n | 80 | \n 1.448200 | \n
\n \n | 90 | \n 1.450400 | \n
\n \n | 100 | \n 1.406500 | \n
\n \n | 110 | \n 1.392100 | \n
\n \n | 120 | \n 1.427900 | \n
\n \n | 130 | \n 1.417400 | \n
\n \n | 140 | \n 1.402700 | \n
\n \n | 150 | \n 1.402300 | \n
\n \n | 160 | \n 1.356500 | \n
\n \n | 170 | \n 1.375300 | \n
\n \n | 180 | \n 1.382400 | \n
\n \n | 190 | \n 1.426700 | \n
\n \n | 200 | \n 1.352500 | \n
\n \n
"},"metadata":{}},{"execution_count":10,"output_type":"execute_result","data":{"text/plain":"TrainOutput(global_step=200, training_loss=1.6418570566177368, metrics={'train_runtime': 3898.6987, 'train_samples_per_second': 0.821, 'train_steps_per_second': 0.051, 'total_flos': 9810597599281152.0, 'train_loss': 1.6418570566177368})"},"metadata":{}}],"execution_count":10},{"cell_type":"markdown","source":"## Inference","metadata":{}},{"cell_type":"code","source":"sample = dataset[\"test\"][0]\nprint(sample)\n\n\ndef format_chat_template(row):\n messages = [\n {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n {\"role\": \"user\", \"content\": row[\"instruction\"]}\n ] \n prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n return prompt\n\nprompt = format_chat_template(sample)\nprompt","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-17T22:27:03.428724Z","iopub.execute_input":"2025-05-17T22:27:03.428978Z","iopub.status.idle":"2025-05-17T22:27:03.436206Z","shell.execute_reply.started":"2025-05-17T22:27:03.428960Z","shell.execute_reply":"2025-05-17T22:27:03.435494Z"}},"outputs":[{"name":"stdout","text":"{'instruction': 'Provide the necessary materials for the given project.', 'input': 'Build a birdhouse', 'output': 'Materials Needed for Building a Birdhouse:\\n-Pieces of wood for the base, walls and roof of the birdhouse \\n-Saw \\n-Screws \\n-Screwdriver \\n-Nails \\n-Hammer \\n-Paint\\n-Paintbrushes \\n-Drill and bits \\n-Gravel (optional)', 'text': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nProvide the necessary materials for the given project.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\nMaterials Needed for Building a Birdhouse:\\n-Pieces of wood for the base, walls and roof of the birdhouse \\n-Saw \\n-Screws \\n-Screwdriver \\n-Nails \\n-Hammer \\n-Paint\\n-Paintbrushes \\n-Drill and bits \\n-Gravel (optional)<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'}\n","output_type":"stream"},{"execution_count":11,"output_type":"execute_result","data":{"text/plain":"'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nProvide the necessary materials for the given project.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'"},"metadata":{}}],"execution_count":11},{"cell_type":"code","source":"# sft_model = AutoModelForCausalLM.from_pretrained(\"./llama3_sft/\")\n\nsft_model = AutoModelForCausalLM.from_pretrained(\"/kaggle/input/llama-3.2/transformers/3b-instruct/1\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-17T22:27:03.436908Z","iopub.execute_input":"2025-05-17T22:27:03.437240Z","iopub.status.idle":"2025-05-17T22:27:10.663130Z","shell.execute_reply.started":"2025-05-17T22:27:03.437221Z","shell.execute_reply":"2025-05-17T22:27:10.662463Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"668ac75fc0574f1f82a18ceb7e774b35"}},"metadata":{}}],"execution_count":12},{"cell_type":"code","source":"inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n\noutput = model.generate(\n **inputs,\n max_new_tokens=100,\n do_sample=True,\n temperature=0.7,\n pad_token_id=tokenizer.eos_token_id,\n)\n\ngenerated_text = tokenizer.decode(output[0], skip_special_tokens=False)\nprint(generated_text[len(prompt):].split(\"<|eot_id|>\")[0]) ","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-05-17T22:27:10.663898Z","iopub.execute_input":"2025-05-17T22:27:10.664164Z","iopub.status.idle":"2025-05-17T22:27:17.590203Z","shell.execute_reply.started":"2025-05-17T22:27:10.664145Z","shell.execute_reply":"2025-05-17T22:27:17.589393Z"}},"outputs":[{"name":"stderr","text":"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\n/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py:745: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n return fn(*args, **kwargs)\n/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n warnings.warn(\n","output_type":"stream"},{"name":"stdout","text":"The materials needed for this project are:\n- A piece of wood\n- A drill\n- A hammer\n- A saw\n- A sandpaper\n- Paint\n- Paintbrushes\n- A paint tray\n","output_type":"stream"}],"execution_count":13}]}
--------------------------------------------------------------------------------
/time_series/vehicle-sales-prediction-tensorflow-lstm.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "bb0d377a",
6 | "metadata": {
7 | "papermill": {
8 | "duration": 0.002831,
9 | "end_time": "2025-05-16T02:49:48.027348",
10 | "exception": false,
11 | "start_time": "2025-05-16T02:49:48.024517",
12 | "status": "completed"
13 | },
14 | "tags": []
15 | },
16 | "source": [
17 | "vehicle sales data\n",
18 | "- data in [kaggle dataste](https://www.kaggle.com/datasets/brendayue/china-vehicle-sales-data)"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 1,
24 | "id": "af6c017a",
25 | "metadata": {
26 | "execution": {
27 | "iopub.execute_input": "2025-05-16T02:49:48.032819Z",
28 | "iopub.status.busy": "2025-05-16T02:49:48.032563Z",
29 | "iopub.status.idle": "2025-05-16T02:49:52.474282Z",
30 | "shell.execute_reply": "2025-05-16T02:49:52.473247Z"
31 | },
32 | "papermill": {
33 | "duration": 4.446267,
34 | "end_time": "2025-05-16T02:49:52.475945",
35 | "exception": false,
36 | "start_time": "2025-05-16T02:49:48.029678",
37 | "status": "completed"
38 | },
39 | "tags": []
40 | },
41 | "outputs": [
42 | {
43 | "name": "stdout",
44 | "output_type": "stream",
45 | "text": [
46 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m74.4/74.4 kB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n",
47 | "\u001b[?25h"
48 | ]
49 | }
50 | ],
51 | "source": [
52 | "!pip install tfts --quiet"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 2,
58 | "id": "cd93f2b1",
59 | "metadata": {
60 | "execution": {
61 | "iopub.execute_input": "2025-05-16T02:49:52.482353Z",
62 | "iopub.status.busy": "2025-05-16T02:49:52.482091Z",
63 | "iopub.status.idle": "2025-05-16T02:50:09.490987Z",
64 | "shell.execute_reply": "2025-05-16T02:50:09.489995Z"
65 | },
66 | "papermill": {
67 | "duration": 17.014133,
68 | "end_time": "2025-05-16T02:50:09.492944",
69 | "exception": false,
70 | "start_time": "2025-05-16T02:49:52.478811",
71 | "status": "completed"
72 | },
73 | "tags": []
74 | },
75 | "outputs": [
76 | {
77 | "name": "stderr",
78 | "output_type": "stream",
79 | "text": [
80 | "2025-05-16 02:49:55.547627: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
81 | "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
82 | "E0000 00:00:1747363795.751419 19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
83 | "E0000 00:00:1747363795.806600 19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
84 | "I0000 00:00:1747363809.433713 19 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory: -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0\n"
85 | ]
86 | }
87 | ],
88 | "source": [
89 | "import logging\n",
90 | "from typing import List, Optional, Union\n",
91 | "import numpy as np\n",
92 | "import pandas as pd\n",
93 | "import tensorflow as tf\n",
94 | "from tfts import AutoModel, AutoConfig, KerasTrainer"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 3,
100 | "id": "b79c26aa",
101 | "metadata": {
102 | "execution": {
103 | "iopub.execute_input": "2025-05-16T02:50:09.502215Z",
104 | "iopub.status.busy": "2025-05-16T02:50:09.501552Z",
105 | "iopub.status.idle": "2025-05-16T02:50:09.506557Z",
106 | "shell.execute_reply": "2025-05-16T02:50:09.505778Z"
107 | },
108 | "papermill": {
109 | "duration": 0.011339,
110 | "end_time": "2025-05-16T02:50:09.508436",
111 | "exception": false,
112 | "start_time": "2025-05-16T02:50:09.497097",
113 | "status": "completed"
114 | },
115 | "tags": []
116 | },
117 | "outputs": [],
118 | "source": [
119 | "class CFG:\n",
120 | " input_dir = \"/kaggle/input/china-vehicle-sales-data/china_vehicle_sales_data.csv\"\n",
121 | " train_sequence_length = 12\n",
122 | " predict_sequence_length = 3\n"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": 4,
128 | "id": "2457152a",
129 | "metadata": {
130 | "execution": {
131 | "iopub.execute_input": "2025-05-16T02:50:09.518205Z",
132 | "iopub.status.busy": "2025-05-16T02:50:09.517910Z",
133 | "iopub.status.idle": "2025-05-16T02:50:09.628283Z",
134 | "shell.execute_reply": "2025-05-16T02:50:09.627424Z"
135 | },
136 | "papermill": {
137 | "duration": 0.116531,
138 | "end_time": "2025-05-16T02:50:09.629616",
139 | "exception": false,
140 | "start_time": "2025-05-16T02:50:09.513085",
141 | "status": "completed"
142 | },
143 | "tags": []
144 | },
145 | "outputs": [
146 | {
147 | "data": {
148 | "text/html": [
149 | "
\n",
150 | "\n",
163 | "
\n",
164 | " \n",
165 | " \n",
166 | " | \n",
167 | " Date | \n",
168 | " province | \n",
169 | " provinceId | \n",
170 | " popularity | \n",
171 | " model | \n",
172 | " bodyType | \n",
173 | " salesVolume | \n",
174 | "
\n",
175 | " \n",
176 | " \n",
177 | " \n",
178 | " | 0 | \n",
179 | " 201601 | \n",
180 | " Shanghai | \n",
181 | " 310000 | \n",
182 | " 1479 | \n",
183 | " 3c974920a76ac9c1 | \n",
184 | " SUV | \n",
185 | " 292 | \n",
186 | "
\n",
187 | " \n",
188 | " | 1 | \n",
189 | " 201601 | \n",
190 | " Yunnan | \n",
191 | " 530000 | \n",
192 | " 1594 | \n",
193 | " 3c974920a76ac9c1 | \n",
194 | " SUV | \n",
195 | " 466 | \n",
196 | "
\n",
197 | " \n",
198 | " | 2 | \n",
199 | " 201601 | \n",
200 | " Inner Mongolia | \n",
201 | " 150000 | \n",
202 | " 1479 | \n",
203 | " 3c974920a76ac9c1 | \n",
204 | " SUV | \n",
205 | " 257 | \n",
206 | "
\n",
207 | " \n",
208 | " | 3 | \n",
209 | " 201601 | \n",
210 | " Beijing | \n",
211 | " 110000 | \n",
212 | " 2370 | \n",
213 | " 3c974920a76ac9c1 | \n",
214 | " SUV | \n",
215 | " 408 | \n",
216 | "
\n",
217 | " \n",
218 | " | 4 | \n",
219 | " 201601 | \n",
220 | " Sichuan | \n",
221 | " 510000 | \n",
222 | " 3562 | \n",
223 | " 3c974920a76ac9c1 | \n",
224 | " SUV | \n",
225 | " 610 | \n",
226 | "
\n",
227 | " \n",
228 | "
\n",
229 | "
"
230 | ],
231 | "text/plain": [
232 | " Date province provinceId popularity model bodyType \\\n",
233 | "0 201601 Shanghai 310000 1479 3c974920a76ac9c1 SUV \n",
234 | "1 201601 Yunnan 530000 1594 3c974920a76ac9c1 SUV \n",
235 | "2 201601 Inner Mongolia 150000 1479 3c974920a76ac9c1 SUV \n",
236 | "3 201601 Beijing 110000 2370 3c974920a76ac9c1 SUV \n",
237 | "4 201601 Sichuan 510000 3562 3c974920a76ac9c1 SUV \n",
238 | "\n",
239 | " salesVolume \n",
240 | "0 292 \n",
241 | "1 466 \n",
242 | "2 257 \n",
243 | "3 408 \n",
244 | "4 610 "
245 | ]
246 | },
247 | "execution_count": 4,
248 | "metadata": {},
249 | "output_type": "execute_result"
250 | }
251 | ],
252 | "source": [
253 | "data = pd.read_csv(CFG.input_dir)\n",
254 | "\n",
255 | "data.head()"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": 5,
261 | "id": "6b674127",
262 | "metadata": {
263 | "execution": {
264 | "iopub.execute_input": "2025-05-16T02:50:09.635994Z",
265 | "iopub.status.busy": "2025-05-16T02:50:09.635773Z",
266 | "iopub.status.idle": "2025-05-16T02:50:09.642094Z",
267 | "shell.execute_reply": "2025-05-16T02:50:09.641305Z"
268 | },
269 | "papermill": {
270 | "duration": 0.01082,
271 | "end_time": "2025-05-16T02:50:09.643286",
272 | "exception": false,
273 | "start_time": "2025-05-16T02:50:09.632466",
274 | "status": "completed"
275 | },
276 | "tags": []
277 | },
278 | "outputs": [],
279 | "source": [
280 | "# https://github.com/hongyingyue/Vehicle-sales-predictor/blob/main/vehicle_ml/feature/ts_feature.py\n",
281 | "\n",
282 | "logger = logging.getLogger(__name__)\n",
283 | "\n",
284 | "def add_lagging_feature(\n",
285 | " data: pd.DataFrame,\n",
286 | " groupby_column: Union[str, List[str]],\n",
287 | " value_columns: List[str],\n",
288 | " lags: List[int],\n",
289 | " feature_columns: Optional[List[str]] = None,\n",
290 | "):\n",
291 | " # note that the data should be sorted by time already\n",
292 | " # the lagging feature could be further developed use f1 - f1_lag, or f1 / f1_lag\n",
293 | "\n",
294 | " if not isinstance(groupby_column, (str, list)):\n",
295 | " raise TypeError(f\"'groupby_column' must be a string or a list of strings, but got {type(groupby_column)}.\")\n",
296 | "\n",
297 | " if not isinstance(value_columns, (list, tuple)):\n",
298 | " raise TypeError(f\"'value_columns' must be a list of strings, but got {type(value_columns)}.\")\n",
299 | "\n",
300 | " feature_columns: List[str] = feature_columns if feature_columns is not None else []\n",
301 | " for column in value_columns:\n",
302 | " if column not in data.columns:\n",
303 | " raise ValueError(f\"Value column '{column}' not found in DataFrame.\")\n",
304 | "\n",
305 | " for lag in lags:\n",
306 | " feature_col_name = f\"{column}_lag{lag}\"\n",
307 | " feature_columns.append(feature_col_name)\n",
308 | " logger.debug(\n",
309 | " f\"Creating lagging feature: {feature_col_name} for column '{column}' with lag {lag} and groupby '{groupby_column}'.\"\n",
310 | " )\n",
311 | " data[feature_col_name] = data.groupby(groupby_column)[column].shift(lag)\n",
312 | " return data"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": 6,
318 | "id": "6c26a66d",
319 | "metadata": {
320 | "execution": {
321 | "iopub.execute_input": "2025-05-16T02:50:09.649177Z",
322 | "iopub.status.busy": "2025-05-16T02:50:09.648793Z",
323 | "iopub.status.idle": "2025-05-16T02:50:09.734527Z",
324 | "shell.execute_reply": "2025-05-16T02:50:09.733463Z"
325 | },
326 | "papermill": {
327 | "duration": 0.090844,
328 | "end_time": "2025-05-16T02:50:09.736508",
329 | "exception": false,
330 | "start_time": "2025-05-16T02:50:09.645664",
331 | "status": "completed"
332 | },
333 | "tags": []
334 | },
335 | "outputs": [
336 | {
337 | "name": "stderr",
338 | "output_type": "stream",
339 | "text": [
340 | "/usr/local/lib/python3.11/dist-packages/pandas/io/formats/format.py:1458: RuntimeWarning: invalid value encountered in greater\n",
341 | " has_large_values = (abs_vals > 1e6).any()\n",
342 | "/usr/local/lib/python3.11/dist-packages/pandas/io/formats/format.py:1459: RuntimeWarning: invalid value encountered in less\n",
343 | " has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()\n",
344 | "/usr/local/lib/python3.11/dist-packages/pandas/io/formats/format.py:1459: RuntimeWarning: invalid value encountered in greater\n",
345 | " has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()\n"
346 | ]
347 | },
348 | {
349 | "data": {
350 | "text/html": [
351 | "\n",
352 | "\n",
365 | "
\n",
366 | " \n",
367 | " \n",
368 | " | \n",
369 | " Date | \n",
370 | " province | \n",
371 | " provinceId | \n",
372 | " popularity | \n",
373 | " model | \n",
374 | " bodyType | \n",
375 | " salesVolume | \n",
376 | " salesVolume_lag1 | \n",
377 | " salesVolume_lag2 | \n",
378 | " salesVolume_lag3 | \n",
379 | " salesVolume_lag4 | \n",
380 | " salesVolume_lag5 | \n",
381 | " salesVolume_lag6 | \n",
382 | " salesVolume_lag7 | \n",
383 | " salesVolume_lag8 | \n",
384 | " salesVolume_lag9 | \n",
385 | " salesVolume_lag10 | \n",
386 | " salesVolume_lag11 | \n",
387 | "
\n",
388 | " \n",
389 | " \n",
390 | " \n",
391 | " | 0 | \n",
392 | " 201601 | \n",
393 | " Shanghai | \n",
394 | " 310000 | \n",
395 | " 1479 | \n",
396 | " 3c974920a76ac9c1 | \n",
397 | " SUV | \n",
398 | " 292 | \n",
399 | " NaN | \n",
400 | " NaN | \n",
401 | " NaN | \n",
402 | " NaN | \n",
403 | " NaN | \n",
404 | " NaN | \n",
405 | " NaN | \n",
406 | " NaN | \n",
407 | " NaN | \n",
408 | " NaN | \n",
409 | " NaN | \n",
410 | "
\n",
411 | " \n",
412 | " | 1 | \n",
413 | " 201601 | \n",
414 | " Yunnan | \n",
415 | " 530000 | \n",
416 | " 1594 | \n",
417 | " 3c974920a76ac9c1 | \n",
418 | " SUV | \n",
419 | " 466 | \n",
420 | " NaN | \n",
421 | " NaN | \n",
422 | " NaN | \n",
423 | " NaN | \n",
424 | " NaN | \n",
425 | " NaN | \n",
426 | " NaN | \n",
427 | " NaN | \n",
428 | " NaN | \n",
429 | " NaN | \n",
430 | " NaN | \n",
431 | "
\n",
432 | " \n",
433 | " | 2 | \n",
434 | " 201601 | \n",
435 | " Inner Mongolia | \n",
436 | " 150000 | \n",
437 | " 1479 | \n",
438 | " 3c974920a76ac9c1 | \n",
439 | " SUV | \n",
440 | " 257 | \n",
441 | " NaN | \n",
442 | " NaN | \n",
443 | " NaN | \n",
444 | " NaN | \n",
445 | " NaN | \n",
446 | " NaN | \n",
447 | " NaN | \n",
448 | " NaN | \n",
449 | " NaN | \n",
450 | " NaN | \n",
451 | " NaN | \n",
452 | "
\n",
453 | " \n",
454 | " | 3 | \n",
455 | " 201601 | \n",
456 | " Beijing | \n",
457 | " 110000 | \n",
458 | " 2370 | \n",
459 | " 3c974920a76ac9c1 | \n",
460 | " SUV | \n",
461 | " 408 | \n",
462 | " NaN | \n",
463 | " NaN | \n",
464 | " NaN | \n",
465 | " NaN | \n",
466 | " NaN | \n",
467 | " NaN | \n",
468 | " NaN | \n",
469 | " NaN | \n",
470 | " NaN | \n",
471 | " NaN | \n",
472 | " NaN | \n",
473 | "
\n",
474 | " \n",
475 | " | 4 | \n",
476 | " 201601 | \n",
477 | " Sichuan | \n",
478 | " 510000 | \n",
479 | " 3562 | \n",
480 | " 3c974920a76ac9c1 | \n",
481 | " SUV | \n",
482 | " 610 | \n",
483 | " NaN | \n",
484 | " NaN | \n",
485 | " NaN | \n",
486 | " NaN | \n",
487 | " NaN | \n",
488 | " NaN | \n",
489 | " NaN | \n",
490 | " NaN | \n",
491 | " NaN | \n",
492 | " NaN | \n",
493 | " NaN | \n",
494 | "
\n",
495 | " \n",
496 | "
\n",
497 | "
"
498 | ],
499 | "text/plain": [
500 | " Date province provinceId popularity model bodyType \\\n",
501 | "0 201601 Shanghai 310000 1479 3c974920a76ac9c1 SUV \n",
502 | "1 201601 Yunnan 530000 1594 3c974920a76ac9c1 SUV \n",
503 | "2 201601 Inner Mongolia 150000 1479 3c974920a76ac9c1 SUV \n",
504 | "3 201601 Beijing 110000 2370 3c974920a76ac9c1 SUV \n",
505 | "4 201601 Sichuan 510000 3562 3c974920a76ac9c1 SUV \n",
506 | "\n",
507 | " salesVolume salesVolume_lag1 salesVolume_lag2 salesVolume_lag3 \\\n",
508 | "0 292 NaN NaN NaN \n",
509 | "1 466 NaN NaN NaN \n",
510 | "2 257 NaN NaN NaN \n",
511 | "3 408 NaN NaN NaN \n",
512 | "4 610 NaN NaN NaN \n",
513 | "\n",
514 | " salesVolume_lag4 salesVolume_lag5 salesVolume_lag6 salesVolume_lag7 \\\n",
515 | "0 NaN NaN NaN NaN \n",
516 | "1 NaN NaN NaN NaN \n",
517 | "2 NaN NaN NaN NaN \n",
518 | "3 NaN NaN NaN NaN \n",
519 | "4 NaN NaN NaN NaN \n",
520 | "\n",
521 | " salesVolume_lag8 salesVolume_lag9 salesVolume_lag10 salesVolume_lag11 \n",
522 | "0 NaN NaN NaN NaN \n",
523 | "1 NaN NaN NaN NaN \n",
524 | "2 NaN NaN NaN NaN \n",
525 | "3 NaN NaN NaN NaN \n",
526 | "4 NaN NaN NaN NaN "
527 | ]
528 | },
529 | "execution_count": 6,
530 | "metadata": {},
531 | "output_type": "execute_result"
532 | }
533 | ],
534 | "source": [
535 | "feature_columns = []\n",
536 | "\n",
537 | "data = add_lagging_feature(data, groupby_column=[\"provinceId\", \"model\"], value_columns=[\"salesVolume\"], lags=list(range(1, 12)), feature_columns=feature_columns)\n",
538 | "\n",
539 | "data.head()"
540 | ]
541 | },
542 | {
543 | "cell_type": "code",
544 | "execution_count": 7,
545 | "id": "457d453e",
546 | "metadata": {
547 | "execution": {
548 | "iopub.execute_input": "2025-05-16T02:50:09.745696Z",
549 | "iopub.status.busy": "2025-05-16T02:50:09.745164Z",
550 | "iopub.status.idle": "2025-05-16T02:50:11.128034Z",
551 | "shell.execute_reply": "2025-05-16T02:50:11.127244Z"
552 | },
553 | "papermill": {
554 | "duration": 1.3884,
555 | "end_time": "2025-05-16T02:50:11.129283",
556 | "exception": false,
557 | "start_time": "2025-05-16T02:50:09.740883",
558 | "status": "completed"
559 | },
560 | "tags": []
561 | },
562 | "outputs": [
563 | {
564 | "name": "stderr",
565 | "output_type": "stream",
566 | "text": [
567 | "/tmp/ipykernel_19/3102104721.py:1: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
568 | " grouped_sequence = data.groupby([\"provinceId\", \"model\"]).apply(\n"
569 | ]
570 | },
571 | {
572 | "data": {
573 | "text/plain": [
574 | "array([[[ 799., nan, nan, nan],\n",
575 | " [ 424., 799., nan, nan],\n",
576 | " [ 733., 424., 799., nan],\n",
577 | " ...,\n",
578 | " [ 544., 659., 630., 670.],\n",
579 | " [ 647., 544., 659., 630.],\n",
580 | " [ 640., 647., 544., 659.]],\n",
581 | "\n",
582 | " [[ 135., nan, nan, nan],\n",
583 | " [ 57., 135., nan, nan],\n",
584 | " [ 160., 57., 135., nan],\n",
585 | " ...,\n",
586 | " [ 105., 201., 120., 135.],\n",
587 | " [ 148., 105., 201., 120.],\n",
588 | " [ 112., 148., 105., 201.]],\n",
589 | "\n",
590 | " [[ 872., nan, nan, nan],\n",
591 | " [ 197., 872., nan, nan],\n",
592 | " [ 494., 197., 872., nan],\n",
593 | " ...,\n",
594 | " [ 152., 170., 181., 159.],\n",
595 | " [ 213., 152., 170., 181.],\n",
596 | " [ 226., 213., 152., 170.]],\n",
597 | "\n",
598 | " ...,\n",
599 | "\n",
600 | " [[ 181., nan, nan, nan],\n",
601 | " [ 60., 181., nan, nan],\n",
602 | " [ 111., 60., 181., nan],\n",
603 | " ...,\n",
604 | " [ 330., 297., 252., 199.],\n",
605 | " [ 178., 330., 297., 252.],\n",
606 | " [ 185., 178., 330., 297.]],\n",
607 | "\n",
608 | " [[1023., nan, nan, nan],\n",
609 | " [ 517., 1023., nan, nan],\n",
610 | " [ 513., 517., 1023., nan],\n",
611 | " ...,\n",
612 | " [1110., 991., 975., 798.],\n",
613 | " [ 967., 1110., 991., 975.],\n",
614 | " [1581., 967., 1110., 991.]],\n",
615 | "\n",
616 | " [[ 170., nan, nan, nan],\n",
617 | " [ 37., 170., nan, nan],\n",
618 | " [ 124., 37., 170., nan],\n",
619 | " ...,\n",
620 | " [ 229., 236., 208., 749.],\n",
621 | " [ 240., 229., 236., 208.],\n",
622 | " [ 337., 240., 229., 236.]]])"
623 | ]
624 | },
625 | "execution_count": 7,
626 | "metadata": {},
627 | "output_type": "execute_result"
628 | }
629 | ],
630 | "source": [
631 | "grouped_sequence = data.groupby([\"provinceId\", \"model\"]).apply(\n",
632 | " lambda x: x.sort_values('Date')[[\"salesVolume\", \"salesVolume_lag1\", \"salesVolume_lag2\", \"salesVolume_lag3\"]].to_numpy()\n",
633 | ")\n",
634 | "\n",
635 | "data_3d = np.stack(grouped_sequence.values)\n",
636 | "\n",
637 | "data_3d"
638 | ]
639 | },
640 | {
641 | "cell_type": "code",
642 | "execution_count": 8,
643 | "id": "b141753f",
644 | "metadata": {
645 | "execution": {
646 | "iopub.execute_input": "2025-05-16T02:50:11.136255Z",
647 | "iopub.status.busy": "2025-05-16T02:50:11.135797Z",
648 | "iopub.status.idle": "2025-05-16T02:50:11.144981Z",
649 | "shell.execute_reply": "2025-05-16T02:50:11.144469Z"
650 | },
651 | "papermill": {
652 | "duration": 0.013747,
653 | "end_time": "2025-05-16T02:50:11.146090",
654 | "exception": false,
655 | "start_time": "2025-05-16T02:50:11.132343",
656 | "status": "completed"
657 | },
658 | "tags": []
659 | },
660 | "outputs": [],
661 | "source": [
662 | "from tensorflow.keras.utils import Sequence\n",
663 | "\n",
664 | "\n",
665 | "class TimeDataset(Sequence):\n",
666 | " def __init__(self, data, train_sequence_length, predict_sequence_length, batch_size: int = 64):\n",
667 | " self.data = data\n",
668 | " self.train_seq_len = train_sequence_length\n",
669 | " self.pred_seq_len = predict_sequence_length\n",
670 | " self.batch_size = batch_size\n",
671 | "\n",
672 | " self.num_ids = data.shape[0]\n",
673 | " self.max_seq_len = data.shape[1]\n",
674 | " self.feature_dim = data.shape[2]\n",
675 | "\n",
676 | " self.samples_per_id = self.max_seq_len - self.train_seq_len - self.pred_seq_len + 1\n",
677 | " self.total_samples = self.num_ids * self.samples_per_id\n",
678 | "\n",
679 | " # Precompute all valid (id, start_idx) pairs\n",
680 | " self.indices = [\n",
681 | " (i, j)\n",
682 | " for i in range(self.num_ids)\n",
683 | " for j in range(self.samples_per_id)\n",
684 | " ]\n",
685 | " \n",
686 | " def __getitem__(self, index):\n",
687 | " # batch-wise item \n",
688 | " batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]\n",
689 | " \n",
690 | " x_batch = []\n",
691 | " y_batch = []\n",
692 | "\n",
693 | " for id_idx, start_idx in batch_indices:\n",
694 | " x = self.data[id_idx, start_idx:start_idx + self.train_seq_len, 1:]\n",
695 | " y = self.data[id_idx, start_idx + self.train_seq_len:start_idx + self.train_seq_len + self.pred_seq_len, 0]\n",
696 | " x_batch.append(x)\n",
697 | " y_batch.append(y)\n",
698 | "\n",
699 | " return np.nan_to_num(np.array(x_batch)), np.nan_to_num(np.array(y_batch))\n",
700 | " \n",
701 | " def __len__(self):\n",
702 | " # depends on how many samples you want to extract from 1 ID\n",
703 | " return int(np.ceil(len(self.indices) / self.batch_size))"
704 | ]
705 | },
706 | {
707 | "cell_type": "code",
708 | "execution_count": 9,
709 | "id": "ac417c3d",
710 | "metadata": {
711 | "execution": {
712 | "iopub.execute_input": "2025-05-16T02:50:11.153097Z",
713 | "iopub.status.busy": "2025-05-16T02:50:11.152458Z",
714 | "iopub.status.idle": "2025-05-16T02:50:11.162165Z",
715 | "shell.execute_reply": "2025-05-16T02:50:11.161518Z"
716 | },
717 | "papermill": {
718 | "duration": 0.014177,
719 | "end_time": "2025-05-16T02:50:11.163314",
720 | "exception": false,
721 | "start_time": "2025-05-16T02:50:11.149137",
722 | "status": "completed"
723 | },
724 | "tags": []
725 | },
726 | "outputs": [
727 | {
728 | "name": "stdout",
729 | "output_type": "stream",
730 | "text": [
731 | "(64, 12, 3)\n",
732 | "(64, 3)\n"
733 | ]
734 | }
735 | ],
736 | "source": [
737 | "train_dataset = TimeDataset(data_3d, CFG.train_sequence_length, CFG.predict_sequence_length)\n",
738 | "valid_dataset = TimeDataset(data_3d, CFG.train_sequence_length, CFG.predict_sequence_length)\n",
739 | "\n",
740 | "print(train_dataset[0][0].shape)\n",
741 | "print(train_dataset[0][1].shape)"
742 | ]
743 | },
744 | {
745 | "cell_type": "code",
746 | "execution_count": 10,
747 | "id": "6a3cc9ee",
748 | "metadata": {
749 | "execution": {
750 | "iopub.execute_input": "2025-05-16T02:50:11.170327Z",
751 | "iopub.status.busy": "2025-05-16T02:50:11.169583Z",
752 | "iopub.status.idle": "2025-05-16T02:50:12.527408Z",
753 | "shell.execute_reply": "2025-05-16T02:50:12.526739Z"
754 | },
755 | "papermill": {
756 | "duration": 1.36238,
757 | "end_time": "2025-05-16T02:50:12.528531",
758 | "exception": false,
759 | "start_time": "2025-05-16T02:50:11.166151",
760 | "status": "completed"
761 | },
762 | "tags": []
763 | },
764 | "outputs": [
765 | {
766 | "data": {
767 | "text/html": [
768 | "Model: \"functional\"\n",
769 | "
\n"
770 | ],
771 | "text/plain": [
772 | "\u001b[1mModel: \"functional\"\u001b[0m\n"
773 | ]
774 | },
775 | "metadata": {},
776 | "output_type": "display_data"
777 | },
778 | {
779 | "data": {
780 | "text/html": [
781 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
782 | "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
783 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
784 | "│ input_layer (InputLayer) │ (None, 12, 3) │ 0 │\n",
785 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
786 | "│ encoder (Encoder) │ [(None, 12, 64), (None, │ 0 │\n",
787 | "│ │ 128)] │ │\n",
788 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
789 | "│ dense (Dense) │ (None, 128) │ 16,512 │\n",
790 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
791 | "│ dense_1 (Dense) │ (None, 128) │ 16,512 │\n",
792 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
793 | "│ dense_2 (Dense) │ (None, 1) │ 129 │\n",
794 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
795 | "│ reshape (Reshape) │ (None, 1, 1) │ 0 │\n",
796 | "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n",
797 | "
\n"
798 | ],
799 | "text/plain": [
800 | "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
801 | "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
802 | "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
803 | "│ input_layer (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
804 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
805 | "│ encoder (\u001b[38;5;33mEncoder\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m, \u001b[38;5;34m64\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, │ \u001b[38;5;34m0\u001b[0m │\n",
806 | "│ │ \u001b[38;5;34m128\u001b[0m)] │ │\n",
807 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
808 | "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n",
809 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
810 | "│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m16,512\u001b[0m │\n",
811 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
812 | "│ dense_2 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m129\u001b[0m │\n",
813 | "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
814 | "│ reshape (\u001b[38;5;33mReshape\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n",
815 | "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n"
816 | ]
817 | },
818 | "metadata": {},
819 | "output_type": "display_data"
820 | },
821 | {
822 | "data": {
823 | "text/html": [
824 | " Total params: 33,153 (129.50 KB)\n",
825 | "
\n"
826 | ],
827 | "text/plain": [
828 | "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m33,153\u001b[0m (129.50 KB)\n"
829 | ]
830 | },
831 | "metadata": {},
832 | "output_type": "display_data"
833 | },
834 | {
835 | "data": {
836 | "text/html": [
837 | " Trainable params: 33,153 (129.50 KB)\n",
838 | "
\n"
839 | ],
840 | "text/plain": [
841 | "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m33,153\u001b[0m (129.50 KB)\n"
842 | ]
843 | },
844 | "metadata": {},
845 | "output_type": "display_data"
846 | },
847 | {
848 | "data": {
849 | "text/html": [
850 | " Non-trainable params: 0 (0.00 B)\n",
851 | "
\n"
852 | ],
853 | "text/plain": [
854 | "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
855 | ]
856 | },
857 | "metadata": {},
858 | "output_type": "display_data"
859 | }
860 | ],
861 | "source": [
862 | "def build_model():\n",
863 | " inputs = tf.keras.Input(shape=(CFG.train_sequence_length, 3))\n",
864 | " \n",
865 | " config = AutoConfig()(\"rnn\")\n",
866 | " config.rnn_type = \"lstm\"\n",
867 | " backbone = AutoModel.from_config(config=config)\n",
868 | " \n",
869 | " outputs = backbone(inputs)\n",
870 | " model = tf.keras.Model(inputs=inputs, outputs=outputs)\n",
871 | " model.compile(loss=tf.keras.losses.MeanAbsoluteError(), optimizer=tf.keras.optimizers.Adam(), metrics = ['mae'])\n",
872 | " return model\n",
873 | "\n",
874 | "\n",
875 | "model = build_model()\n",
876 | "model.summary()"
877 | ]
878 | },
879 | {
880 | "cell_type": "code",
881 | "execution_count": 11,
882 | "id": "d13f0ce1",
883 | "metadata": {
884 | "execution": {
885 | "iopub.execute_input": "2025-05-16T02:50:12.536919Z",
886 | "iopub.status.busy": "2025-05-16T02:50:12.536488Z",
887 | "iopub.status.idle": "2025-05-16T02:50:40.889410Z",
888 | "shell.execute_reply": "2025-05-16T02:50:40.888816Z"
889 | },
890 | "papermill": {
891 | "duration": 28.358352,
892 | "end_time": "2025-05-16T02:50:40.890731",
893 | "exception": false,
894 | "start_time": "2025-05-16T02:50:12.532379",
895 | "status": "completed"
896 | },
897 | "tags": []
898 | },
899 | "outputs": [
900 | {
901 | "name": "stdout",
902 | "output_type": "stream",
903 | "text": [
904 | "Epoch 1/10\n"
905 | ]
906 | },
907 | {
908 | "name": "stderr",
909 | "output_type": "stream",
910 | "text": [
911 | "/usr/local/lib/python3.11/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.\n",
912 | " self._warn_if_super_not_called()\n",
913 | "I0000 00:00:1747363815.790123 65 cuda_dnn.cc:529] Loaded cuDNN version 90300\n"
914 | ]
915 | },
916 | {
917 | "name": "stdout",
918 | "output_type": "stream",
919 | "text": [
920 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 9ms/step - loss: 324.2088 - mae: 324.2088 - val_loss: 197.0513 - val_mae: 197.0513\n",
921 | "Epoch 2/10\n",
922 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 211.5362 - mae: 211.5362 - val_loss: 187.4890 - val_mae: 187.4890\n",
923 | "Epoch 3/10\n",
924 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 9ms/step - loss: 174.5251 - mae: 174.5251 - val_loss: 162.3955 - val_mae: 162.3955\n",
925 | "Epoch 4/10\n",
926 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 9ms/step - loss: 172.6773 - mae: 172.6773 - val_loss: 149.7986 - val_mae: 149.7986\n",
927 | "Epoch 5/10\n",
928 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 168.0907 - mae: 168.0907 - val_loss: 148.2648 - val_mae: 148.2648\n",
929 | "Epoch 6/10\n",
930 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 9ms/step - loss: 163.7792 - mae: 163.7792 - val_loss: 157.9741 - val_mae: 157.9741\n",
931 | "Epoch 7/10\n",
932 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 166.4666 - mae: 166.4666 - val_loss: 166.7117 - val_mae: 166.7117\n",
933 | "Epoch 8/10\n",
934 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 150.0828 - mae: 150.0828 - val_loss: 159.2332 - val_mae: 159.2332\n",
935 | "Epoch 9/10\n",
936 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 140.8101 - mae: 140.8101 - val_loss: 187.0695 - val_mae: 187.0695\n",
937 | "Epoch 10/10\n",
938 | "\u001b[1m282/282\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 8ms/step - loss: 154.2695 - mae: 154.2695 - val_loss: 137.9109 - val_mae: 137.9109\n"
939 | ]
940 | }
941 | ],
942 | "source": [
943 | "history = model.fit(train_dataset, validation_data=valid_dataset, epochs=10) \n",
944 | "model.save_weights('./sales_model.weights.h5')"
945 | ]
946 | }
947 | ],
948 | "metadata": {
949 | "kaggle": {
950 | "accelerator": "gpu",
951 | "dataSources": [
952 | {
953 | "datasetId": 7421883,
954 | "sourceId": 11816396,
955 | "sourceType": "datasetVersion"
956 | }
957 | ],
958 | "dockerImageVersionId": 31041,
959 | "isGpuEnabled": true,
960 | "isInternetEnabled": true,
961 | "language": "python",
962 | "sourceType": "notebook"
963 | },
964 | "kernelspec": {
965 | "display_name": "Python 3",
966 | "language": "python",
967 | "name": "python3"
968 | },
969 | "language_info": {
970 | "codemirror_mode": {
971 | "name": "ipython",
972 | "version": 3
973 | },
974 | "file_extension": ".py",
975 | "mimetype": "text/x-python",
976 | "name": "python",
977 | "nbconvert_exporter": "python",
978 | "pygments_lexer": "ipython3",
979 | "version": "3.11.11"
980 | },
981 | "papermill": {
982 | "default_parameters": {},
983 | "duration": 60.266353,
984 | "end_time": "2025-05-16T02:50:44.081675",
985 | "environment_variables": {},
986 | "exception": null,
987 | "input_path": "__notebook__.ipynb",
988 | "output_path": "__notebook__.ipynb",
989 | "parameters": {},
990 | "start_time": "2025-05-16T02:49:43.815322",
991 | "version": "2.6.0"
992 | }
993 | },
994 | "nbformat": 4,
995 | "nbformat_minor": 5
996 | }
997 |
--------------------------------------------------------------------------------