├── .gitignore ├── Multivariate_Long_Sequence_Time-Series_Forecasting_with_SOTA_Transformers.pdf ├── README.md ├── transformer_timeseries.yml └── patch-tst └── patchTST_WTH.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .ipynb_checkpoints/ 3 | *.csv 4 | *.xls 5 | *.ppt 6 | *.pth 7 | wandb 8 | tsai 9 | PatchTST 10 | *.png -------------------------------------------------------------------------------- /Multivariate_Long_Sequence_Time-Series_Forecasting_with_SOTA_Transformers.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neelblabla/transformers_for_time_series_forecasting/HEAD/Multivariate_Long_Sequence_Time-Series_Forecasting_with_SOTA_Transformers.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multivariate Long Sequence Time-Series Forecasting with SOTA Transformers 2 | 3 | We reproduce results from the following papers that introduced two special state-of-the-art transformer architectures for forecasting long sequence time-series data:- 4 | 5 | 1) PatchTST - "A Time Series is Worth 64 Words: Long-term Forecasting with Transformers" (https://doi.org/10.48550/arXiv.2211.14730) 6 | 2) Informer - "Expanding the prediction capacity in long sequence time-series forecasting" (https://doi.org/10.1016/j.artint.2023.103886) + (https://doi.org/10.48550/arXiv.2012.07436) 7 | 8 | Objective: To harness the power of Transformer models for multivariate time series forecasting with a focus on improved efficiency and accuracy. 9 | 10 | Details of the implementation are documented in the PDF report (Multivariate_Long_Sequence_Time-Series_Forecasting_with_SOTA_Transformers.pdf) made available in the repo. 11 | 12 | ### Results: 13 | The results obtained in our implementation are presented below aside to their counterparts listed in the PatchTST introduction paper of 2023. 14 | 15 | ![image](https://github.com/neelblabla/transformers_for_time_series_forecasting/assets/114079228/0ddc430c-a130-4a60-825d-0d1220a76e27) 16 | 17 | As evident, our model implementations managed to reproduce results in close proximity of the ones produced by the authors of ‘PatchTST’ and ‘Informer’ for some datasets. In the case of Informer, the average deviance from the results obtained in the paper only amounts to 0.13 for MSE and 0.07 for MAE; while in case of PatchTST this is about 0.132 for MSE and 0.126 for MAE. Our implementation of PatchTST relies on a scaled-down architecture compared to the one featured in the original paper, primarily due to computational constraints. An intriguing insight gained from our experiments is that an increase in model size to a certain value does not consistently lead to lower MSE and MAE values respectively. Instead, it suggests that further augmenting the model's size may be necessary to achieve improved performance, which had computational constraints in our case. 18 | 19 | Weights and Biases link for visualizations: 20 | https://wandb.ai/kirteshpatel98/transformer_timeseries 21 | 22 | ### Instruction for Demo: 23 | Before running the code please clone the OG time series repository into the patch-tst folder to import the model architecture. 24 | ``` 25 | git clone https://github.com/yuqinie98/PatchTST 26 | ``` 27 | 28 | Install dependent packages: 29 | ``` 30 | conda env create -n myenv -f transformer_timeseries.yml 31 | ``` 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /transformer_timeseries.yml: -------------------------------------------------------------------------------- 1 | name: tf_gpu 2 | channels: 3 | - intel 4 | - pytorch 5 | - nvidia 6 | - defaults 7 | dependencies: 8 | - _tflow_select=2.1.0=gpu 9 | - abseil-cpp=20210324.2=hd77b12b_0 10 | - absl-py=1.4.0=py39haa95532_0 11 | - aiohttp=3.8.5=py39h2bbff1b_0 12 | - aiosignal=1.2.0=pyhd3eb1b0_0 13 | - astor=0.8.1=py39haa95532_0 14 | - asttokens=2.0.5=pyhd3eb1b0_0 15 | - astunparse=1.6.3=py_0 16 | - async-timeout=4.0.2=py39haa95532_0 17 | - attrs=23.1.0=py39haa95532_0 18 | - backcall=0.2.0=pyhd3eb1b0_0 19 | - blas=1.0=mkl 20 | - blinker=1.6.2=py39haa95532_0 21 | - bottleneck=1.3.5=py39h080aedc_0 22 | - brotli=1.0.9=h2bbff1b_7 23 | - brotli-bin=1.0.9=h2bbff1b_7 24 | - brotli-python=1.0.9=py39hd77b12b_7 25 | - brotlipy=0.7.0=py39h2bbff1b_1003 26 | - ca-certificates=2023.08.22=haa95532_0 27 | - cachetools=4.2.2=pyhd3eb1b0_0 28 | - catalogue=2.0.7=py39haa95532_0 29 | - certifi=2023.7.22=py39haa95532_0 30 | - cffi=1.15.1=py39h2bbff1b_3 31 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 32 | - click=8.1.7=py39haa95532_0 33 | - colorama=0.4.6=py39haa95532_0 34 | - comm=0.1.2=py39haa95532_0 35 | - confection=0.0.4=py39hd4e2768_0 36 | - cryptography=41.0.3=py39h3438e0d_0 37 | - cuda-cccl=12.1.109=0 38 | - cuda-cudart=11.8.89=0 39 | - cuda-cudart-dev=11.8.89=0 40 | - cuda-cupti=11.8.87=0 41 | - cuda-libraries=11.8.0=0 42 | - cuda-libraries-dev=11.8.0=0 43 | - cuda-nvrtc=11.8.89=0 44 | - cuda-nvrtc-dev=11.8.89=0 45 | - cuda-nvtx=11.8.86=0 46 | - cuda-profiler-api=12.1.105=0 47 | - cuda-runtime=11.8.0=0 48 | - cudatoolkit=11.3.1=h59b6b97_2 49 | - cudnn=8.2.1=cuda11.3_0 50 | - cymem=2.0.6=py39hd77b12b_0 51 | - cython-blis=0.7.9=py39h080aedc_0 52 | - daal4py=2023.1.1=py39h757b272_0 53 | - dal=2023.1.1=h59b6b97_48681 54 | - debugpy=1.6.7=py39hd77b12b_0 55 | - decorator=5.1.1=pyhd3eb1b0_0 56 | - et_xmlfile=1.1.0=py39haa95532_0 57 | - exceptiongroup=1.0.4=py39haa95532_0 58 | - executing=0.8.3=pyhd3eb1b0_0 59 | - filelock=3.9.0=py39haa95532_0 60 | - flatbuffers=2.0.0=h6c2663c_0 61 | - freetype=2.12.1=ha860e81_0 62 | - frozenlist=1.4.0=py39h2bbff1b_0 63 | - gast=0.4.0=pyhd3eb1b0_0 64 | - giflib=5.2.1=h8cc25b3_3 65 | - google-auth=2.22.0=py39haa95532_0 66 | - google-auth-oauthlib=0.4.1=py_2 67 | - google-pasta=0.2.0=pyhd3eb1b0_0 68 | - graphviz=2.38=hfd603c8_2 69 | - h5py=3.9.0=py39hfc34f40_0 70 | - hdf5=1.12.1=h51c971a_3 71 | - icc_rt=2022.1.0=h6049295_2 72 | - icu=68.1=h6c2663c_0 73 | - importlib-metadata=6.0.0=py39haa95532_0 74 | - importlib_metadata=6.0.0=hd3eb1b0_0 75 | - intel-openmp=2023.1.0=h59b6b97_46319 76 | - jedi=0.18.1=py39haa95532_1 77 | - jinja2=3.1.2=py39haa95532_0 78 | - joblib=1.2.0=py39haa95532_0 79 | - jpeg=9e=h2bbff1b_1 80 | - jupyter_client=8.1.0=py39haa95532_0 81 | - jupyter_core=5.5.0=py39haa95532_0 82 | - keras=2.6.0=pyhd3eb1b0_0 83 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 84 | - langcodes=3.3.0=pyhd3eb1b0_0 85 | - lerc=3.0=hd77b12b_0 86 | - libbrotlicommon=1.0.9=h2bbff1b_7 87 | - libbrotlidec=1.0.9=h2bbff1b_7 88 | - libbrotlienc=1.0.9=h2bbff1b_7 89 | - libcublas=11.11.3.6=0 90 | - libcublas-dev=11.11.3.6=0 91 | - libcufft=10.9.0.58=0 92 | - libcufft-dev=10.9.0.58=0 93 | - libcurand=10.3.2.106=0 94 | - libcurand-dev=10.3.2.106=0 95 | - libcurl=8.4.0=h86230a5_0 96 | - libcusolver=11.4.1.48=0 97 | - libcusolver-dev=11.4.1.48=0 98 | - libcusparse=11.7.5.86=0 99 | - libcusparse-dev=11.7.5.86=0 100 | - libdeflate=1.17=h2bbff1b_1 101 | - libnpp=11.8.0.86=0 102 | - libnpp-dev=11.8.0.86=0 103 | - libnvjpeg=11.9.0.86=0 104 | - libnvjpeg-dev=11.9.0.86=0 105 | - libpng=1.6.39=h8cc25b3_0 106 | - libprotobuf=3.17.2=h23ce68f_1 107 | - libsodium=1.0.18=h62dcd97_0 108 | - libssh2=1.10.0=hcd4344a_2 109 | - libtiff=4.5.1=hd77b12b_0 110 | - libuv=1.44.2=h2bbff1b_0 111 | - libwebp=1.3.2=hbc33d0d_0 112 | - libwebp-base=1.3.2=h2bbff1b_0 113 | - lz4-c=1.9.4=h2bbff1b_0 114 | - markdown=3.4.1=py39haa95532_0 115 | - markupsafe=2.1.1=py39h2bbff1b_0 116 | - mkl=2023.1.0=h6b88ed4_46357 117 | - mkl-service=2.4.0=py39h2bbff1b_1 118 | - mkl_fft=1.3.8=py39h2bbff1b_0 119 | - mkl_random=1.2.4=py39h59b6b97_0 120 | - mpmath=1.3.0=py39haa95532_0 121 | - multidict=6.0.2=py39h2bbff1b_0 122 | - munkres=1.1.4=py_0 123 | - murmurhash=1.0.7=py39hd77b12b_0 124 | - nest-asyncio=1.5.6=py39haa95532_0 125 | - networkx=3.1=py39haa95532_0 126 | - nltk=3.8.1=py39haa95532_0 127 | - numexpr=2.8.7=py39h2cd9be0_0 128 | - numpy-base=1.26.0=py39h65a83cf_0 129 | - oauthlib=3.2.2=py39haa95532_0 130 | - openjpeg=2.4.0=h4fc8c34_0 131 | - openssl=1.1.1w=h2bbff1b_0 132 | - opt_einsum=3.3.0=pyhd3eb1b0_1 133 | - pandas=1.5.2=py39hf11a4ad_0 134 | - pandas-stubs=1.5.3.230203=py39haa95532_0 135 | - parso=0.8.3=pyhd3eb1b0_0 136 | - pathy=0.10.1=py39haa95532_0 137 | - pickleshare=0.7.5=pyhd3eb1b0_1003 138 | - platformdirs=3.10.0=py39haa95532_0 139 | - plotly=5.9.0=py39haa95532_0 140 | - preshed=3.0.6=py39h6c2663c_0 141 | - prompt-toolkit=3.0.36=py39haa95532_0 142 | - psutil=5.9.0=py39h2bbff1b_0 143 | - pure_eval=0.2.2=pyhd3eb1b0_0 144 | - pyasn1=0.4.8=pyhd3eb1b0_0 145 | - pyasn1-modules=0.2.8=py_0 146 | - pycparser=2.21=pyhd3eb1b0_0 147 | - pydantic=1.10.12=py39h2bbff1b_1 148 | - pygments=2.15.1=py39haa95532_1 149 | - pyjwt=2.4.0=py39haa95532_0 150 | - pyopenssl=23.2.0=py39haa95532_0 151 | - pysocks=1.7.1=py39haa95532_0 152 | - python=3.9.18=h6244533_0 153 | - python-dateutil=2.8.2=pyhd3eb1b0_0 154 | - python-flatbuffers=1.12=pyhd3eb1b0_0 155 | - pytorch-cuda=11.8=h24eeafa_5 156 | - pytorch-mutex=1.0=cuda 157 | - pytz=2023.3.post1=py39haa95532_0 158 | - pywin32=305=py39h2bbff1b_0 159 | - pyzmq=25.1.0=py39hd77b12b_0 160 | - requests=2.31.0=py39haa95532_0 161 | - requests-oauthlib=1.3.0=py_0 162 | - rsa=4.7.2=pyhd3eb1b0_1 163 | - scikit-learn=1.2.1=py39hd77b12b_0 164 | - scikit-learn-intelex=2023.1.1=py39_intel_48681 165 | - scipy=1.11.3=py39h309d312_0 166 | - setuptools=68.0.0=py39haa95532_0 167 | - shellingham=1.5.0=py39haa95532_0 168 | - six=1.16.0=pyhd3eb1b0_1 169 | - smart_open=5.2.1=py39haa95532_0 170 | - snappy=1.1.9=h6c2663c_0 171 | - spacy=3.5.3=py39hef0f399_0 172 | - spacy-legacy=3.0.12=py39haa95532_0 173 | - spacy-loggers=1.0.4=py39haa95532_0 174 | - sqlite=3.41.2=h2bbff1b_0 175 | - srsly=2.4.8=py39hd77b12b_0 176 | - stack_data=0.2.0=pyhd3eb1b0_0 177 | - sympy=1.11.1=py39haa95532_0 178 | - tbb=2021.9.0=vc14_intel_43574 179 | - tenacity=8.2.2=py39haa95532_0 180 | - tensorboard-data-server=0.6.1=py39haa95532_0 181 | - tensorboard-plugin-wit=1.8.1=py39haa95532_0 182 | - tensorflow-base=2.6.0=gpu_py39hb3da07e_0 183 | - tensorflow-gpu=2.6.0=h17022bd_0 184 | - termcolor=2.1.0=py39haa95532_0 185 | - thinc=8.1.10=py39hf497b98_0 186 | - threadpoolctl=2.2.0=pyh0d69192_0 187 | - tk=8.6.12=h2bbff1b_0 188 | - tornado=6.3.3=py39h2bbff1b_0 189 | - tqdm=4.65.0=py39hd4e2768_0 190 | - traitlets=5.7.1=py39haa95532_0 191 | - typer=0.4.1=py39haa95532_0 192 | - types-pytz=2022.4.0.0=py39haa95532_1 193 | - typing-extensions=4.7.1=py39haa95532_0 194 | - typing_extensions=4.7.1=py39haa95532_0 195 | - tzdata=2023c=h04d1e81_0 196 | - urllib3=1.26.18=py39haa95532_0 197 | - vc=14.2=h21ff451_1 198 | - vs2015_runtime=14.27.29016=h5e58377_2 199 | - wasabi=0.9.1=py39haa95532_0 200 | - wcwidth=0.2.5=pyhd3eb1b0_0 201 | - werkzeug=2.2.3=py39haa95532_0 202 | - wheel=0.35.1=pyhd3eb1b0_0 203 | - win_inet_pton=1.1.0=py39haa95532_0 204 | - wrapt=1.14.1=py39h2bbff1b_0 205 | - xz=5.4.2=h8cc25b3_0 206 | - yarl=1.8.1=py39h2bbff1b_0 207 | - zeromq=4.3.4=hd77b12b_0 208 | - zlib=1.2.13=h8cc25b3_0 209 | - zstd=1.5.5=hd43e919_0 210 | - pip: 211 | - alembic==1.12.0 212 | - anyio==4.0.0 213 | - appdirs==1.4.4 214 | - bardapi==0.1.34 215 | - beautifulsoup4==4.11.2 216 | - bokeh==3.1.0 217 | - browser-cookie3==0.19.1 218 | - chardet==3.0.4 219 | - cloudpickle==1.3.0 220 | - colorlog==6.7.0 221 | - contourpy==1.1.1 222 | - cycler==0.12.1 223 | - deep-translator==1.11.4 224 | - dm-tree==0.1.8 225 | - docker-pycreds==0.4.0 226 | - en-core-web-sm==3.5.0 227 | - et-xmlfile==1.1.0 228 | - execnb==0.1.5 229 | - fastai==2.7.13 230 | - fastcore==1.5.29 231 | - fastdownload==0.0.7 232 | - fastprogress==1.0.3 233 | - fonttools==4.43.1 234 | - frozendict==2.3.5 235 | - fsspec==2023.9.1 236 | - ghapi==1.0.4 237 | - gin-config==0.5.0 238 | - gitdb==4.0.11 239 | - gitpython==3.1.40 240 | - google-api-core==2.12.0 241 | - google-cloud-core==2.3.3 242 | - google-cloud-translate==3.12.0 243 | - googleapis-common-protos==1.60.0 244 | - googlebard==2.1.0 245 | - googletrans==4.0.0rc1 246 | - greenlet==3.0.0 247 | - grpcio==1.58.0 248 | - grpcio-status==1.58.0 249 | - h11==0.9.0 250 | - h2==3.2.0 251 | - hpack==3.0.0 252 | - hstspreload==2023.1.1 253 | - html5lib==1.1 254 | - httpcore==0.9.1 255 | - httpx==0.13.3 256 | - huggingface-hub==0.17.1 257 | - hyperframe==5.2.0 258 | - idna==2.10 259 | - imbalanced-learn==0.11.0 260 | - importlib-resources==6.1.0 261 | - ipykernel==6.25.2 262 | - ipython==8.16.1 263 | - ipywidgets==8.0.4 264 | - jsonschema==4.19.1 265 | - jsonschema-specifications==2023.7.1 266 | - jupyterlab-widgets==3.0.9 267 | - kiwisolver==1.4.5 268 | - langdetect==1.0.9 269 | - libclang==15.0.6.1 270 | - lightning-utilities==0.9.0 271 | - livelossplot==0.5.5 272 | - llvmlite==0.41.1 273 | - lxml==4.9.2 274 | - lz4==4.3.2 275 | - mako==1.2.4 276 | - markdown-it-py==3.0.0 277 | - matplotlib==3.7.0 278 | - matplotlib-inline==0.1.6 279 | - mdurl==0.1.2 280 | - msgpack==1.0.7 281 | - multitasking==0.0.11 282 | - nbdev==2.3.13 283 | - neuralforecast==1.6.4 284 | - numba==0.58.1 285 | - numpy==1.26.1 286 | - openpyxl==3.1.2 287 | - optuna==3.4.0 288 | - packaging==23.2 289 | - pathtools==0.1.2 290 | - patsy==0.5.3 291 | - pillow==10.1.0 292 | - pip==23.3.1 293 | - plot-model==0.20 294 | - proto-plus==1.22.3 295 | - protobuf==3.20.3 296 | - pyarrow==6.0.1 297 | - pycryptodomex==3.19.0 298 | - pycuda==2022.2.2 299 | - pydot==1.4.2 300 | - pydot-ng==2.0.0 301 | - pyparsing==3.1.1 302 | - python-graphviz==0.20.1 303 | - pytools==2023.1 304 | - pytorch-lightning==2.1.0 305 | - pyts==0.13.0 306 | - pyyaml==6.0 307 | - ray==2.7.1 308 | - referencing==0.30.2 309 | - regex==2023.8.8 310 | - rfc3986==1.5.0 311 | - rich==13.5.3 312 | - rpds-py==0.10.6 313 | - safetensors==0.3.3 314 | - seaborn==0.13.0 315 | - sentencepiece==0.1.99 316 | - sentry-sdk==1.32.0 317 | - setproctitle==1.3.3 318 | - smmap==5.0.1 319 | - sniffio==1.3.0 320 | - socksio==1.0.0 321 | - soupsieve==2.4 322 | - sqlalchemy==2.0.22 323 | - statsmodels==0.13.5 324 | - tensorboard==2.11.2 325 | - tensorboardx==2.6.2.2 326 | - tensorflow==2.11.0 327 | - tensorflow-estimator==2.11.0 328 | - tensorflow-intel==2.11.0 329 | - tensorflow-io-gcs-filesystem==0.31.0 330 | - tensorflow-probability==0.20.0 331 | - textblob==0.17.1 332 | - tf-agents==0.6.0 333 | - tokenizers==0.13.3 334 | - torch==2.0.1+cu118 335 | - torchaudio==2.0.2+cu118 336 | - torchdata==0.6.1 337 | - torchmetrics==1.2.0 338 | - torchsummary==1.5.1 339 | - torchtext==0.15.2 340 | - torchvision==0.15.2+cu118 341 | - transformers==4.33.2 342 | - utilsforecast==0.0.12 343 | - wandb==0.15.12 344 | - watchdog==3.0.0 345 | - webencodings==0.5.1 346 | - widgetsnbextension==4.0.9 347 | - xyzservices==2023.2.0 348 | - yfinance==0.2.12 349 | - zipp==3.17.0 350 | prefix: C:\Users\hp\miniconda3\envs\tf_gpu 351 | -------------------------------------------------------------------------------- /patch-tst/patchTST_WTH.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f5fc39de", 6 | "metadata": {}, 7 | "source": [ 8 | "### Multivariate Time Series Forecasting by application of PatchTST (Transformer) on WTH Dataset\n", 9 | "-Lookback window = 336\n", 10 | "\n", 11 | "-Forecast window = 96" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 209, 17 | "id": "ed103c66", 18 | "metadata": { 19 | "colab": { 20 | "base_uri": "https://localhost:8080/" 21 | }, 22 | "id": "ed103c66", 23 | "outputId": "6d983b40-6a98-48fd-c1ef-875839b80915" 24 | }, 25 | "outputs": [ 26 | { 27 | "name": "stdout", 28 | "output_type": "stream", 29 | "text": [ 30 | "2.0.1+cu118\n" 31 | ] 32 | } 33 | ], 34 | "source": [ 35 | "import torch\n", 36 | "from torch.utils.data import Dataset\n", 37 | "\n", 38 | "import seaborn as sns\n", 39 | "import matplotlib.pyplot as plt\n", 40 | "import numpy as np\n", 41 | "import pandas as pd\n", 42 | "# !pip install wandb\n", 43 | "import wandb\n", 44 | "import random\n", 45 | "from tqdm import tqdm\n", 46 | "\n", 47 | "from sklearn.preprocessing import StandardScaler\n", 48 | "print(torch.__version__)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 210, 54 | "id": "3855ec0c", 55 | "metadata": { 56 | "colab": { 57 | "base_uri": "https://localhost:8080/", 58 | "height": 471 59 | }, 60 | "id": "3855ec0c", 61 | "outputId": "caf23e54-39a1-4de5-8a08-40a6b7c38400" 62 | }, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/html": [ 67 | "Finishing last run (ID:6fm8pdr3) before initializing another..." 68 | ], 69 | "text/plain": [ 70 | "" 71 | ] 72 | }, 73 | "metadata": {}, 74 | "output_type": "display_data" 75 | }, 76 | { 77 | "data": { 78 | "text/html": [ 79 | "Waiting for W&B process to finish... (success)." 80 | ], 81 | "text/plain": [ 82 | "" 83 | ] 84 | }, 85 | "metadata": {}, 86 | "output_type": "display_data" 87 | }, 88 | { 89 | "data": { 90 | "text/html": [ 91 | " View run decent-field-69 at: https://wandb.ai/kirteshpatel98/transformer_timeseries/runs/6fm8pdr3
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" 92 | ], 93 | "text/plain": [ 94 | "" 95 | ] 96 | }, 97 | "metadata": {}, 98 | "output_type": "display_data" 99 | }, 100 | { 101 | "data": { 102 | "text/html": [ 103 | "Find logs at: .\\wandb\\run-20231103_132403-6fm8pdr3\\logs" 104 | ], 105 | "text/plain": [ 106 | "" 107 | ] 108 | }, 109 | "metadata": {}, 110 | "output_type": "display_data" 111 | }, 112 | { 113 | "data": { 114 | "text/html": [ 115 | "Successfully finished last run (ID:6fm8pdr3). Initializing new run:
" 116 | ], 117 | "text/plain": [ 118 | "" 119 | ] 120 | }, 121 | "metadata": {}, 122 | "output_type": "display_data" 123 | }, 124 | { 125 | "data": { 126 | "application/vnd.jupyter.widget-view+json": { 127 | "model_id": "ad3adbc342b4413eae65a14dca105401", 128 | "version_major": 2, 129 | "version_minor": 0 130 | }, 131 | "text/plain": [ 132 | "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011111111111111112, max=1.0…" 133 | ] 134 | }, 135 | "metadata": {}, 136 | "output_type": "display_data" 137 | }, 138 | { 139 | "data": { 140 | "text/html": [ 141 | "Tracking run with wandb version 0.15.12" 142 | ], 143 | "text/plain": [ 144 | "" 145 | ] 146 | }, 147 | "metadata": {}, 148 | "output_type": "display_data" 149 | }, 150 | { 151 | "data": { 152 | "text/html": [ 153 | "Run data is saved locally in C:\\Users\\hp\\OneDrive - fs-students.de\\FS\\Sem 3\\Deep Learning\\transformers_for_time_series_forecasting\\patch-tst\\wandb\\run-20231103_132702-pnfd65sm" 154 | ], 155 | "text/plain": [ 156 | "" 157 | ] 158 | }, 159 | "metadata": {}, 160 | "output_type": "display_data" 161 | }, 162 | { 163 | "data": { 164 | "text/html": [ 165 | "Syncing run vocal-breeze-70 to Weights & Biases (docs)
" 166 | ], 167 | "text/plain": [ 168 | "" 169 | ] 170 | }, 171 | "metadata": {}, 172 | "output_type": "display_data" 173 | }, 174 | { 175 | "data": { 176 | "text/html": [ 177 | " View project at https://wandb.ai/kirteshpatel98/transformer_timeseries" 178 | ], 179 | "text/plain": [ 180 | "" 181 | ] 182 | }, 183 | "metadata": {}, 184 | "output_type": "display_data" 185 | }, 186 | { 187 | "data": { 188 | "text/html": [ 189 | " View run at https://wandb.ai/kirteshpatel98/transformer_timeseries/runs/pnfd65sm" 190 | ], 191 | "text/plain": [ 192 | "" 193 | ] 194 | }, 195 | "metadata": {}, 196 | "output_type": "display_data" 197 | }, 198 | { 199 | "data": { 200 | "text/html": [ 201 | "" 202 | ], 203 | "text/plain": [ 204 | "" 205 | ] 206 | }, 207 | "execution_count": 210, 208 | "metadata": {}, 209 | "output_type": "execute_result" 210 | } 211 | ], 212 | "source": [ 213 | "\n", 214 | "# start a new wandb run to track this script\n", 215 | "wandb.init(\n", 216 | " # set the wandb project where this run will be logged\n", 217 | " project=\"transformer_timeseries\",\n", 218 | ")\n" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 211, 224 | "id": "5522d111", 225 | "metadata": { 226 | "colab": { 227 | "base_uri": "https://localhost:8080/", 228 | "height": 634 229 | }, 230 | "id": "5522d111", 231 | "outputId": "e7970d24-0ff5-4827-aa0c-b9fb80ce9051" 232 | }, 233 | "outputs": [ 234 | { 235 | "data": { 236 | "text/plain": [ 237 | "\"\\nfrom google.colab import drive\\ndrive.mount('/content/drive')\\n\\nimport os\\nos.chdir('/content/drive/My Drive/')\\n\\n\\ndf=pd.read_csv('WTH.csv')\\n# df.drop('Unnamed: 0',axis=1,inplace=True)\\ndf\\n\"" 238 | ] 239 | }, 240 | "execution_count": 211, 241 | "metadata": {}, 242 | "output_type": "execute_result" 243 | } 244 | ], 245 | "source": [ 246 | "'''\n", 247 | "from google.colab import drive\n", 248 | "drive.mount('/content/drive')\n", 249 | "\n", 250 | "import os\n", 251 | "os.chdir('/content/drive/My Drive/')\n", 252 | "\n", 253 | "\n", 254 | "df=pd.read_csv('WTH.csv')\n", 255 | "# df.drop('Unnamed: 0',axis=1,inplace=True)\n", 256 | "df\n", 257 | "'''" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 212, 263 | "id": "bffcb3d9", 264 | "metadata": { 265 | "colab": { 266 | "base_uri": "https://localhost:8080/", 267 | "height": 36 268 | }, 269 | "id": "bffcb3d9", 270 | "outputId": "0583a406-d5e9-4687-baf9-6e00500e03d9" 271 | }, 272 | "outputs": [ 273 | { 274 | "data": { 275 | "text/html": [ 276 | "
\n", 277 | "\n", 290 | "\n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | "
VisibilityDryBulbFarenheitDryBulbCelsiusWetBulbFarenheitDewPointFarenheitDewPointCelsiusRelativeHumidityWindSpeedWindDirectionStationPressureAltimeterWetBulbCelsius
010.016-9137-1467713021.65000030.35-10.3
110.016-9137-1467515021.64000030.34-10.3
210.016-9137-1467519021.65000030.35-10.3
310.016-9137-1467718021.65000030.35-10.3
410.016-9149-1374612021.64000030.34-10.0
.......................................
3505910.0320023-500021.47868630.210.0
350607.030-1025-40511021.47868630.210.0
350615.030-1028-200021.47868630.200.0
3506210.030-1028-20514021.47868630.180.0
3506310.030-12928-292511021.52000030.18-1.5
\n", 476 | "

35064 rows × 12 columns

\n", 477 | "
" 478 | ], 479 | "text/plain": [ 480 | " Visibility DryBulbFarenheit DryBulbCelsius WetBulbFarenheit \\\n", 481 | "0 10.0 16 -9 13 \n", 482 | "1 10.0 16 -9 13 \n", 483 | "2 10.0 16 -9 13 \n", 484 | "3 10.0 16 -9 13 \n", 485 | "4 10.0 16 -9 14 \n", 486 | "... ... ... ... ... \n", 487 | "35059 10.0 32 0 0 \n", 488 | "35060 7.0 30 -1 0 \n", 489 | "35061 5.0 30 -1 0 \n", 490 | "35062 10.0 30 -1 0 \n", 491 | "35063 10.0 30 -1 29 \n", 492 | "\n", 493 | " DewPointFarenheit DewPointCelsius RelativeHumidity WindSpeed \\\n", 494 | "0 7 -14 67 7 \n", 495 | "1 7 -14 67 5 \n", 496 | "2 7 -14 67 5 \n", 497 | "3 7 -14 67 7 \n", 498 | "4 9 -13 74 6 \n", 499 | "... ... ... ... ... \n", 500 | "35059 23 -5 0 0 \n", 501 | "35060 25 -4 0 5 \n", 502 | "35061 28 -2 0 0 \n", 503 | "35062 28 -2 0 5 \n", 504 | "35063 28 -2 92 5 \n", 505 | "\n", 506 | " WindDirection StationPressure Altimeter WetBulbCelsius \n", 507 | "0 130 21.650000 30.35 -10.3 \n", 508 | "1 150 21.640000 30.34 -10.3 \n", 509 | "2 190 21.650000 30.35 -10.3 \n", 510 | "3 180 21.650000 30.35 -10.3 \n", 511 | "4 120 21.640000 30.34 -10.0 \n", 512 | "... ... ... ... ... \n", 513 | "35059 0 21.478686 30.21 0.0 \n", 514 | "35060 110 21.478686 30.21 0.0 \n", 515 | "35061 0 21.478686 30.20 0.0 \n", 516 | "35062 140 21.478686 30.18 0.0 \n", 517 | "35063 110 21.520000 30.18 -1.5 \n", 518 | "\n", 519 | "[35064 rows x 12 columns]" 520 | ] 521 | }, 522 | "execution_count": 212, 523 | "metadata": {}, 524 | "output_type": "execute_result" 525 | } 526 | ], 527 | "source": [ 528 | "df=pd.read_csv('WTH.csv',index_col=None)\n", 529 | "# df.drop('Unnamed: 0',axis=1,inplace=True)\n", 530 | "df.iloc[:,1:]" 531 | ] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "execution_count": 213, 536 | "id": "273ef681", 537 | "metadata": {}, 538 | "outputs": [ 539 | { 540 | "data": { 541 | "text/html": [ 542 | "
\n", 543 | "\n", 556 | "\n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | "
dateVisibilityDryBulbFarenheitDryBulbCelsiusWetBulbFarenheitDewPointFarenheitDewPointCelsiusRelativeHumidityWindSpeedWindDirectionStationPressureAltimeterWetBulbCelsius
01/1/2010 0:000.301781-1.580651-1.605506-1.632744-1.014166-1.0355220.6823620.278284-0.0256250.6819930.652382-1.641876
11/1/2010 1:000.301781-1.580651-1.605506-1.632744-1.014166-1.0355220.682362-0.1344720.1701720.6248020.608931-1.641876
21/1/2010 2:000.301781-1.580651-1.605506-1.632744-1.014166-1.0355220.682362-0.1344720.5617680.6819930.652382-1.641876
31/1/2010 3:000.301781-1.580651-1.605506-1.632744-1.014166-1.0355220.6823620.2782840.4638690.6819930.652382-1.641876
41/1/2010 4:000.301781-1.580651-1.605506-1.553935-0.876306-0.9111010.9607690.071906-0.1235240.6248020.608931-1.598387
..........................................
3505912/31/2013 19:000.301781-0.643255-0.650163-2.6572500.0887080.084264-1.982391-1.166361-1.298311-0.2977820.044066-0.148745
3506012/31/2013 20:00-1.172466-0.760429-0.756313-2.6572500.2265670.208684-1.982391-0.134472-0.221423-0.2977820.044066-0.148745
3506112/31/2013 21:00-2.155297-0.760429-0.756313-2.6572500.4333560.457525-1.982391-1.166361-1.298311-0.2977820.000615-0.148745
3506212/31/2013 22:000.301781-0.760429-0.756313-2.6572500.4333560.457525-1.982391-0.1344720.072274-0.297782-0.086287-0.148745
3506312/31/2013 23:000.301781-0.760429-0.756313-0.3718130.4333560.4575251.676673-0.134472-0.221423-0.061499-0.086287-0.366191
\n", 754 | "

35064 rows × 13 columns

\n", 755 | "
" 756 | ], 757 | "text/plain": [ 758 | " date Visibility DryBulbFarenheit DryBulbCelsius \\\n", 759 | "0 1/1/2010 0:00 0.301781 -1.580651 -1.605506 \n", 760 | "1 1/1/2010 1:00 0.301781 -1.580651 -1.605506 \n", 761 | "2 1/1/2010 2:00 0.301781 -1.580651 -1.605506 \n", 762 | "3 1/1/2010 3:00 0.301781 -1.580651 -1.605506 \n", 763 | "4 1/1/2010 4:00 0.301781 -1.580651 -1.605506 \n", 764 | "... ... ... ... ... \n", 765 | "35059 12/31/2013 19:00 0.301781 -0.643255 -0.650163 \n", 766 | "35060 12/31/2013 20:00 -1.172466 -0.760429 -0.756313 \n", 767 | "35061 12/31/2013 21:00 -2.155297 -0.760429 -0.756313 \n", 768 | "35062 12/31/2013 22:00 0.301781 -0.760429 -0.756313 \n", 769 | "35063 12/31/2013 23:00 0.301781 -0.760429 -0.756313 \n", 770 | "\n", 771 | " WetBulbFarenheit DewPointFarenheit DewPointCelsius RelativeHumidity \\\n", 772 | "0 -1.632744 -1.014166 -1.035522 0.682362 \n", 773 | "1 -1.632744 -1.014166 -1.035522 0.682362 \n", 774 | "2 -1.632744 -1.014166 -1.035522 0.682362 \n", 775 | "3 -1.632744 -1.014166 -1.035522 0.682362 \n", 776 | "4 -1.553935 -0.876306 -0.911101 0.960769 \n", 777 | "... ... ... ... ... \n", 778 | "35059 -2.657250 0.088708 0.084264 -1.982391 \n", 779 | "35060 -2.657250 0.226567 0.208684 -1.982391 \n", 780 | "35061 -2.657250 0.433356 0.457525 -1.982391 \n", 781 | "35062 -2.657250 0.433356 0.457525 -1.982391 \n", 782 | "35063 -0.371813 0.433356 0.457525 1.676673 \n", 783 | "\n", 784 | " WindSpeed WindDirection StationPressure Altimeter WetBulbCelsius \n", 785 | "0 0.278284 -0.025625 0.681993 0.652382 -1.641876 \n", 786 | "1 -0.134472 0.170172 0.624802 0.608931 -1.641876 \n", 787 | "2 -0.134472 0.561768 0.681993 0.652382 -1.641876 \n", 788 | "3 0.278284 0.463869 0.681993 0.652382 -1.641876 \n", 789 | "4 0.071906 -0.123524 0.624802 0.608931 -1.598387 \n", 790 | "... ... ... ... ... ... \n", 791 | "35059 -1.166361 -1.298311 -0.297782 0.044066 -0.148745 \n", 792 | "35060 -0.134472 -0.221423 -0.297782 0.044066 -0.148745 \n", 793 | "35061 -1.166361 -1.298311 -0.297782 0.000615 -0.148745 \n", 794 | "35062 -0.134472 0.072274 -0.297782 -0.086287 -0.148745 \n", 795 | "35063 -0.134472 -0.221423 -0.061499 -0.086287 -0.366191 \n", 796 | "\n", 797 | "[35064 rows x 13 columns]" 798 | ] 799 | }, 800 | "execution_count": 213, 801 | "metadata": {}, 802 | "output_type": "execute_result" 803 | } 804 | ], 805 | "source": [ 806 | "scaler = StandardScaler()\n", 807 | "df.iloc[:,1:] = scaler.fit_transform(df.iloc[:,1:])\n", 808 | "df" 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": 214, 814 | "id": "IMhozIdUX4fg", 815 | "metadata": { 816 | "id": "IMhozIdUX4fg" 817 | }, 818 | "outputs": [], 819 | "source": [ 820 | "class weather_data(torch.utils.data.Dataset):\n", 821 | " def __init__(self, df, mode=\"train\", seq_len=336, pred_len=96,num_feat=12):\n", 822 | " super().__init__()\n", 823 | " self.num_feat=num_feat\n", 824 | " self.df = df.iloc[:,1:num_feat+1]\n", 825 | " # time_stamp = df.iloc[:,0]\n", 826 | "\n", 827 | " assert mode in ['train', 'test', 'val']\n", 828 | " type_map = {'train': 0, 'val': 1, 'test': 2}\n", 829 | " self.set_type = type_map[mode]\n", 830 | "\n", 831 | " self.seq_len = seq_len\n", 832 | " self.pred_len = pred_len\n", 833 | " dataset_len=len(df)\n", 834 | "\n", 835 | " border1s = [0, int(round(0.7*dataset_len,0)) - self.seq_len, int(round(0.7*dataset_len,0)) + int(round(0.1*dataset_len,0)) - self.seq_len]\n", 836 | " border2s = [int(round(0.7*dataset_len,0)), int(round(0.7*dataset_len,0)) + int(round(0.1*dataset_len,0)), int(round(0.7*dataset_len,0)) + int(round(0.1*dataset_len,0)) + int(round(0.2*dataset_len,0))]\n", 837 | " border1 = border1s[self.set_type]\n", 838 | " border2 = border2s[self.set_type]\n", 839 | "\n", 840 | "\n", 841 | "\n", 842 | "\n", 843 | " self.df = self.df.to_numpy(dtype=np.float32)\n", 844 | " # time_stamp = time_stamp.to_numpy()\n", 845 | "\n", 846 | " self.data_x = self.df[border1: border2, :]\n", 847 | " self.data_y = self.df[border1: border2, :]\n", 848 | "\n", 849 | " # self.data_stamp = time_stamp[border1: border2]\n", 850 | "\n", 851 | " def __getitem__(self, index):\n", 852 | " s_begin = index\n", 853 | " s_end = s_begin + self.seq_len\n", 854 | " r_begin = s_end\n", 855 | " r_end = r_begin + self.pred_len\n", 856 | "\n", 857 | " seq_x = self.data_x[s_begin:s_end]\n", 858 | " seq_y = self.data_y[r_begin:r_end]\n", 859 | " return seq_x, seq_y\n", 860 | "\n", 861 | "\n", 862 | " def __len__(self):\n", 863 | " return len(self.data_x) - self.seq_len - self.pred_len + 1" 864 | ] 865 | }, 866 | { 867 | "cell_type": "code", 868 | "execution_count": 215, 869 | "id": "Ds8C5Wv1TZJy", 870 | "metadata": { 871 | "colab": { 872 | "base_uri": "https://localhost:8080/" 873 | }, 874 | "id": "Ds8C5Wv1TZJy", 875 | "outputId": "3523cc60-dc4f-4830-acb8-8140b52c92dd" 876 | }, 877 | "outputs": [ 878 | { 879 | "name": "stdout", 880 | "output_type": "stream", 881 | "text": [ 882 | "24114 3411 6918\n" 883 | ] 884 | } 885 | ], 886 | "source": [ 887 | "from torch.utils.data import DataLoader\n", 888 | "\n", 889 | "train_dataset = weather_data(df=df)\n", 890 | "valid_dataset = weather_data(df=df,mode=\"val\")\n", 891 | "test_dataset = weather_data(df=df,mode=\"test\")\n", 892 | "\n", 893 | "print(len(train_dataset),len(valid_dataset),len(test_dataset))" 894 | ] 895 | }, 896 | { 897 | "cell_type": "code", 898 | "execution_count": 216, 899 | "id": "70a18f29", 900 | "metadata": { 901 | "id": "70a18f29" 902 | }, 903 | "outputs": [ 904 | { 905 | "data": { 906 | "text/plain": [ 907 | "(96, 12)" 908 | ] 909 | }, 910 | "execution_count": 216, 911 | "metadata": {}, 912 | "output_type": "execute_result" 913 | } 914 | ], 915 | "source": [ 916 | "train_dataset[1][1].shape" 917 | ] 918 | }, 919 | { 920 | "cell_type": "code", 921 | "execution_count": 217, 922 | "id": "a2c85a51", 923 | "metadata": {}, 924 | "outputs": [ 925 | { 926 | "data": { 927 | "text/plain": [ 928 | "PatchTST.PatchTST_supervised.models.PatchTST.Model" 929 | ] 930 | }, 931 | "execution_count": 217, 932 | "metadata": {}, 933 | "output_type": "execute_result" 934 | } 935 | ], 936 | "source": [ 937 | "import os\n", 938 | "current_directory = os.getcwd()\n", 939 | "os.chdir(current_directory+r\"\\PatchTST\\PatchTST_supervised\")\n", 940 | "from PatchTST.PatchTST_supervised.models.PatchTST import Model\n", 941 | "os.chdir(current_directory)\n", 942 | "Model" 943 | ] 944 | }, 945 | { 946 | "cell_type": "code", 947 | "execution_count": 218, 948 | "id": "322a0ff3", 949 | "metadata": {}, 950 | "outputs": [ 951 | { 952 | "data": { 953 | "text/plain": [ 954 | "'\\nclass Configs:\\n def __init__(self):\\n self.enc_in = 12\\n self.seq_len = 336\\n self.pred_len = 96\\n self.e_layers = 3\\n self.n_heads = 8\\n self.d_model = 128\\n self.d_ff = 512\\n self.dropout = 0.4\\n self.fc_dropout = 0.\\n self.head_dropout = 0.\\n self.individual = True\\n self.patch_len = 24\\n self.stride = 12\\n self.padding_patch = True\\n self.revin = True\\n self.affine = False\\n self.subtract_last = False\\n self.decomposition = False\\n self.kernel_size = 25\\n \\nmy_configs = Configs()\\nmodel=Model(configs=my_configs)\\n'" 955 | ] 956 | }, 957 | "execution_count": 218, 958 | "metadata": {}, 959 | "output_type": "execute_result" 960 | } 961 | ], 962 | "source": [ 963 | "'''\n", 964 | "class Configs:\n", 965 | " def __init__(self):\n", 966 | " self.enc_in = 12\n", 967 | " self.seq_len = 336\n", 968 | " self.pred_len = 96\n", 969 | " self.e_layers = 3\n", 970 | " self.n_heads = 8\n", 971 | " self.d_model = 128\n", 972 | " self.d_ff = 512\n", 973 | " self.dropout = 0.4\n", 974 | " self.fc_dropout = 0.\n", 975 | " self.head_dropout = 0.\n", 976 | " self.individual = True\n", 977 | " self.patch_len = 24\n", 978 | " self.stride = 12\n", 979 | " self.padding_patch = True\n", 980 | " self.revin = True\n", 981 | " self.affine = False\n", 982 | " self.subtract_last = False\n", 983 | " self.decomposition = False\n", 984 | " self.kernel_size = 25\n", 985 | " \n", 986 | "my_configs = Configs()\n", 987 | "model=Model(configs=my_configs)\n", 988 | "'''" 989 | ] 990 | }, 991 | { 992 | "cell_type": "code", 993 | "execution_count": 219, 994 | "id": "7152de2e", 995 | "metadata": {}, 996 | "outputs": [], 997 | "source": [ 998 | "\n", 999 | "class Configs:\n", 1000 | " def __init__(self):\n", 1001 | " self.enc_in = 12\n", 1002 | " self.seq_len = 336\n", 1003 | " self.pred_len = 96\n", 1004 | " self.e_layers = 3\n", 1005 | " self.n_heads = 4\n", 1006 | " self.d_model = 16\n", 1007 | " self.d_ff = 128\n", 1008 | " self.dropout = 0.4\n", 1009 | " self.fc_dropout = 0.\n", 1010 | " self.head_dropout = 0.\n", 1011 | " self.individual = True\n", 1012 | " self.patch_len = 24\n", 1013 | " self.stride = 2\n", 1014 | " self.padding_patch = True\n", 1015 | " self.revin = True\n", 1016 | " self.affine = False\n", 1017 | " self.subtract_last = False\n", 1018 | " self.decomposition = False\n", 1019 | " self.kernel_size = 25\n", 1020 | " \n", 1021 | "my_configs = Configs()\n", 1022 | "model=Model(configs=my_configs)\n" 1023 | ] 1024 | }, 1025 | { 1026 | "cell_type": "code", 1027 | "execution_count": 220, 1028 | "id": "94fdb9c2", 1029 | "metadata": {}, 1030 | "outputs": [ 1031 | { 1032 | "data": { 1033 | "text/plain": [ 1034 | "torch.Size([1, 96, 12])" 1035 | ] 1036 | }, 1037 | "execution_count": 220, 1038 | "metadata": {}, 1039 | "output_type": "execute_result" 1040 | } 1041 | ], 1042 | "source": [ 1043 | "dat=torch.from_numpy(train_dataset[0][0])\n", 1044 | "dat=dat.unsqueeze(dim=0)\n", 1045 | "dat.shape\n", 1046 | "model(dat).shape" 1047 | ] 1048 | }, 1049 | { 1050 | "cell_type": "code", 1051 | "execution_count": 222, 1052 | "id": "4e2b4427", 1053 | "metadata": { 1054 | "colab": { 1055 | "base_uri": "https://localhost:8080/" 1056 | }, 1057 | "id": "4e2b4427", 1058 | "outputId": "be9af307-b172-4a69-ffd9-61220cb69e4e" 1059 | }, 1060 | "outputs": [ 1061 | { 1062 | "data": { 1063 | "text/plain": [ 1064 | "0" 1065 | ] 1066 | }, 1067 | "execution_count": 222, 1068 | "metadata": {}, 1069 | "output_type": "execute_result" 1070 | } 1071 | ], 1072 | "source": [ 1073 | "import gc\n", 1074 | "torch.cuda.empty_cache()\n", 1075 | "gc.collect()" 1076 | ] 1077 | }, 1078 | { 1079 | "cell_type": "code", 1080 | "execution_count": 223, 1081 | "id": "032b97a3", 1082 | "metadata": { 1083 | "id": "032b97a3" 1084 | }, 1085 | "outputs": [ 1086 | { 1087 | "data": { 1088 | "text/plain": [ 1089 | "'\\n # plotting\\n \\n for i,j in zip(test_pred[0],test_y[0]):\\n pass\\n # print(i[ind].item())\\n test_pred_plot.append(i[ind].item())\\n # print(j[ind].item())\\n test_y_plot.append(j[ind].item())\\n sns.lineplot(x=list(range(len(test_pred_plot)))[:96], y=test_pred_plot[:96], label=\"predicted\")\\n sns.lineplot(x=list(range(len(test_y_plot)))[:96], y=test_y_plot[:96], label=\"actual\")\\n plt.legend(title=\\'Lines\\')\\n plt.show()\\n test_total_loss/= test_iter_count\\n print(\"MSE test loss: \",test_total_loss)\\n\\n '" 1090 | ] 1091 | }, 1092 | "execution_count": 223, 1093 | "metadata": {}, 1094 | "output_type": "execute_result" 1095 | } 1096 | ], 1097 | "source": [ 1098 | "class model_run:\n", 1099 | " def __init__(self,model=model):\n", 1100 | " self.patchtst_model = model\n", 1101 | " \n", 1102 | " def model_architecture(self):\n", 1103 | " n=0\n", 1104 | " for x in self.patchtst_model.state_dict():\n", 1105 | " n=n+1\n", 1106 | " print(x)\n", 1107 | " print(\"layers= \",n)\n", 1108 | "\n", 1109 | " def load_datasets(self,train_dataset=train_dataset,valid_dataset=valid_dataset,test_dataset=test_dataset):\n", 1110 | " self.train_dataset=train_dataset\n", 1111 | " self.valid_dataset=valid_dataset\n", 1112 | " self.test_dataset=test_dataset\n", 1113 | "\n", 1114 | "\n", 1115 | "\n", 1116 | " def model_hyperparameters(self,batch_size=16,lr=0.0001,epochs=20,cuda=True,Dataloader=DataLoader):\n", 1117 | " self.batch_size = batch_size\n", 1118 | " self.lr=lr\n", 1119 | " self.epochs=epochs\n", 1120 | "\n", 1121 | " self.train_dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)\n", 1122 | " self.test_dataloader = DataLoader(self.test_dataset, batch_size=1, shuffle=False)\n", 1123 | " self.valid_dataloader = DataLoader(self.valid_dataset, batch_size=batch_size, shuffle=True)\n", 1124 | "\n", 1125 | "\n", 1126 | " if cuda:\n", 1127 | " self.patchtst_model=self.patchtst_model.to(\"cuda\")\n", 1128 | "\n", 1129 | " self.optimizer = torch.optim.Adam(self.patchtst_model.parameters(), lr=self.lr)\n", 1130 | " self.loss = torch.nn.MSELoss()\n", 1131 | " self.loss_mae=torch.nn.L1Loss()\n", 1132 | "\n", 1133 | " def model_execute(self,):\n", 1134 | " self.load_datasets()\n", 1135 | " self.model_hyperparameters()\n", 1136 | "\n", 1137 | " for epoch in range(self.epochs):\n", 1138 | " iter_count = 0\n", 1139 | " total_loss = 0\n", 1140 | " train_steps=0\n", 1141 | " total_train_mae_loss=0\n", 1142 | "\n", 1143 | " for train_x, train_y in tqdm(self.train_dataloader):\n", 1144 | " train_x = train_x.to(\"cuda\")\n", 1145 | " train_y = train_y.to(\"cuda\")\n", 1146 | " \n", 1147 | " self.i=train_x\n", 1148 | " pred_y = self.patchtst_model(train_x)\n", 1149 | " # print(train_x.shape)\n", 1150 | " # print(train_y.shape)\n", 1151 | " # print(pred_y.shape)\n", 1152 | " train_mae_loss=self.loss_mae(pred_y, train_y)\n", 1153 | "\n", 1154 | " loss_t = self.loss(pred_y, train_y)\n", 1155 | " self.optimizer.zero_grad()\n", 1156 | " loss_t.backward()\n", 1157 | " self.optimizer.step()\n", 1158 | " total_loss += loss_t.item()\n", 1159 | " total_train_mae_loss += train_mae_loss.item()\n", 1160 | " iter_count += 1\n", 1161 | " train_steps += 1\n", 1162 | "\n", 1163 | " valid_iter_count = 0\n", 1164 | " valid_total_loss = 0\n", 1165 | " valid_total_mae = 0\n", 1166 | " with torch.no_grad():\n", 1167 | " for valid_x, valid_y in self.valid_dataloader:\n", 1168 | " valid_x = valid_x.to(\"cuda\")\n", 1169 | " valid_y = valid_y.to(\"cuda\")\n", 1170 | " pred_y = self.patchtst_model(valid_x)\n", 1171 | " loss_v = self.loss(pred_y, valid_y)\n", 1172 | " valid_loss_mae=self.loss_mae(pred_y, valid_y)\n", 1173 | "\n", 1174 | " valid_total_loss += loss_v.item()\n", 1175 | " valid_iter_count += 1\n", 1176 | "\n", 1177 | " valid_total_mae+=valid_loss_mae.item()\n", 1178 | "\n", 1179 | " total_loss /= iter_count\n", 1180 | " total_train_mae_loss /= iter_count\n", 1181 | "\n", 1182 | " valid_total_loss /= valid_iter_count\n", 1183 | " valid_total_mae /= valid_iter_count\n", 1184 | " wandb.log({\"MSE Train\": total_loss, \"MSE Test\": valid_total_loss,\"MAE Train\": total_train_mae_loss,\"MAE Test\": valid_total_mae})\n", 1185 | "\n", 1186 | " print(\"epoch: {} MSE loss: {:.4f} MSE valid loss: {:.4f}\".format(epoch, total_loss, valid_total_loss))\n", 1187 | " print(\" MAE loss: {:.4f} MAE valid loss: {:.4f}\".format(total_train_mae_loss, valid_total_mae))\n", 1188 | "\n", 1189 | " def save_model(self,name):\n", 1190 | " torch.save(self.patchtst_model.state_dict(), name+'.pth')\n", 1191 | "\n", 1192 | "\n", 1193 | " def test_plots(self,ind=3,df_columns=None,scaler=None,column=None):\n", 1194 | " test_total_loss = 0\n", 1195 | " test_total_loss_mae=0\n", 1196 | " test_iter_count = 0\n", 1197 | " test_pred_plot=[]\n", 1198 | " test_y_plot=[]\n", 1199 | " \n", 1200 | " df_y_final=pd.DataFrame(columns=df_columns)\n", 1201 | " df_pred_final=pd.DataFrame(columns=df_columns)\n", 1202 | " \n", 1203 | " for test_x, test_y in self.test_dataloader:\n", 1204 | " test_x = test_x.to(\"cuda\")\n", 1205 | " test_y = test_y.to(\"cuda\")\n", 1206 | " test_pred=self.patchtst_model(test_x.to(\"cuda\"))\n", 1207 | " # print(test_pred.shape)\n", 1208 | " # print(test_y.shape)\n", 1209 | " loss_test=self.loss(test_pred, test_y)\n", 1210 | " test_total_loss += loss_test.item()\n", 1211 | " \n", 1212 | " loss_test_mae=self.loss_mae(test_pred, test_y)\n", 1213 | " test_total_loss_mae+=loss_test_mae.item()\n", 1214 | " \n", 1215 | " test_iter_count += 1\n", 1216 | " \n", 1217 | " df_y=pd.DataFrame(test_y[0].to(\"cpu\").numpy(),columns=df_columns)\n", 1218 | " df_pred=pd.DataFrame(test_pred[0].to(\"cpu\").detach().numpy(),columns=df_columns)\n", 1219 | " \n", 1220 | " df_y_final = pd.concat([df_y_final, df_y], axis=0)\n", 1221 | " df_pred_final = pd.concat([df_pred_final, df_pred], axis=0)\n", 1222 | " df_pred_final=scaler.inverse_transform(df_pred_final) \n", 1223 | " df_y_final=scaler.inverse_transform(df_y_final)\n", 1224 | " \n", 1225 | " df_pred_final=pd.DataFrame(df_pred_final,columns=df_columns)\n", 1226 | " df_y_final=pd.DataFrame(df_y_final,columns=df_columns)\n", 1227 | " \n", 1228 | " \n", 1229 | " sns.lineplot(x=list(range(len(df_y_final)))[:96], y=df_y_final[column][:96], label=\"GroundTruth\")\n", 1230 | " sns.lineplot(x=list(range(len(df_pred_final)))[:96], y=df_pred_final[column][:96], label=\"Prediction\")\n", 1231 | " plt.legend(title='Type')\n", 1232 | " plt.xlabel('Time Step')\n", 1233 | " plt.title(\"Predicted vs Actual\")\n", 1234 | " plt.show()\n", 1235 | " \n", 1236 | " \n", 1237 | " test_total_loss/= test_iter_count\n", 1238 | " test_total_loss_mae/=test_iter_count\n", 1239 | " print(\"MSE test loss: \",test_total_loss)\n", 1240 | " print(\"MAE test loss: \",test_total_loss_mae)\n", 1241 | " \n", 1242 | "'''\n", 1243 | " # plotting\n", 1244 | " \n", 1245 | " for i,j in zip(test_pred[0],test_y[0]):\n", 1246 | " pass\n", 1247 | " # print(i[ind].item())\n", 1248 | " test_pred_plot.append(i[ind].item())\n", 1249 | " # print(j[ind].item())\n", 1250 | " test_y_plot.append(j[ind].item())\n", 1251 | " sns.lineplot(x=list(range(len(test_pred_plot)))[:96], y=test_pred_plot[:96], label=\"predicted\")\n", 1252 | " sns.lineplot(x=list(range(len(test_y_plot)))[:96], y=test_y_plot[:96], label=\"actual\")\n", 1253 | " plt.legend(title='Lines')\n", 1254 | " plt.show()\n", 1255 | " test_total_loss/= test_iter_count\n", 1256 | " print(\"MSE test loss: \",test_total_loss)\n", 1257 | "\n", 1258 | " '''\n" 1259 | ] 1260 | }, 1261 | { 1262 | "cell_type": "code", 1263 | "execution_count": 224, 1264 | "id": "c4be2e6f", 1265 | "metadata": { 1266 | "id": "c4be2e6f" 1267 | }, 1268 | "outputs": [], 1269 | "source": [ 1270 | "x=model_run(model=model)" 1271 | ] 1272 | }, 1273 | { 1274 | "cell_type": "code", 1275 | "execution_count": 225, 1276 | "id": "PRq_t36mOl4T", 1277 | "metadata": { 1278 | "colab": { 1279 | "base_uri": "https://localhost:8080/", 1280 | "height": 1000 1281 | }, 1282 | "id": "PRq_t36mOl4T", 1283 | "outputId": "36d03252-4ea5-48cc-b8ce-66ec139f0b65", 1284 | "scrolled": true 1285 | }, 1286 | "outputs": [ 1287 | { 1288 | "name": "stderr", 1289 | "output_type": "stream", 1290 | "text": [ 1291 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:00<00:00, 24.77it/s]\n" 1292 | ] 1293 | }, 1294 | { 1295 | "name": "stdout", 1296 | "output_type": "stream", 1297 | "text": [ 1298 | "epoch: 0 MSE loss: 0.5659 MSE valid loss: 0.6731\n", 1299 | " MAE loss: 0.5382 MAE valid loss: 0.6013\n" 1300 | ] 1301 | }, 1302 | { 1303 | "name": "stderr", 1304 | "output_type": "stream", 1305 | "text": [ 1306 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:00<00:00, 24.96it/s]\n" 1307 | ] 1308 | }, 1309 | { 1310 | "name": "stdout", 1311 | "output_type": "stream", 1312 | "text": [ 1313 | "epoch: 1 MSE loss: 0.5202 MSE valid loss: 0.6677\n", 1314 | " MAE loss: 0.5102 MAE valid loss: 0.5951\n" 1315 | ] 1316 | }, 1317 | { 1318 | "name": "stderr", 1319 | "output_type": "stream", 1320 | "text": [ 1321 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:00<00:00, 24.94it/s]\n" 1322 | ] 1323 | }, 1324 | { 1325 | "name": "stdout", 1326 | "output_type": "stream", 1327 | "text": [ 1328 | "epoch: 2 MSE loss: 0.5089 MSE valid loss: 0.6622\n", 1329 | " MAE loss: 0.5036 MAE valid loss: 0.5908\n" 1330 | ] 1331 | }, 1332 | { 1333 | "name": "stderr", 1334 | "output_type": "stream", 1335 | "text": [ 1336 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:00<00:00, 24.92it/s]\n" 1337 | ] 1338 | }, 1339 | { 1340 | "name": "stdout", 1341 | "output_type": "stream", 1342 | "text": [ 1343 | "epoch: 3 MSE loss: 0.5026 MSE valid loss: 0.6561\n", 1344 | " MAE loss: 0.5002 MAE valid loss: 0.5874\n" 1345 | ] 1346 | }, 1347 | { 1348 | "name": "stderr", 1349 | "output_type": "stream", 1350 | "text": [ 1351 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:00<00:00, 24.74it/s]\n" 1352 | ] 1353 | }, 1354 | { 1355 | "name": "stdout", 1356 | "output_type": "stream", 1357 | "text": [ 1358 | "epoch: 4 MSE loss: 0.4972 MSE valid loss: 0.6522\n", 1359 | " MAE loss: 0.4972 MAE valid loss: 0.5857\n" 1360 | ] 1361 | }, 1362 | { 1363 | "name": "stderr", 1364 | "output_type": "stream", 1365 | "text": [ 1366 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.48it/s]\n" 1367 | ] 1368 | }, 1369 | { 1370 | "name": "stdout", 1371 | "output_type": "stream", 1372 | "text": [ 1373 | "epoch: 5 MSE loss: 0.4929 MSE valid loss: 0.6491\n", 1374 | " MAE loss: 0.4950 MAE valid loss: 0.5856\n" 1375 | ] 1376 | }, 1377 | { 1378 | "name": "stderr", 1379 | "output_type": "stream", 1380 | "text": [ 1381 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.43it/s]\n" 1382 | ] 1383 | }, 1384 | { 1385 | "name": "stdout", 1386 | "output_type": "stream", 1387 | "text": [ 1388 | "epoch: 6 MSE loss: 0.4896 MSE valid loss: 0.6478\n", 1389 | " MAE loss: 0.4933 MAE valid loss: 0.5828\n" 1390 | ] 1391 | }, 1392 | { 1393 | "name": "stderr", 1394 | "output_type": "stream", 1395 | "text": [ 1396 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.61it/s]\n" 1397 | ] 1398 | }, 1399 | { 1400 | "name": "stdout", 1401 | "output_type": "stream", 1402 | "text": [ 1403 | "epoch: 7 MSE loss: 0.4868 MSE valid loss: 0.6556\n", 1404 | " MAE loss: 0.4918 MAE valid loss: 0.5835\n" 1405 | ] 1406 | }, 1407 | { 1408 | "name": "stderr", 1409 | "output_type": "stream", 1410 | "text": [ 1411 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.34it/s]\n" 1412 | ] 1413 | }, 1414 | { 1415 | "name": "stdout", 1416 | "output_type": "stream", 1417 | "text": [ 1418 | "epoch: 8 MSE loss: 0.4842 MSE valid loss: 0.6545\n", 1419 | " MAE loss: 0.4904 MAE valid loss: 0.5844\n" 1420 | ] 1421 | }, 1422 | { 1423 | "name": "stderr", 1424 | "output_type": "stream", 1425 | "text": [ 1426 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.47it/s]\n" 1427 | ] 1428 | }, 1429 | { 1430 | "name": "stdout", 1431 | "output_type": "stream", 1432 | "text": [ 1433 | "epoch: 9 MSE loss: 0.4818 MSE valid loss: 0.6574\n", 1434 | " MAE loss: 0.4892 MAE valid loss: 0.5859\n" 1435 | ] 1436 | }, 1437 | { 1438 | "name": "stderr", 1439 | "output_type": "stream", 1440 | "text": [ 1441 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:02<00:00, 24.06it/s]\n" 1442 | ] 1443 | }, 1444 | { 1445 | "name": "stdout", 1446 | "output_type": "stream", 1447 | "text": [ 1448 | "epoch: 10 MSE loss: 0.4795 MSE valid loss: 0.6614\n", 1449 | " MAE loss: 0.4882 MAE valid loss: 0.5871\n" 1450 | ] 1451 | }, 1452 | { 1453 | "name": "stderr", 1454 | "output_type": "stream", 1455 | "text": [ 1456 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:03<00:00, 23.64it/s]\n" 1457 | ] 1458 | }, 1459 | { 1460 | "name": "stdout", 1461 | "output_type": "stream", 1462 | "text": [ 1463 | "epoch: 11 MSE loss: 0.4763 MSE valid loss: 0.6616\n", 1464 | " MAE loss: 0.4865 MAE valid loss: 0.5881\n" 1465 | ] 1466 | }, 1467 | { 1468 | "name": "stderr", 1469 | "output_type": "stream", 1470 | "text": [ 1471 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:02<00:00, 24.26it/s]\n" 1472 | ] 1473 | }, 1474 | { 1475 | "name": "stdout", 1476 | "output_type": "stream", 1477 | "text": [ 1478 | "epoch: 12 MSE loss: 0.4738 MSE valid loss: 0.6610\n", 1479 | " MAE loss: 0.4851 MAE valid loss: 0.5852\n" 1480 | ] 1481 | }, 1482 | { 1483 | "name": "stderr", 1484 | "output_type": "stream", 1485 | "text": [ 1486 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.56it/s]\n" 1487 | ] 1488 | }, 1489 | { 1490 | "name": "stdout", 1491 | "output_type": "stream", 1492 | "text": [ 1493 | "epoch: 13 MSE loss: 0.4724 MSE valid loss: 0.6614\n", 1494 | " MAE loss: 0.4843 MAE valid loss: 0.5864\n" 1495 | ] 1496 | }, 1497 | { 1498 | "name": "stderr", 1499 | "output_type": "stream", 1500 | "text": [ 1501 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.51it/s]\n" 1502 | ] 1503 | }, 1504 | { 1505 | "name": "stdout", 1506 | "output_type": "stream", 1507 | "text": [ 1508 | "epoch: 14 MSE loss: 0.4708 MSE valid loss: 0.6660\n", 1509 | " MAE loss: 0.4835 MAE valid loss: 0.5877\n" 1510 | ] 1511 | }, 1512 | { 1513 | "name": "stderr", 1514 | "output_type": "stream", 1515 | "text": [ 1516 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.48it/s]\n" 1517 | ] 1518 | }, 1519 | { 1520 | "name": "stdout", 1521 | "output_type": "stream", 1522 | "text": [ 1523 | "epoch: 15 MSE loss: 0.4689 MSE valid loss: 0.6636\n", 1524 | " MAE loss: 0.4825 MAE valid loss: 0.5866\n" 1525 | ] 1526 | }, 1527 | { 1528 | "name": "stderr", 1529 | "output_type": "stream", 1530 | "text": [ 1531 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.50it/s]\n" 1532 | ] 1533 | }, 1534 | { 1535 | "name": "stdout", 1536 | "output_type": "stream", 1537 | "text": [ 1538 | "epoch: 16 MSE loss: 0.4681 MSE valid loss: 0.6668\n", 1539 | " MAE loss: 0.4821 MAE valid loss: 0.5870\n" 1540 | ] 1541 | }, 1542 | { 1543 | "name": "stderr", 1544 | "output_type": "stream", 1545 | "text": [ 1546 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.45it/s]\n" 1547 | ] 1548 | }, 1549 | { 1550 | "name": "stdout", 1551 | "output_type": "stream", 1552 | "text": [ 1553 | "epoch: 17 MSE loss: 0.4667 MSE valid loss: 0.6606\n", 1554 | " MAE loss: 0.4813 MAE valid loss: 0.5838\n" 1555 | ] 1556 | }, 1557 | { 1558 | "name": "stderr", 1559 | "output_type": "stream", 1560 | "text": [ 1561 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.54it/s]\n" 1562 | ] 1563 | }, 1564 | { 1565 | "name": "stdout", 1566 | "output_type": "stream", 1567 | "text": [ 1568 | "epoch: 18 MSE loss: 0.4657 MSE valid loss: 0.6631\n", 1569 | " MAE loss: 0.4809 MAE valid loss: 0.5850\n" 1570 | ] 1571 | }, 1572 | { 1573 | "name": "stderr", 1574 | "output_type": "stream", 1575 | "text": [ 1576 | "100%|██████████████████████████████████████████████████████████████████████████████| 1508/1508 [01:01<00:00, 24.33it/s]\n" 1577 | ] 1578 | }, 1579 | { 1580 | "name": "stdout", 1581 | "output_type": "stream", 1582 | "text": [ 1583 | "epoch: 19 MSE loss: 0.4644 MSE valid loss: 0.6642\n", 1584 | " MAE loss: 0.4801 MAE valid loss: 0.5872\n" 1585 | ] 1586 | } 1587 | ], 1588 | "source": [ 1589 | "x.model_execute()" 1590 | ] 1591 | }, 1592 | { 1593 | "cell_type": "code", 1594 | "execution_count": 231, 1595 | "id": "24989d45", 1596 | "metadata": {}, 1597 | "outputs": [ 1598 | { 1599 | "data": { 1600 | "image/png": "", 1601 | "text/plain": [ 1602 | "
" 1603 | ] 1604 | }, 1605 | "metadata": {}, 1606 | "output_type": "display_data" 1607 | }, 1608 | { 1609 | "name": "stdout", 1610 | "output_type": "stream", 1611 | "text": [ 1612 | "MSE test loss: 0.485170769399389\n", 1613 | "MAE test loss: 0.48576957206008886\n" 1614 | ] 1615 | } 1616 | ], 1617 | "source": [ 1618 | "x.test_plots(df_columns=df.columns[1:],scaler=scaler,column='WindDirection')" 1619 | ] 1620 | }, 1621 | { 1622 | "cell_type": "code", 1623 | "execution_count": null, 1624 | "id": "c137bc01", 1625 | "metadata": {}, 1626 | "outputs": [], 1627 | "source": [ 1628 | "'''\n", 1629 | "dat=torch.from_numpy(train_dataset[0][0])\n", 1630 | "dat=dat.unsqueeze(dim=0)\n", 1631 | "dat.shape\n", 1632 | "model(dat).shape\n", 1633 | "'''" 1634 | ] 1635 | }, 1636 | { 1637 | "cell_type": "code", 1638 | "execution_count": null, 1639 | "id": "1bfe07c7", 1640 | "metadata": {}, 1641 | "outputs": [], 1642 | "source": [] 1643 | } 1644 | ], 1645 | "metadata": { 1646 | "accelerator": "GPU", 1647 | "colab": { 1648 | "gpuType": "V100", 1649 | "machine_shape": "hm", 1650 | "provenance": [] 1651 | }, 1652 | "kernelspec": { 1653 | "display_name": "gpu1", 1654 | "language": "python", 1655 | "name": "gpu1" 1656 | }, 1657 | "language_info": { 1658 | "codemirror_mode": { 1659 | "name": "ipython", 1660 | "version": 3 1661 | }, 1662 | "file_extension": ".py", 1663 | "mimetype": "text/x-python", 1664 | "name": "python", 1665 | "nbconvert_exporter": "python", 1666 | "pygments_lexer": "ipython3", 1667 | "version": "3.9.18" 1668 | } 1669 | }, 1670 | "nbformat": 4, 1671 | "nbformat_minor": 5 1672 | } 1673 | --------------------------------------------------------------------------------