├── .gitignore ├── ConvTSMixer-elec.ipynb ├── ConvTSMixer-exchange.ipynb ├── ConvTSMixer-hyperparameter_tuning.ipynb ├── ConvTSMixer-solar.ipynb ├── ConvTSMixer-traffic.ipynb ├── ConvTSMixer ├── __init__.py ├── estimator.py ├── lightning_module.py └── module.py ├── LICENSE ├── LagTST-hyperparameter_tuning.ipynb ├── LagTST-uni-solar.ipynb ├── LagTST ├── __init__.py ├── estimator.py ├── lightning_module.py └── module.py ├── Linear-hyperparameter_tuning.ipynb ├── Linear-uni-solar.ipynb ├── Linear ├── __init__.py ├── estimator.py ├── lightning_module.py └── module.py ├── MlpTSMixer-elec.ipynb ├── MlpTSMixer-exchange.ipynb ├── MlpTSMixer-hyperparameter_tuning.ipynb ├── MlpTSMixer-solar.ipynb ├── MlpTSMixer-traffic.ipynb ├── MlpTSMixer ├── __init__.py ├── estimator.py ├── lightning_module.py └── module.py ├── PatchTST ├── __init__.py ├── estimator.py ├── lightning_module.py └── module.py ├── README.md ├── TSM-hyperparameter_tuning.ipynb ├── TSMixer-elec.ipynb ├── TSMixer-exchange.ipynb ├── TSMixer-solar.ipynb ├── TSMixer ├── __init__.py ├── estimator.py ├── lightning_module.py ├── module.py ├── module_conv.py └── version_old │ ├── __init__.py │ ├── estimator.py │ ├── lightning_module.py │ ├── model.py │ ├── model_auxiliary.py │ └── module.py ├── Transformer-MV-Solar.ipynb ├── Transformer-exchange.ipynb ├── TsT-elec.ipynb ├── TsT-exchange.ipynb ├── TsT-hyperparameter_tuning.ipynb ├── TsT-solar.ipynb ├── TsT-traffic.ipynb ├── TsT ├── __init__.py ├── estimator.py ├── lightning_module.py └── module.py ├── deepVAR-Solar.ipynb ├── deepVAR-exchange.ipynb ├── examples ├── TS-Mixer Auxiliary Example.ipynb ├── TS-Mixer Base Example.ipynb └── __init__.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # logs 132 | lightning_logs 133 | 134 | # vscode 135 | .vscode/ -------------------------------------------------------------------------------- /ConvTSMixer-hyperparameter_tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "86ff733d-8ea7-4d77-8d2a-cd329ab8f385", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import matplotlib.pyplot as plt\n", 13 | "import numpy as np\n", 14 | "import pandas as pd\n", 15 | "\n", 16 | "import torch\n", 17 | "\n", 18 | "from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n", 19 | "from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n", 20 | "from gluonts.evaluation.backtest import make_evaluation_predictions\n", 21 | "from gluonts.evaluation import MultivariateEvaluator\n", 22 | "\n", 23 | "# from pts.modules import StudentTOutput\n", 24 | "\n", 25 | "from ConvTSMixer import ConvTSMixerEstimator\n", 26 | "import random\n", 27 | "import numpy as np\n", 28 | "import time\n", 29 | "import optuna\n", 30 | "from optuna.samplers import TPESampler" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "944676b0-2a9f-4301-92e6-f382f5693639", 37 | "metadata": { 38 | "tags": [] 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "class ConvTSMixerObjective:\n", 43 | " def __init__(\n", 44 | " self,\n", 45 | " dataset,\n", 46 | " train_grouper,\n", 47 | " test_grouper,\n", 48 | " metric_type=\"m_sum_mean_wQuantileLoss\",\n", 49 | " ):\n", 50 | " self.metric_type = metric_type\n", 51 | " self.dataset = dataset\n", 52 | " self.dataset_train = train_grouper(self.dataset.train)\n", 53 | " self.dataset_test = test_grouper(self.dataset.test)\n", 54 | "\n", 55 | " def get_params(self, trial) -> dict:\n", 56 | " return {\n", 57 | " \"context_length\": trial.suggest_int(\n", 58 | " \"context_length\",\n", 59 | " dataset.metadata.prediction_length,\n", 60 | " dataset.metadata.prediction_length * 5,\n", 61 | " 1,\n", 62 | " ),\n", 63 | " \"batch_size\": trial.suggest_int(\"batch_size\", 32, 256, 32),\n", 64 | " \"depth\": trial.suggest_int(\"depth\", 2, 16, 4),\n", 65 | " \"dim\": trial.suggest_int(\"dim\", 16, 64, 16),\n", 66 | " \"patch_size\": trial.suggest_int(\"patch_size\", 2, 16, 4),\n", 67 | " \"kernel_size\": trial.suggest_int(\"kernel_size\", 2, 8, 2),\n", 68 | " }\n", 69 | "\n", 70 | " def __call__(self, trial):\n", 71 | " params = self.get_params(trial)\n", 72 | " estimator = ConvTSMixerEstimator(\n", 73 | " # distr_output=StudentTOutput(dim=int(dataset.metadata.feat_static_cat[0].cardinality)),\n", 74 | " input_size=int(self.dataset.metadata.feat_static_cat[0].cardinality),\n", 75 | " prediction_length=self.dataset.metadata.prediction_length,\n", 76 | " context_length=params[\"context_length\"],\n", 77 | " freq=self.dataset.metadata.freq,\n", 78 | " scaling=\"std\",\n", 79 | " depth=params[\"depth\"],\n", 80 | " patch_size=(params[\"patch_size\"], params[\"patch_size\"]),\n", 81 | " kernel_size=params[\"kernel_size\"],\n", 82 | " dim=params[\"dim\"],\n", 83 | " batch_size=params[\"batch_size\"],\n", 84 | " num_batches_per_epoch=100,\n", 85 | " patch_reverse_mapping_layer=\"mlp\",\n", 86 | " trainer_kwargs=dict(accelerator=\"cuda\", max_epochs=30),\n", 87 | " )\n", 88 | "\n", 89 | " predictor = estimator.train(\n", 90 | " training_data=self.dataset_train, num_workers=8, shuffle_buffer_length=1024\n", 91 | " )\n", 92 | "\n", 93 | " forecast_it, ts_it = make_evaluation_predictions(\n", 94 | " dataset=self.dataset_test, predictor=predictor, num_samples=100\n", 95 | " )\n", 96 | " forecasts = list(forecast_it)\n", 97 | " tss = list(ts_it)\n", 98 | " evaluator = MultivariateEvaluator(\n", 99 | " quantiles=(np.arange(20) / 20.0)[1:], target_agg_funcs={\"sum\": np.sum}\n", 100 | " )\n", 101 | " agg_metrics, _ = evaluator(iter(tss), iter(forecasts))\n", 102 | " return agg_metrics[self.metric_type]" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "d6c1da22-755e-44d5-86b6-10465d8c25e8", 109 | "metadata": { 110 | "tags": [] 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "dataset = get_dataset(\"exchange_rate_nips\", regenerate=False)\n", 115 | "train_grouper = MultivariateGrouper(\n", 116 | " max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality)\n", 117 | ")\n", 118 | "\n", 119 | "test_grouper = MultivariateGrouper(\n", 120 | " num_test_dates=int(len(dataset.test) / len(dataset.train)),\n", 121 | " max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n", 122 | ")\n", 123 | "dataset_train = train_grouper(dataset.train)\n", 124 | "dataset_test = test_grouper(dataset.test)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "9ded5c21-ea78-4f99-98e3-1573f5abfbf2", 131 | "metadata": { 132 | "tags": [] 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "seed = 42\n", 137 | "random.seed(seed)\n", 138 | "torch.manual_seed(seed)\n", 139 | "start_time = time.time()\n", 140 | "sampler = TPESampler(seed=seed)\n", 141 | "study = optuna.create_study(sampler=sampler, direction=\"minimize\")\n", 142 | "study.optimize(ConvTSMixerObjective(dataset, train_grouper, test_grouper), n_trials=10)\n", 143 | "\n", 144 | "print(\"Number of finished trials: {}\".format(len(study.trials)))\n", 145 | "\n", 146 | "print(\"Best trial:\")\n", 147 | "trial = study.best_trial\n", 148 | "\n", 149 | "print(\" Value: {}\".format(trial.value))\n", 150 | "print(\" Params: \")\n", 151 | "for key, value in trial.params.items():\n", 152 | " print(\" {}: {}\".format(key, value))\n", 153 | "print(time.time() - start_time)" 154 | ] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python 3 (ipykernel)", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.9.0" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 5 178 | } 179 | -------------------------------------------------------------------------------- /ConvTSMixer/__init__.py: -------------------------------------------------------------------------------- 1 | from .module import ConvTSMixerModel 2 | from .lightning_module import ConvTSMixerLightningModule 3 | from .estimator import ConvTSMixerEstimator 4 | 5 | __all__ = ["ConvTSMixerModel", "ConvTSMixerLightningModule", "ConvTSMixerEstimator"] 6 | -------------------------------------------------------------------------------- /ConvTSMixer/lightning_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | import lightning.pytorch as pl 15 | import torch 16 | 17 | from gluonts.core.component import validated 18 | 19 | from .module import ConvTSMixerModel 20 | 21 | 22 | class ConvTSMixerLightningModule(pl.LightningModule): 23 | """ 24 | A ``pl.LightningModule`` class that can be used to train a 25 | ``ConvTSMixerModel`` with PyTorch Lightning. 26 | 27 | This is a thin layer around a (wrapped) ``ConvTSMixerModel`` object, 28 | that exposes the methods to evaluate training and validation loss. 29 | 30 | Parameters 31 | ---------- 32 | model 33 | ``ConvTSMixerModel`` to be trained. 34 | lr 35 | Learning rate, default: ``1e-3``. 36 | weight_decay 37 | Weight decay regularization parameter, default: ``1e-8``. 38 | """ 39 | 40 | @validated() 41 | def __init__( 42 | self, 43 | model_kwargs: dict, 44 | lr: float = 1e-3, 45 | weight_decay: float = 1e-8, 46 | ): 47 | super().__init__() 48 | self.save_hyperparameters() 49 | self.model = ConvTSMixerModel(**model_kwargs) 50 | self.lr = lr 51 | self.weight_decay = weight_decay 52 | 53 | def forward(self, *args, **kwargs): 54 | distr_args, loc, scale = self.model.forward(*args, **kwargs) 55 | distr = self.model.distr_output.distribution(distr_args, loc, scale) 56 | return distr.sample((self.model.num_parallel_samples,)).reshape( 57 | -1, 58 | self.model.num_parallel_samples, 59 | self.model.prediction_length, 60 | self.model.input_size, 61 | ) 62 | 63 | def _compute_loss(self, batch): 64 | past_target = batch["past_target"] 65 | past_observed_values = batch["past_observed_values"] 66 | target = batch["future_target"] 67 | observed_target = batch["future_observed_values"] 68 | 69 | assert past_target.shape[1] == self.model.context_length 70 | assert target.shape[1] == self.model.prediction_length 71 | 72 | distr_args, loc, scale = self.model( 73 | past_target=past_target, 74 | past_observed_values=past_observed_values, 75 | past_time_feat=batch["past_time_feat"], 76 | future_time_feat=batch["future_time_feat"], 77 | ) 78 | loss_values = self.model.distr_output.loss( 79 | target=target, distr_args=distr_args, loc=loc, scale=scale 80 | ) 81 | return (loss_values * observed_target).sum() / torch.maximum( 82 | torch.tensor(1.0), observed_target.sum() 83 | ) 84 | 85 | # distr = self.model.distr_output.distribution(distr_args, loc, scale) 86 | 87 | # return (self.loss(distr, target) * observed_target).sum() / torch.maximum( 88 | # torch.tensor(1.0), observed_target.sum() 89 | # ) 90 | 91 | def training_step(self, batch, batch_idx: int): # type: ignore 92 | """ 93 | Execute training step. 94 | """ 95 | train_loss = self._compute_loss(batch) 96 | self.log( 97 | "train_loss", 98 | train_loss, 99 | on_epoch=True, 100 | on_step=False, 101 | prog_bar=True, 102 | ) 103 | return train_loss 104 | 105 | def validation_step(self, batch, batch_idx: int): # type: ignore 106 | """ 107 | Execute validation step. 108 | """ 109 | val_loss = self._compute_loss(batch) 110 | self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) 111 | return val_loss 112 | 113 | def configure_optimizers(self): 114 | """ 115 | Returns the optimizer to use. 116 | """ 117 | return torch.optim.Adam( 118 | self.model.parameters(), 119 | lr=self.lr, 120 | weight_decay=self.weight_decay, 121 | ) 122 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LagTST-hyperparameter_tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "86ff733d-8ea7-4d77-8d2a-cd329ab8f385", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import matplotlib.pyplot as plt\n", 13 | "import numpy as np\n", 14 | "import pandas as pd\n", 15 | "\n", 16 | "import torch\n", 17 | "\n", 18 | "from gluonts.dataset.repository.datasets import get_dataset\n", 19 | "from gluonts.dataset.common import ListDataset\n", 20 | "from gluonts.evaluation import make_evaluation_predictions, Evaluator\n", 21 | "\n", 22 | "# from gluonts.torch.model.lag_tst.estimator import LagTSTEstimator\n", 23 | "from gluonts.torch.distributions import NegativeBinomialOutput\n", 24 | "from gluonts.torch.modules.loss import NegativeLogLikelihood\n", 25 | "\n", 26 | "from LagTST import LagTSTEstimator\n", 27 | "\n", 28 | "# from pts.modules import StudentTOutput\n", 29 | "\n", 30 | "import random\n", 31 | "import numpy as np\n", 32 | "import time\n", 33 | "import optuna\n", 34 | "from optuna.samplers import TPESampler" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "id": "944676b0-2a9f-4301-92e6-f382f5693639", 41 | "metadata": { 42 | "tags": [] 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "class LagTSTObjective:\n", 47 | " def __init__(self, dataset, metric_type=\"mean_wQuantileLoss\"):\n", 48 | " self.dataset = dataset\n", 49 | " self.metric_type = metric_type\n", 50 | "\n", 51 | " def get_params(self, trial) -> dict:\n", 52 | " return {\n", 53 | " \"context_length\": trial.suggest_int(\n", 54 | " \"context_length\",\n", 55 | " dataset.metadata.prediction_length,\n", 56 | " dataset.metadata.prediction_length * 5,\n", 57 | " 1,\n", 58 | " ),\n", 59 | " \"batch_size\": trial.suggest_int(\"batch_size\", 32, 256, 32),\n", 60 | " \"d_model\": trial.suggest_int(\"d_model\", 16, 64, 16),\n", 61 | " \"dim\": trial.suggest_int(\"dim\", 16, 64, 16),\n", 62 | " \"patch_size\": trial.suggest_int(\"patch_size\", 2, 16, 4),\n", 63 | " \"kernel_size\": trial.suggest_int(\"kernel_size\", 9, 18, 3),\n", 64 | " }\n", 65 | "\n", 66 | " def __call__(self, trial):\n", 67 | " params = self.get_params(trial)\n", 68 | "\n", 69 | " estimator = LagTSTEstimator(\n", 70 | " prediction_length=self.dataset.metadata.prediction_length,\n", 71 | " context_length=params[\"context_length\"],\n", 72 | " freq=dataset.metadata.freq,\n", 73 | " scaling=\"std\",\n", 74 | " # distr_output=NegativeBinomialOutput(),\n", 75 | " # loss=NegativeLogLikelihood(beta=0.2),\n", 76 | " d_model=params[\"d_model\"],\n", 77 | " dim_feedforward=params[\"dim\"],\n", 78 | " batch_size=params[\"batch_size\"],\n", 79 | " patch_reverse_mapping_layer=\"mlp\",\n", 80 | " num_batches_per_epoch=100,\n", 81 | " trainer_kwargs=dict(accelerator=\"gpu\", max_epochs=30),\n", 82 | " )\n", 83 | " predictor = estimator.train(\n", 84 | " training_data=self.dataset.train,\n", 85 | " cache_data=True,\n", 86 | " shuffle_buffer_length=1024,\n", 87 | " validation_data=self.dataset.test,\n", 88 | " )\n", 89 | "\n", 90 | " forecast_it, ts_it = make_evaluation_predictions(\n", 91 | " dataset=dataset.test,\n", 92 | " predictor=predictor,\n", 93 | " )\n", 94 | " forecasts = list(forecast_it)\n", 95 | " # if layer == layers[0]:\n", 96 | " tss = list(ts_it)\n", 97 | " evaluator = Evaluator()\n", 98 | " agg_metrics, _ = evaluator(iter(tss), iter(forecasts))\n", 99 | " return agg_metrics[self.metric_type]" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 3, 105 | "id": "d6c1da22-755e-44d5-86b6-10465d8c25e8", 106 | "metadata": { 107 | "tags": [] 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "dataset = get_dataset(\n", 112 | " \"solar-energy\", regenerate=False\n", 113 | ") # dataset = get_dataset(\"electricity\")" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "9ded5c21-ea78-4f99-98e3-1573f5abfbf2", 120 | "metadata": { 121 | "tags": [], 122 | "pycharm": { 123 | "is_executing": true 124 | } 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "seed = 42\n", 129 | "random.seed(seed)\n", 130 | "torch.manual_seed(seed)\n", 131 | "start_time = time.time()\n", 132 | "sampler = TPESampler(seed=seed)\n", 133 | "study = optuna.create_study(sampler=sampler, direction=\"minimize\")\n", 134 | "study.optimize(LagTSTObjective(dataset), n_trials=10)\n", 135 | "\n", 136 | "print(\"Number of finished trials: {}\".format(len(study.trials)))\n", 137 | "\n", 138 | "print(\"Best trial:\")\n", 139 | "trial = study.best_trial\n", 140 | "\n", 141 | "print(\" Value: {}\".format(trial.value))\n", 142 | "print(\" Params: \")\n", 143 | "for key, value in trial.params.items():\n", 144 | " print(\" {}: {}\".format(key, value))\n", 145 | "print(time.time() - start_time)" 146 | ] 147 | } 148 | ], 149 | "metadata": { 150 | "kernelspec": { 151 | "display_name": "py38", 152 | "language": "python", 153 | "name": "py38" 154 | }, 155 | "language_info": { 156 | "codemirror_mode": { 157 | "name": "ipython", 158 | "version": 3 159 | }, 160 | "file_extension": ".py", 161 | "mimetype": "text/x-python", 162 | "name": "python", 163 | "nbconvert_exporter": "python", 164 | "pygments_lexer": "ipython3", 165 | "version": "3.8.10" 166 | } 167 | }, 168 | "nbformat": 4, 169 | "nbformat_minor": 5 170 | } 171 | -------------------------------------------------------------------------------- /LagTST/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from .module import LagTSTModel 15 | from .lightning_module import LagTSTLightningModule 16 | from .estimator import LagTSTEstimator 17 | 18 | __all__ = ["LagTSTModel", "LagTSTLightningModule", "LagTSTEstimator"] 19 | -------------------------------------------------------------------------------- /LagTST/estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from typing import Optional, Iterable, Dict, Any, List 15 | 16 | import torch 17 | import pytorch_lightning as pl 18 | 19 | from gluonts.core.component import validated 20 | from gluonts.dataset.common import Dataset 21 | from gluonts.dataset.field_names import FieldName 22 | from gluonts.dataset.loader import as_stacked_batches 23 | from gluonts.itertools import Cyclic 24 | from gluonts.model.forecast_generator import DistributionForecastGenerator 25 | from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood 26 | from gluonts.transform import ( 27 | Transformation, 28 | AddObservedValuesIndicator, 29 | InstanceSampler, 30 | InstanceSplitter, 31 | ValidationSplitSampler, 32 | TestSplitSampler, 33 | ExpectedNumInstanceSampler, 34 | SelectFields, 35 | ) 36 | from gluonts.torch.model.estimator import PyTorchLightningEstimator 37 | from gluonts.torch.model.predictor import PyTorchPredictor 38 | from gluonts.torch.distributions import DistributionOutput, StudentTOutput 39 | 40 | from .lightning_module import LagTSTLightningModule 41 | 42 | PREDICTION_INPUT_NAMES = ["past_target", "past_observed_values"] 43 | 44 | TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ 45 | "future_target", 46 | "future_observed_values", 47 | ] 48 | 49 | 50 | class LagTSTEstimator(PyTorchLightningEstimator): 51 | """ 52 | An estimator training the LagTST model for forecasting. 53 | 54 | This class is uses the model defined in ``SimpleFeedForwardModel``, 55 | and wraps it into a ``LagTSTLightningModule`` for training 56 | purposes: training is performed using PyTorch Lightning's ``pl.Trainer`` 57 | class. 58 | 59 | Parameters 60 | ---------- 61 | freq 62 | Frequency of the data to train on and predict. 63 | prediction_length 64 | Length of the prediction horizon. 65 | context_length 66 | Number of time steps prior to prediction time that the model 67 | takes as inputs (default: ``10 * prediction_length``). 68 | lags_seq 69 | Indices of the lagged target values to use as inputs of the RNN 70 | (default: None, in which case these are automatically determined 71 | based on freq). 72 | d_model 73 | Size of hidden layers in the Transformer encoder. 74 | nhead 75 | Number of attention heads in the Transformer encoder. 76 | dim_feedforward 77 | Size of hidden layers in the Transformer encoder. 78 | dropout 79 | Dropout probability in the Transformer encoder. 80 | activation 81 | Activation function in the Transformer encoder. 82 | norm_first 83 | Whether to apply normalization before or after the attention. 84 | num_encoder_layers 85 | Number of layers in the Transformer encoder. 86 | lr 87 | Learning rate (default: ``1e-3``). 88 | weight_decay 89 | Weight decay regularization parameter (default: ``1e-8``). 90 | scaling 91 | Scaling parameter can be "mean", "std" or None. 92 | distr_output 93 | Distribution to use to evaluate observations and sample predictions 94 | (default: StudentTOutput()). 95 | loss 96 | Loss to be optimized during training 97 | (default: ``NegativeLogLikelihood()``). 98 | batch_size 99 | The size of the batches to be used for training (default: 32). 100 | num_batches_per_epoch 101 | Number of batches to be processed in each training epoch 102 | (default: 50). 103 | trainer_kwargs 104 | Additional arguments to provide to ``pl.Trainer`` for construction. 105 | train_sampler 106 | Controls the sampling of windows during training. 107 | validation_sampler 108 | Controls the sampling of windows during validation. 109 | """ 110 | 111 | @validated() 112 | def __init__( 113 | self, 114 | freq: str, 115 | prediction_length: int, 116 | context_length: Optional[int] = None, 117 | d_model: int = 32, 118 | nhead: int = 4, 119 | dim_feedforward: int = 128, 120 | lags_seq: Optional[List[int]] = None, 121 | dropout: float = 0.1, 122 | activation: str = "relu", 123 | norm_first: bool = False, 124 | num_encoder_layers: int = 2, 125 | lr: float = 1e-3, 126 | weight_decay: float = 1e-8, 127 | scaling: Optional[str] = "mean", 128 | distr_output: DistributionOutput = StudentTOutput(), 129 | loss: DistributionLoss = NegativeLogLikelihood(), 130 | batch_size: int = 32, 131 | num_batches_per_epoch: int = 50, 132 | trainer_kwargs: Optional[Dict[str, Any]] = None, 133 | train_sampler: Optional[InstanceSampler] = None, 134 | validation_sampler: Optional[InstanceSampler] = None, 135 | ) -> None: 136 | default_trainer_kwargs = {"max_epochs": 100} 137 | if trainer_kwargs is not None: 138 | default_trainer_kwargs.update(trainer_kwargs) 139 | super().__init__(trainer_kwargs=default_trainer_kwargs) 140 | 141 | self.prediction_length = prediction_length 142 | self.context_length = context_length or 10 * prediction_length 143 | self.freq = freq 144 | # TODO find way to enforce same defaults to network and estimator 145 | # somehow 146 | self.lr = lr 147 | self.weight_decay = weight_decay 148 | self.distr_output = distr_output 149 | self.loss = loss 150 | self.scaling = scaling 151 | self.lags_seq = lags_seq 152 | self.d_model = d_model 153 | self.nhead = nhead 154 | self.dim_feedforward = dim_feedforward 155 | self.dropout = dropout 156 | self.activation = activation 157 | self.norm_first = norm_first 158 | self.num_encoder_layers = num_encoder_layers 159 | self.batch_size = batch_size 160 | self.num_batches_per_epoch = num_batches_per_epoch 161 | 162 | self.train_sampler = train_sampler or ExpectedNumInstanceSampler( 163 | num_instances=1.0, min_future=prediction_length 164 | ) 165 | self.validation_sampler = validation_sampler or ValidationSplitSampler( 166 | min_future=prediction_length 167 | ) 168 | 169 | def create_transformation(self) -> Transformation: 170 | return SelectFields( 171 | [ 172 | FieldName.ITEM_ID, 173 | FieldName.INFO, 174 | FieldName.START, 175 | FieldName.TARGET, 176 | ], 177 | allow_missing=True, 178 | ) + AddObservedValuesIndicator( 179 | target_field=FieldName.TARGET, 180 | output_field=FieldName.OBSERVED_VALUES, 181 | ) 182 | 183 | def create_lightning_module(self) -> pl.LightningModule: 184 | return LagTSTLightningModule( 185 | loss=self.loss, 186 | lr=self.lr, 187 | weight_decay=self.weight_decay, 188 | model_kwargs={ 189 | "prediction_length": self.prediction_length, 190 | "context_length": self.context_length, 191 | "freq": self.freq, 192 | "lags_seq": self.lags_seq, 193 | "d_model": self.d_model, 194 | "nhead": self.nhead, 195 | "dim_feedforward": self.dim_feedforward, 196 | "dropout": self.dropout, 197 | "activation": self.activation, 198 | "norm_first": self.norm_first, 199 | "num_encoder_layers": self.num_encoder_layers, 200 | "distr_output": self.distr_output, 201 | "scaling": self.scaling, 202 | }, 203 | ) 204 | 205 | def _create_instance_splitter(self, module: LagTSTLightningModule, mode: str): 206 | assert mode in ["training", "validation", "test"] 207 | 208 | instance_sampler = { 209 | "training": self.train_sampler, 210 | "validation": self.validation_sampler, 211 | "test": TestSplitSampler(), 212 | }[mode] 213 | 214 | return InstanceSplitter( 215 | target_field=FieldName.TARGET, 216 | is_pad_field=FieldName.IS_PAD, 217 | start_field=FieldName.START, 218 | forecast_start_field=FieldName.FORECAST_START, 219 | instance_sampler=instance_sampler, 220 | past_length=module.model._past_length, 221 | future_length=self.prediction_length, 222 | time_series_fields=[FieldName.OBSERVED_VALUES], 223 | dummy_value=self.distr_output.value_in_support, 224 | ) 225 | 226 | def create_training_data_loader( 227 | self, 228 | data: Dataset, 229 | module: LagTSTLightningModule, 230 | shuffle_buffer_length: Optional[int] = None, 231 | **kwargs, 232 | ) -> Iterable: 233 | data = Cyclic(data).stream() 234 | instances = self._create_instance_splitter(module, "training").apply( 235 | data, is_train=True 236 | ) 237 | return as_stacked_batches( 238 | instances, 239 | batch_size=self.batch_size, 240 | shuffle_buffer_length=shuffle_buffer_length, 241 | field_names=TRAINING_INPUT_NAMES, 242 | output_type=torch.tensor, 243 | num_batches_per_epoch=self.num_batches_per_epoch, 244 | ) 245 | 246 | def create_validation_data_loader( 247 | self, data: Dataset, module: LagTSTLightningModule, **kwargs 248 | ) -> Iterable: 249 | instances = self._create_instance_splitter(module, "validation").apply( 250 | data, is_train=True 251 | ) 252 | return as_stacked_batches( 253 | instances, 254 | batch_size=self.batch_size, 255 | field_names=TRAINING_INPUT_NAMES, 256 | output_type=torch.tensor, 257 | ) 258 | 259 | def create_predictor( 260 | self, transformation: Transformation, module 261 | ) -> PyTorchPredictor: 262 | prediction_splitter = self._create_instance_splitter(module, "test") 263 | 264 | return PyTorchPredictor( 265 | input_transform=transformation + prediction_splitter, 266 | input_names=PREDICTION_INPUT_NAMES, 267 | prediction_net=module, 268 | forecast_generator=DistributionForecastGenerator(self.distr_output), 269 | batch_size=self.batch_size, 270 | prediction_length=self.prediction_length, 271 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 272 | ) 273 | -------------------------------------------------------------------------------- /LagTST/lightning_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | import pytorch_lightning as pl 15 | import torch 16 | 17 | from gluonts.core.component import validated 18 | from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood 19 | 20 | from .module import LagTSTModel 21 | 22 | 23 | class LagTSTLightningModule(pl.LightningModule): 24 | """ 25 | A ``pl.LightningModule`` class that can be used to train a 26 | ``LagTSTModel`` with PyTorch Lightning. 27 | 28 | This is a thin layer around a (wrapped) ``LagTSTModel`` object, 29 | that exposes the methods to evaluate training and validation loss. 30 | 31 | Parameters 32 | ---------- 33 | model 34 | ``LagTSTModel`` to be trained. 35 | loss 36 | Loss function to be used for training, 37 | default: ``NegativeLogLikelihood()``. 38 | lr 39 | Learning rate, default: ``1e-3``. 40 | weight_decay 41 | Weight decay regularization parameter, default: ``1e-8``. 42 | """ 43 | 44 | @validated() 45 | def __init__( 46 | self, 47 | model_kwargs: dict, 48 | loss: DistributionLoss = NegativeLogLikelihood(), 49 | lr: float = 1e-3, 50 | weight_decay: float = 1e-8, 51 | ): 52 | super().__init__() 53 | self.save_hyperparameters() 54 | self.model = LagTSTModel(**model_kwargs) 55 | self.loss = loss 56 | self.lr = lr 57 | self.weight_decay = weight_decay 58 | 59 | def forward(self, *args, **kwargs): 60 | return self.model.forward(*args, **kwargs) 61 | 62 | def _compute_loss(self, batch): 63 | context = batch["past_target"] 64 | past_observed_values = batch["past_observed_values"] 65 | target = batch["future_target"] 66 | observed_target = batch["future_observed_values"] 67 | 68 | assert context.shape[-1] == self.model._past_length 69 | assert target.shape[-1] == self.model.prediction_length 70 | 71 | distr_args, loc, scale = self.model(context, past_observed_values) 72 | distr = self.model.distr_output.distribution(distr_args, loc, scale) 73 | 74 | return (self.loss(distr, target) * observed_target).sum() / torch.maximum( 75 | torch.tensor(1.0), observed_target.sum() 76 | ) 77 | 78 | def training_step(self, batch, batch_idx: int): # type: ignore 79 | """ 80 | Execute training step. 81 | """ 82 | train_loss = self._compute_loss(batch) 83 | self.log( 84 | "train_loss", 85 | train_loss, 86 | on_epoch=True, 87 | on_step=False, 88 | prog_bar=True, 89 | ) 90 | return train_loss 91 | 92 | def validation_step(self, batch, batch_idx: int): # type: ignore 93 | """ 94 | Execute validation step. 95 | """ 96 | val_loss = self._compute_loss(batch) 97 | self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) 98 | return val_loss 99 | 100 | def configure_optimizers(self): 101 | """ 102 | Returns the optimizer to use. 103 | """ 104 | return torch.optim.Adam( 105 | self.model.parameters(), 106 | lr=self.lr, 107 | weight_decay=self.weight_decay, 108 | ) 109 | -------------------------------------------------------------------------------- /LagTST/module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from typing import Tuple, Optional, List 15 | 16 | import numpy as np 17 | from pandas.tseries.frequencies import to_offset 18 | import torch 19 | from torch import nn 20 | 21 | from gluonts.core.component import validated 22 | from gluonts.model import Input, InputSpec 23 | from gluonts.torch.distributions import StudentTOutput 24 | from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler 25 | from gluonts.torch.util import unsqueeze_expand, lagged_sequence_values 26 | from gluonts.time_feature import norm_freq_str 27 | from gluonts.time_feature.lag import _make_lags 28 | 29 | 30 | def get_lags_for_frequency( 31 | freq_str: str, lag_ub: int = 1200, num_lags: Optional[int] = None 32 | ) -> List[int]: 33 | """ 34 | Generates a list of lags that that are appropriate for the given frequency 35 | string. 36 | 37 | By default all frequencies have the following lags: [1]. 38 | Remaining lags correspond to the same `season` (+/- `delta`) in previous 39 | `k` cycles. Here `delta` and `k` are chosen according to the existing code. 40 | 41 | Parameters 42 | ---------- 43 | 44 | freq_str 45 | Frequency string of the form [multiple][granularity] such as "12H", 46 | "5min", "1D" etc. 47 | 48 | lag_ub 49 | The maximum value for a lag. 50 | 51 | num_lags 52 | Maximum number of lags; by default all generated lags are returned 53 | """ 54 | 55 | # Lags are target values at the same `season` (+/- delta) but in the 56 | # previous cycle. 57 | def _make_lags_for_second(multiple, num_cycles=3): 58 | # We use previous ``num_cycles`` hours to generate lags 59 | return [_make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)] 60 | 61 | def _make_lags_for_minute(multiple, num_cycles=3): 62 | # We use previous ``num_cycles`` hours to generate lags 63 | return [_make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)] 64 | 65 | def _make_lags_for_hour(multiple, num_cycles=7): 66 | # We use previous ``num_cycles`` days to generate lags 67 | return [_make_lags(k * 24 // multiple, 1) for k in range(1, num_cycles + 1)] 68 | 69 | def _make_lags_for_day(multiple, num_cycles=4, days_in_week=7, days_in_month=30): 70 | # We use previous ``num_cycles`` weeks to generate lags 71 | # We use the last month (in addition to 4 weeks) to generate lag. 72 | return [ 73 | _make_lags(k * days_in_week // multiple, 1) 74 | for k in range(1, num_cycles + 1) 75 | ] + [_make_lags(days_in_month // multiple, 1)] 76 | 77 | def _make_lags_for_week(multiple, num_cycles=3): 78 | # We use previous ``num_cycles`` years to generate lags 79 | # Additionally, we use previous 4, 8, 12 weeks 80 | return [_make_lags(k * 52 // multiple, 1) for k in range(1, num_cycles + 1)] + [ 81 | [4 // multiple, 8 // multiple, 12 // multiple] 82 | ] 83 | 84 | def _make_lags_for_month(multiple, num_cycles=3): 85 | # We use previous ``num_cycles`` years to generate lags 86 | return [_make_lags(k * 12 // multiple, 1) for k in range(1, num_cycles + 1)] 87 | 88 | # multiple, granularity = get_granularity(freq_str) 89 | offset = to_offset(freq_str) 90 | # normalize offset name, so that both `W` and `W-SUN` refer to `W` 91 | offset_name = norm_freq_str(offset.name) 92 | 93 | if offset_name == "A": 94 | lags = [] 95 | elif offset_name == "Q": 96 | assert ( 97 | offset.n == 1 98 | ), "Only multiple 1 is supported for quarterly. Use x month instead." 99 | lags = _make_lags_for_month(offset.n * 3.0) 100 | elif offset_name == "M": 101 | lags = _make_lags_for_month(offset.n) 102 | elif offset_name == "W": 103 | lags = _make_lags_for_week(offset.n) 104 | elif offset_name == "D": 105 | lags = _make_lags_for_day(offset.n) + _make_lags_for_week(offset.n / 7.0) 106 | elif offset_name == "B": 107 | lags = _make_lags_for_day( 108 | offset.n, days_in_week=5, days_in_month=22 109 | ) + _make_lags_for_week(offset.n / 5.0) 110 | elif offset_name == "H": 111 | lags = ( 112 | _make_lags_for_hour(offset.n) 113 | + _make_lags_for_day(offset.n / 24) 114 | + _make_lags_for_week(offset.n / (24 * 7)) 115 | ) 116 | # minutes 117 | elif offset_name == "T": 118 | lags = ( 119 | _make_lags_for_minute(offset.n) 120 | + _make_lags_for_hour(offset.n / 60) 121 | + _make_lags_for_day(offset.n / (60 * 24)) 122 | + _make_lags_for_week(offset.n / (60 * 24 * 7)) 123 | ) 124 | # second 125 | elif offset_name == "S": 126 | lags = ( 127 | _make_lags_for_second(offset.n) 128 | + _make_lags_for_minute(offset.n / 60) 129 | + _make_lags_for_hour(offset.n / (60 * 60)) 130 | ) 131 | else: 132 | raise Exception("invalid frequency") 133 | 134 | # flatten lags list and filter 135 | lags = [int(lag) for sub_list in lags for lag in sub_list if 7 < lag <= lag_ub] 136 | lags = [1] + sorted(list(set(lags))) 137 | 138 | return lags[:num_lags] 139 | 140 | 141 | def make_linear_layer(dim_in, dim_out): 142 | lin = nn.Linear(dim_in, dim_out) 143 | torch.nn.init.uniform_(lin.weight, -0.07, 0.07) 144 | torch.nn.init.zeros_(lin.bias) 145 | return lin 146 | 147 | 148 | class SinusoidalPositionalEmbedding(nn.Embedding): 149 | """This module produces sinusoidal positional embeddings of any length.""" 150 | 151 | def __init__(self, num_positions: int, embedding_dim: int) -> None: 152 | super().__init__(num_positions, embedding_dim) 153 | self.weight = self._init_weight(self.weight) 154 | 155 | @staticmethod 156 | def _init_weight(out: nn.Parameter) -> nn.Parameter: 157 | """ 158 | Features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] 159 | """ 160 | n_pos, dim = out.shape 161 | position_enc = np.array( 162 | [ 163 | [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] 164 | for pos in range(n_pos) 165 | ] 166 | ) 167 | # set early to avoid an error in pytorch-1.8+ 168 | out.requires_grad = False 169 | 170 | sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 171 | out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) 172 | out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) 173 | out.detach_() 174 | return out 175 | 176 | @torch.no_grad() 177 | def forward( 178 | self, input_ids_shape: torch.Size, past_key_values_length: int = 0 179 | ) -> torch.Tensor: 180 | """`input_ids_shape` is expected to be [bsz x seqlen x ...].""" 181 | _, seq_len = input_ids_shape[:2] 182 | positions = torch.arange( 183 | past_key_values_length, 184 | past_key_values_length + seq_len, 185 | dtype=torch.long, 186 | device=self.weight.device, 187 | ) 188 | return super().forward(positions) 189 | 190 | 191 | class LagTSTModel(nn.Module): 192 | """ 193 | Module implementing the LagTST model for forecasting. 194 | 195 | Parameters 196 | ---------- 197 | prediction_length 198 | Number of time points to predict. 199 | context_length 200 | Number of time steps prior to prediction time that the model. 201 | distr_output 202 | Distribution to use to evaluate observations and sample predictions. 203 | Default: ``StudentTOutput()``. 204 | """ 205 | 206 | @validated() 207 | def __init__( 208 | self, 209 | prediction_length: int, 210 | context_length: int, 211 | freq: str, 212 | d_model: int, 213 | nhead: int, 214 | dim_feedforward: int, 215 | dropout: float, 216 | activation: str, 217 | norm_first: bool, 218 | num_encoder_layers: int, 219 | scaling: str, 220 | lags_seq: Optional[List[int]] = None, 221 | distr_output=StudentTOutput(), 222 | ) -> None: 223 | super().__init__() 224 | 225 | assert prediction_length > 0 226 | assert context_length > 0 227 | 228 | self.prediction_length = prediction_length 229 | self.context_length = context_length 230 | self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq) 231 | self.d_model = d_model 232 | self.distr_output = distr_output 233 | 234 | if scaling == "mean": 235 | self.scaler = MeanScaler(keepdim=True) 236 | elif scaling == "std": 237 | self.scaler = StdScaler(keepdim=True) 238 | else: 239 | self.scaler = NOPScaler(keepdim=True) 240 | 241 | # project from number of lags + 2 features (loc and scale) to d_model 242 | self.patch_proj = make_linear_layer(len(self.lags_seq) + 2, d_model) 243 | 244 | self.positional_encoding = SinusoidalPositionalEmbedding( 245 | self.context_length, d_model 246 | ) 247 | 248 | layer_norm_eps: float = 1e-5 249 | encoder_layer = nn.TransformerEncoderLayer( 250 | d_model=d_model, 251 | nhead=nhead, 252 | dim_feedforward=dim_feedforward, 253 | dropout=dropout, 254 | activation=activation, 255 | layer_norm_eps=layer_norm_eps, 256 | batch_first=True, 257 | norm_first=norm_first, 258 | ) 259 | encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) 260 | self.encoder = nn.TransformerEncoder( 261 | encoder_layer, num_encoder_layers, encoder_norm 262 | ) 263 | 264 | self.flatten = nn.Linear( 265 | d_model * self.context_length, prediction_length * d_model 266 | ) 267 | 268 | self.args_proj = self.distr_output.get_args_proj(d_model) 269 | 270 | @property 271 | def _past_length(self) -> int: 272 | return self.context_length + max(self.lags_seq) 273 | 274 | def describe_inputs(self, batch_size=1) -> InputSpec: 275 | return InputSpec( 276 | { 277 | "past_target": Input( 278 | shape=(batch_size, self._past_length), dtype=torch.float 279 | ), 280 | "past_observed_values": Input( 281 | shape=(batch_size, self._past_length), dtype=torch.float 282 | ), 283 | }, 284 | torch.zeros, 285 | ) 286 | 287 | def forward( 288 | self, 289 | past_target: torch.Tensor, 290 | past_observed_values: torch.Tensor, 291 | ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: 292 | # scale the input 293 | past_target_scaled, loc, scale = self.scaler(past_target, past_observed_values) 294 | 295 | lags = lagged_sequence_values( 296 | self.lags_seq, 297 | past_target_scaled[:, : -self.context_length, ...], 298 | past_target_scaled[:, -self.context_length :, ...], 299 | dim=-1, 300 | ) 301 | 302 | # add loc and scale to past_target_patches as additional features 303 | log_abs_loc = loc.abs().log1p() 304 | log_scale = scale.log() 305 | expanded_static_feat = unsqueeze_expand( 306 | torch.cat([log_abs_loc, log_scale], dim=-1), 307 | dim=1, 308 | size=lags.shape[1], 309 | ) 310 | inputs = torch.cat((lags, expanded_static_feat), dim=-1) 311 | 312 | # project patches 313 | enc_in = self.patch_proj(inputs) 314 | embed_pos = self.positional_encoding(enc_in.size()) 315 | 316 | # transformer encoder with positional encoding 317 | enc_out = self.encoder(enc_in + embed_pos) 318 | 319 | # flatten and project to prediction length * d_model 320 | flatten_out = self.flatten(enc_out.flatten(start_dim=1)) 321 | 322 | # project to distribution arguments 323 | distr_args = self.args_proj( 324 | flatten_out.reshape(-1, self.prediction_length, self.d_model) 325 | ) 326 | return distr_args, loc, scale 327 | -------------------------------------------------------------------------------- /Linear-hyperparameter_tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "86ff733d-8ea7-4d77-8d2a-cd329ab8f385", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import matplotlib.pyplot as plt\n", 13 | "import numpy as np\n", 14 | "import pandas as pd\n", 15 | "\n", 16 | "import torch\n", 17 | "\n", 18 | "from gluonts.dataset.repository.datasets import get_dataset\n", 19 | "from gluonts.dataset.common import ListDataset\n", 20 | "from gluonts.evaluation import make_evaluation_predictions, Evaluator\n", 21 | "from gluonts.torch.distributions import NegativeBinomialOutput\n", 22 | "from gluonts.torch.modules.loss import NegativeLogLikelihood\n", 23 | "\n", 24 | "from Linear import LinearEstimator\n", 25 | "\n", 26 | "# from pts.modules import StudentTOutput\n", 27 | "\n", 28 | "import random\n", 29 | "import numpy as np\n", 30 | "import time\n", 31 | "import optuna" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "944676b0-2a9f-4301-92e6-f382f5693639", 38 | "metadata": { 39 | "tags": [] 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "class LinearObjective:\n", 44 | " def __init__(self, dataset, metric_type=\"mean_wQuantileLoss\"):\n", 45 | " self.dataset = dataset\n", 46 | " self.metric_type = metric_type\n", 47 | "\n", 48 | " def get_params(self, trial) -> dict:\n", 49 | " return {\n", 50 | " \"context_length\": trial.suggest_int(\n", 51 | " \"context_length\",\n", 52 | " dataset.metadata.prediction_length,\n", 53 | " dataset.metadata.prediction_length * 10,\n", 54 | " 4,\n", 55 | " ),\n", 56 | " \"batch_size\": trial.suggest_int(\"batch_size\", 128, 256, 64),\n", 57 | " }\n", 58 | "\n", 59 | " def __call__(self, trial):\n", 60 | " params = self.get_params(trial)\n", 61 | "\n", 62 | " estimator = LinearEstimator(\n", 63 | " prediction_length=self.dataset.metadata.prediction_length,\n", 64 | " context_length=dataset.metadata.prediction_length,\n", 65 | " freq=dataset.metadata.freq,\n", 66 | " scaling=\"std\",\n", 67 | " # distr_output=NegativeBinomialOutput(),\n", 68 | " # loss=NegativeLogLikelihood(beta=0.2),\n", 69 | " batch_size=params[\"batch_size\"],\n", 70 | " num_batches_per_epoch=100,\n", 71 | " trainer_kwargs=dict(accelerator=\"gpu\", max_epochs=30),\n", 72 | " )\n", 73 | " predictor = estimator.train(\n", 74 | " training_data=self.dataset.train,\n", 75 | " cache_data=True,\n", 76 | " shuffle_buffer_length=1024,\n", 77 | " validation_data=self.dataset.test,\n", 78 | " )\n", 79 | "\n", 80 | " forecast_it, ts_it = make_evaluation_predictions(\n", 81 | " dataset=dataset.test,\n", 82 | " predictor=predictor,\n", 83 | " )\n", 84 | " forecasts = list(forecast_it)\n", 85 | " # if layer == layers[0]:\n", 86 | " tss = list(ts_it)\n", 87 | " evaluator = Evaluator()\n", 88 | " agg_metrics, _ = evaluator(iter(tss), iter(forecasts))\n", 89 | " return agg_metrics[self.metric_type]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "d6c1da22-755e-44d5-86b6-10465d8c25e8", 96 | "metadata": { 97 | "tags": [] 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "dataset = get_dataset(\n", 102 | " \"solar-energy\", regenerate=False\n", 103 | ") # dataset = get_dataset(\"electricity\")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "9ded5c21-ea78-4f99-98e3-1573f5abfbf2", 110 | "metadata": { 111 | "tags": [] 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "seed = 42\n", 116 | "random.seed(seed)\n", 117 | "torch.manual_seed(seed)\n", 118 | "start_time = time.time()\n", 119 | "study = optuna.create_study(direction=\"minimize\")\n", 120 | "study.optimize(LinearObjective(dataset), n_trials=10)\n", 121 | "\n", 122 | "print(\"Number of finished trials: {}\".format(len(study.trials)))\n", 123 | "\n", 124 | "print(\"Best trial:\")\n", 125 | "trial = study.best_trial\n", 126 | "\n", 127 | "print(\" Value: {}\".format(trial.value))\n", 128 | "print(\" Params: \")\n", 129 | "for key, value in trial.params.items():\n", 130 | " print(\" {}: {}\".format(key, value))\n", 131 | "print(time.time() - start_time)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "9819ac62-f056-4dcb-814f-075f8ce15bac", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [] 141 | } 142 | ], 143 | "metadata": { 144 | "kernelspec": { 145 | "display_name": "py38", 146 | "language": "python", 147 | "name": "py38" 148 | }, 149 | "language_info": { 150 | "codemirror_mode": { 151 | "name": "ipython", 152 | "version": 3 153 | }, 154 | "file_extension": ".py", 155 | "mimetype": "text/x-python", 156 | "name": "python", 157 | "nbconvert_exporter": "python", 158 | "pygments_lexer": "ipython3", 159 | "version": "3.8.10" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 5 164 | } 165 | -------------------------------------------------------------------------------- /Linear/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from .module import LinearModel 15 | from .lightning_module import LinearLightningModule 16 | from .estimator import LinearEstimator 17 | 18 | __all__ = ["LinearModel", "LinearLightningModule", "LinearEstimator"] 19 | -------------------------------------------------------------------------------- /Linear/estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from typing import List, Optional, Iterable, Dict, Any 15 | 16 | import torch 17 | import pytorch_lightning as pl 18 | 19 | from gluonts.core.component import validated 20 | from gluonts.dataset.common import Dataset 21 | from gluonts.dataset.field_names import FieldName 22 | from gluonts.dataset.loader import as_stacked_batches 23 | from gluonts.dataset.stat import calculate_dataset_statistics 24 | from gluonts.itertools import Cyclic 25 | from gluonts.model.forecast_generator import DistributionForecastGenerator 26 | from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood 27 | from gluonts.transform import ( 28 | Transformation, 29 | AddObservedValuesIndicator, 30 | InstanceSampler, 31 | InstanceSplitter, 32 | ValidationSplitSampler, 33 | TestSplitSampler, 34 | ExpectedNumInstanceSampler, 35 | RemoveFields, 36 | SetField, 37 | AddTimeFeatures, 38 | AddAgeFeature, 39 | VstackFeatures, 40 | ) 41 | from gluonts.time_feature import TimeFeature, time_features_from_frequency_str 42 | from gluonts.torch.model.estimator import PyTorchLightningEstimator 43 | from gluonts.torch.model.predictor import PyTorchPredictor 44 | from gluonts.torch.distributions import ( 45 | DistributionOutput, 46 | StudentTOutput, 47 | ) 48 | 49 | from .lightning_module import LinearLightningModule 50 | 51 | PREDICTION_INPUT_NAMES = [ 52 | "feat_static_cat", 53 | "feat_static_real", 54 | "past_time_feat", 55 | "past_target", 56 | "past_observed_values", 57 | "future_time_feat", 58 | ] 59 | 60 | TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ 61 | "future_target", 62 | "future_observed_values", 63 | ] 64 | 65 | 66 | class LinearEstimator(PyTorchLightningEstimator): 67 | """ 68 | An estimator training a Linear model for forecasting. 69 | 70 | This class is uses the model defined in ``LinearModel``, 71 | and wraps it into a ``LinearLightningModule`` for training 72 | purposes: training is performed using PyTorch Lightning's ``pl.Trainer`` 73 | class. 74 | 75 | Parameters 76 | ---------- 77 | prediction_length 78 | Length of the prediction horizon. 79 | context_length 80 | Number of time steps prior to prediction time that the model 81 | takes as inputs (default: ``10 * prediction_length``). 82 | hidden_dimensions 83 | Size of hidden layers in the feed-forward network 84 | (default: ``[20, 20]``). 85 | lr 86 | Learning rate (default: ``1e-3``). 87 | weight_decay 88 | Weight decay regularization parameter (default: ``1e-8``). 89 | distr_output 90 | Distribution to use to evaluate observations and sample predictions 91 | (default: StudentTOutput()). 92 | loss 93 | Loss to be optimized during training 94 | (default: ``NegativeLogLikelihood()``). 95 | batch_norm 96 | Whether to apply batch normalization. 97 | batch_size 98 | The size of the batches to be used for training (default: 32). 99 | num_batches_per_epoch 100 | Number of batches to be processed in each training epoch 101 | (default: 50). 102 | trainer_kwargs 103 | Additional arguments to provide to ``pl.Trainer`` for construction. 104 | train_sampler 105 | Controls the sampling of windows during training. 106 | validation_sampler 107 | Controls the sampling of windows during validation. 108 | 109 | """ 110 | 111 | @validated() 112 | def __init__( 113 | self, 114 | freq: str, 115 | prediction_length: int, 116 | context_length: Optional[int] = None, 117 | hidden_dimensions: Optional[List[int]] = None, 118 | input_size: int = 1, 119 | scaling: Optional[str] = "mean", 120 | num_feat_dynamic_real: int = 0, 121 | num_feat_static_cat: int = 0, 122 | num_feat_static_real: int = 0, 123 | cardinality: Optional[List[int]] = None, 124 | embedding_dimension: Optional[List[int]] = None, 125 | time_features: Optional[List[TimeFeature]] = None, 126 | lr: float = 1e-3, 127 | weight_decay: float = 1e-8, 128 | distr_output: DistributionOutput = StudentTOutput(), 129 | loss: DistributionLoss = NegativeLogLikelihood(), 130 | batch_norm: bool = False, 131 | batch_size: int = 32, 132 | num_batches_per_epoch: int = 50, 133 | trainer_kwargs: Optional[Dict[str, Any]] = None, 134 | train_sampler: Optional[InstanceSampler] = None, 135 | validation_sampler: Optional[InstanceSampler] = None, 136 | ) -> None: 137 | default_trainer_kwargs = {"max_epochs": 100, "gradient_clip_val": 10.0} 138 | if trainer_kwargs is not None: 139 | default_trainer_kwargs.update(trainer_kwargs) 140 | super().__init__(trainer_kwargs=default_trainer_kwargs) 141 | 142 | self.scaling = scaling 143 | self.freq = freq 144 | self.input_size = input_size 145 | self.prediction_length = prediction_length 146 | self.context_length = context_length or 10 * prediction_length 147 | self.num_feat_dynamic_real = num_feat_dynamic_real 148 | self.num_feat_static_cat = num_feat_static_cat 149 | self.num_feat_static_real = num_feat_static_real 150 | self.cardinality = ( 151 | cardinality if cardinality and num_feat_static_cat > 0 else [1] 152 | ) 153 | self.embedding_dimension = embedding_dimension 154 | self.time_features = ( 155 | time_features 156 | if time_features is not None 157 | else time_features_from_frequency_str(self.freq) 158 | ) 159 | # TODO find way to enforce same defaults to network and estimator 160 | # somehow 161 | self.hidden_dimensions = hidden_dimensions or [20, 20] 162 | self.lr = lr 163 | self.weight_decay = weight_decay 164 | self.distr_output = distr_output 165 | self.loss = loss 166 | self.batch_norm = batch_norm 167 | self.batch_size = batch_size 168 | self.num_batches_per_epoch = num_batches_per_epoch 169 | 170 | self.train_sampler = train_sampler or ExpectedNumInstanceSampler( 171 | num_instances=1.0, min_future=prediction_length 172 | ) 173 | self.validation_sampler = validation_sampler or ValidationSplitSampler( 174 | min_future=prediction_length 175 | ) 176 | 177 | @classmethod 178 | def derive_auto_fields(cls, train_iter): 179 | stats = calculate_dataset_statistics(train_iter) 180 | 181 | return { 182 | "num_feat_dynamic_real": stats.num_feat_dynamic_real, 183 | "num_feat_static_cat": len(stats.feat_static_cat), 184 | "cardinality": [len(cats) for cats in stats.feat_static_cat], 185 | } 186 | 187 | def create_transformation(self) -> Transformation: 188 | remove_field_names = [] 189 | if self.num_feat_static_real == 0: 190 | remove_field_names.append(FieldName.FEAT_STATIC_REAL) 191 | if self.num_feat_dynamic_real == 0: 192 | remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) 193 | 194 | return ( 195 | RemoveFields(field_names=remove_field_names) 196 | + ( 197 | SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0]) 198 | if not self.num_feat_static_cat > 0 199 | else [] 200 | ) 201 | + ( 202 | SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]) 203 | if not self.num_feat_static_real > 0 204 | else [] 205 | ) 206 | + AddTimeFeatures( 207 | start_field=FieldName.START, 208 | target_field=FieldName.TARGET, 209 | output_field=FieldName.FEAT_TIME, 210 | time_features=self.time_features, 211 | pred_length=self.prediction_length, 212 | ) 213 | + AddAgeFeature( 214 | target_field=FieldName.TARGET, 215 | output_field=FieldName.FEAT_AGE, 216 | pred_length=self.prediction_length, 217 | log_scale=True, 218 | ) 219 | + VstackFeatures( 220 | output_field=FieldName.FEAT_TIME, 221 | input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] 222 | + ( 223 | [FieldName.FEAT_DYNAMIC_REAL] 224 | if self.num_feat_dynamic_real > 0 225 | else [] 226 | ), 227 | ) 228 | + AddObservedValuesIndicator( 229 | target_field=FieldName.TARGET, 230 | output_field=FieldName.OBSERVED_VALUES, 231 | ) 232 | ) 233 | 234 | def create_lightning_module(self) -> pl.LightningModule: 235 | return LinearLightningModule( 236 | loss=self.loss, 237 | lr=self.lr, 238 | weight_decay=self.weight_decay, 239 | model_kwargs={ 240 | "input_size": self.input_size, 241 | "prediction_length": self.prediction_length, 242 | "context_length": self.context_length, 243 | "hidden_dimensions": self.hidden_dimensions, 244 | "scaling": self.scaling, 245 | "distr_output": self.distr_output, 246 | "batch_norm": self.batch_norm, 247 | }, 248 | ) 249 | 250 | def _create_instance_splitter(self, module: LinearLightningModule, mode: str): 251 | assert mode in ["training", "validation", "test"] 252 | 253 | instance_sampler = { 254 | "training": self.train_sampler, 255 | "validation": self.validation_sampler, 256 | "test": TestSplitSampler(), 257 | }[mode] 258 | 259 | return InstanceSplitter( 260 | target_field=FieldName.TARGET, 261 | is_pad_field=FieldName.IS_PAD, 262 | start_field=FieldName.START, 263 | forecast_start_field=FieldName.FORECAST_START, 264 | instance_sampler=instance_sampler, 265 | past_length=self.context_length, 266 | future_length=self.prediction_length, 267 | time_series_fields=[FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES], 268 | dummy_value=self.distr_output.value_in_support, 269 | ) 270 | 271 | def create_training_data_loader( 272 | self, 273 | data: Dataset, 274 | module: LinearLightningModule, 275 | shuffle_buffer_length: Optional[int] = None, 276 | **kwargs, 277 | ) -> Iterable: 278 | data = Cyclic(data).stream() 279 | instances = self._create_instance_splitter(module, "training").apply( 280 | data, is_train=True 281 | ) 282 | return as_stacked_batches( 283 | instances, 284 | batch_size=self.batch_size, 285 | shuffle_buffer_length=shuffle_buffer_length, 286 | field_names=TRAINING_INPUT_NAMES, 287 | output_type=torch.tensor, 288 | num_batches_per_epoch=self.num_batches_per_epoch, 289 | ) 290 | 291 | def create_validation_data_loader( 292 | self, 293 | data: Dataset, 294 | module: LinearLightningModule, 295 | **kwargs, 296 | ) -> Iterable: 297 | instances = self._create_instance_splitter(module, "validation").apply( 298 | data, is_train=True 299 | ) 300 | return as_stacked_batches( 301 | instances, 302 | batch_size=self.batch_size, 303 | field_names=TRAINING_INPUT_NAMES, 304 | output_type=torch.tensor, 305 | ) 306 | 307 | def create_predictor( 308 | self, 309 | transformation: Transformation, 310 | module, 311 | ) -> PyTorchPredictor: 312 | prediction_splitter = self._create_instance_splitter(module, "test") 313 | 314 | return PyTorchPredictor( 315 | input_transform=transformation + prediction_splitter, 316 | input_names=PREDICTION_INPUT_NAMES, 317 | prediction_net=module, 318 | forecast_generator=DistributionForecastGenerator(self.distr_output), 319 | batch_size=self.batch_size, 320 | prediction_length=self.prediction_length, 321 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 322 | ) 323 | -------------------------------------------------------------------------------- /Linear/lightning_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | import pytorch_lightning as pl 15 | import torch 16 | 17 | from gluonts.core.component import validated 18 | from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood 19 | 20 | from .module import LinearModel 21 | 22 | 23 | class LinearLightningModule(pl.LightningModule): 24 | """ 25 | A ``pl.LightningModule`` class that can be used to train a 26 | ``LinearModel`` with PyTorch Lightning. 27 | 28 | This is a thin layer around a (wrapped) ``LinearModel`` object, 29 | that exposes the methods to evaluate training and validation loss. 30 | 31 | Parameters 32 | ---------- 33 | model 34 | ``LinearModel`` to be trained. 35 | loss 36 | Loss function to be used for training, 37 | default: ``NegativeLogLikelihood()``. 38 | lr 39 | Learning rate, default: ``1e-3``. 40 | weight_decay 41 | Weight decay regularization parameter, default: ``1e-8``. 42 | """ 43 | 44 | @validated() 45 | def __init__( 46 | self, 47 | model_kwargs: dict, 48 | loss: DistributionLoss = NegativeLogLikelihood(), 49 | lr: float = 1e-3, 50 | weight_decay: float = 1e-8, 51 | ): 52 | super().__init__() 53 | self.save_hyperparameters() 54 | self.model = LinearModel(**model_kwargs) 55 | self.loss = loss 56 | self.lr = lr 57 | self.weight_decay = weight_decay 58 | 59 | def forward(self, *args, **kwargs): 60 | return self.model.forward(*args, **kwargs) 61 | 62 | def _compute_loss(self, batch): 63 | past_target = batch["past_target"] 64 | past_observed_values = batch["past_observed_values"] 65 | target = batch["future_target"] 66 | observed_target = batch["future_observed_values"] 67 | 68 | assert past_target.shape[-1] == self.model.context_length 69 | assert target.shape[-1] == self.model.prediction_length 70 | 71 | distr_args, loc, scale = self.model( 72 | past_target=past_target, past_observed_values=past_observed_values 73 | ) 74 | distr = self.model.distr_output.distribution(distr_args, loc, scale) 75 | 76 | return (self.loss(distr, target) * observed_target).sum() / torch.maximum( 77 | torch.tensor(1.0), observed_target.sum() 78 | ) 79 | 80 | def training_step(self, batch, batch_idx: int): # type: ignore 81 | """ 82 | Execute training step. 83 | """ 84 | train_loss = self._compute_loss(batch) 85 | self.log( 86 | "train_loss", 87 | train_loss, 88 | on_epoch=True, 89 | on_step=False, 90 | prog_bar=True, 91 | ) 92 | return train_loss 93 | 94 | def validation_step(self, batch, batch_idx: int): # type: ignore 95 | """ 96 | Execute validation step. 97 | """ 98 | val_loss = self._compute_loss(batch) 99 | self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) 100 | return val_loss 101 | 102 | def configure_optimizers(self): 103 | """ 104 | Returns the optimizer to use. 105 | """ 106 | return torch.optim.Adam( 107 | self.model.parameters(), 108 | lr=self.lr, 109 | weight_decay=self.weight_decay, 110 | ) 111 | -------------------------------------------------------------------------------- /Linear/module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from typing import List, Tuple, Optional 15 | 16 | import torch 17 | from torch import nn 18 | 19 | from gluonts.core.component import validated 20 | from gluonts.model import Input, InputSpec 21 | from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler 22 | from gluonts.torch.distributions import StudentTOutput 23 | 24 | 25 | def make_linear_layer(dim_in, dim_out): 26 | lin = nn.Linear(dim_in, dim_out) 27 | torch.nn.init.uniform_(lin.weight, -0.07, 0.07) 28 | torch.nn.init.zeros_(lin.bias) 29 | return lin 30 | 31 | 32 | class LinearModel(nn.Module): 33 | """ 34 | Module implementing Linear for forecasting. 35 | 36 | Parameters 37 | ---------- 38 | prediction_length 39 | Number of time points to predict. 40 | context_length 41 | Number of time steps prior to prediction time that the model. 42 | hidden_dimensions 43 | Size of hidden layers in the feed-forward network. 44 | distr_output 45 | Distribution to use to evaluate observations and sample predictions. 46 | Default: ``StudentTOutput()``. 47 | batch_norm 48 | Whether to apply batch normalization. Default: ``False``. 49 | """ 50 | 51 | @validated() 52 | def __init__( 53 | self, 54 | prediction_length: int, 55 | context_length: int, 56 | scaling: str, 57 | input_size: int, 58 | hidden_dimensions: Optional[List[int]] = None, 59 | distr_output=StudentTOutput(), 60 | batch_norm: bool = False, 61 | ) -> None: 62 | super().__init__() 63 | 64 | assert prediction_length > 0 65 | assert context_length > 0 66 | assert hidden_dimensions is None or len(hidden_dimensions) > 0 67 | 68 | self.prediction_length = prediction_length 69 | self.context_length = context_length 70 | self.hidden_dimensions = ( 71 | hidden_dimensions if hidden_dimensions is not None else [20, 20] 72 | ) 73 | if scaling == "mean": 74 | self.scaler = MeanScaler(keepdim=True) 75 | elif scaling == "std": 76 | self.scaler = StdScaler(keepdim=True) 77 | else: 78 | self.scaler = NOPScaler(keepdim=True) 79 | 80 | self.distr_output = distr_output 81 | self.batch_norm = batch_norm 82 | 83 | dimensions = [context_length] + self.hidden_dimensions[:-1] 84 | 85 | modules = [] 86 | for in_size, out_size in zip(dimensions[:-1], dimensions[1:]): 87 | modules += [make_linear_layer(in_size, out_size), nn.ReLU()] 88 | if batch_norm: 89 | modules.append(nn.BatchNorm1d(out_size)) 90 | modules.append( 91 | make_linear_layer( 92 | dimensions[-1], prediction_length * self.hidden_dimensions[-1] 93 | ) 94 | ) 95 | 96 | self.nn = nn.Sequential(*modules) 97 | self.args_proj = self.distr_output.get_args_proj(self.hidden_dimensions[-1]) 98 | 99 | def describe_inputs(self, batch_size=1) -> InputSpec: 100 | return InputSpec( 101 | { 102 | "past_target": Input( 103 | shape=(batch_size, self.context_length), dtype=torch.float 104 | ), 105 | "past_observed_values": Input( 106 | shape=(batch_size, self.context_length), dtype=torch.float 107 | ), 108 | }, 109 | torch.zeros, 110 | ) 111 | 112 | def forward( 113 | self, 114 | feat_static_cat: Optional[torch.Tensor] = None, 115 | feat_static_real: Optional[torch.Tensor] = None, 116 | past_time_feat: Optional[torch.Tensor] = None, 117 | past_target: Optional[torch.Tensor] = None, 118 | past_observed_values: Optional[torch.Tensor] = None, 119 | future_time_feat: Optional[torch.Tensor] = None, 120 | future_target: Optional[torch.Tensor] = None, 121 | future_observed_values: Optional[torch.Tensor] = None, 122 | ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: 123 | past_target_scaled, loc, scale = self.scaler(past_target, past_observed_values) 124 | nn_out = self.nn(past_target_scaled) 125 | nn_out_reshaped = nn_out.reshape( 126 | -1, self.prediction_length, self.hidden_dimensions[-1] 127 | ) 128 | distr_args = self.args_proj(nn_out_reshaped) 129 | return distr_args, loc, scale 130 | -------------------------------------------------------------------------------- /MlpTSMixer-hyperparameter_tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "86ff733d-8ea7-4d77-8d2a-cd329ab8f385", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import matplotlib.pyplot as plt\n", 13 | "import numpy as np\n", 14 | "import pandas as pd\n", 15 | "\n", 16 | "import torch\n", 17 | "\n", 18 | "from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n", 19 | "from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n", 20 | "from gluonts.evaluation.backtest import make_evaluation_predictions\n", 21 | "from gluonts.evaluation import MultivariateEvaluator\n", 22 | "\n", 23 | "from MlpTSMixer import MlpTSMixerEstimator\n", 24 | "\n", 25 | "import random\n", 26 | "import numpy as np\n", 27 | "import time\n", 28 | "import optuna\n", 29 | "from optuna.samplers import TPESampler" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "944676b0-2a9f-4301-92e6-f382f5693639", 36 | "metadata": { 37 | "tags": [] 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "class MlpTSMixerObjective:\n", 42 | " def __init__(\n", 43 | " self,\n", 44 | " dataset,\n", 45 | " train_grouper,\n", 46 | " test_grouper,\n", 47 | " metric_type=\"m_sum_mean_wQuantileLoss\",\n", 48 | " ):\n", 49 | " self.metric_type = metric_type\n", 50 | " self.dataset = dataset\n", 51 | " self.dataset_train = train_grouper(self.dataset.train)\n", 52 | " self.dataset_test = test_grouper(self.dataset.test)\n", 53 | "\n", 54 | " def get_params(self, trial) -> dict:\n", 55 | " return {\n", 56 | " \"context_length\": trial.suggest_int(\n", 57 | " \"context_length\",\n", 58 | " dataset.metadata.prediction_length,\n", 59 | " dataset.metadata.prediction_length * 5,\n", 60 | " 1,\n", 61 | " ),\n", 62 | " \"batch_size\": trial.suggest_int(\"batch_size\", 32, 256, 32),\n", 63 | " \"depth\": trial.suggest_int(\"depth\", 2, 16, 4),\n", 64 | " \"dim\": trial.suggest_int(\"dim\", 16, 64, 16),\n", 65 | " \"patch_size\": trial.suggest_int(\"patch_size\", 2, 16, 4),\n", 66 | " \"expansion_factor\": trial.suggest_int(\"expansion_factor\", 2, 8, 2),\n", 67 | " \"kernel_size\": trial.suggest_int(\"kernel_size\", 9, 18, 3),\n", 68 | " }\n", 69 | "\n", 70 | " def __call__(self, trial):\n", 71 | " params = self.get_params(trial)\n", 72 | " estimator = estimator = MlpTSMixerEstimator(\n", 73 | " # distr_output=StudentTOutput(dim=int(dataset.metadata.feat_static_cat[0].cardinality)),\n", 74 | " input_size=int(self.dataset.metadata.feat_static_cat[0].cardinality),\n", 75 | " prediction_length=self.dataset.metadata.prediction_length,\n", 76 | " context_length=params[\"context_length\"],\n", 77 | " freq=self.dataset.metadata.freq,\n", 78 | " scaling=\"std\",\n", 79 | " depth=params[\"depth\"],\n", 80 | " patch_size=(params[\"patch_size\"], params[\"patch_size\"]),\n", 81 | " dim=params[\"dim\"],\n", 82 | " batch_size=params[\"batch_size\"],\n", 83 | " num_batches_per_epoch=100,\n", 84 | " patch_reverse_mapping_layer=\"mlp\",\n", 85 | " trainer_kwargs=dict(accelerator=\"cuda\", max_epochs=30),\n", 86 | " )\n", 87 | " predictor = estimator.train(\n", 88 | " training_data=self.dataset_train, num_workers=8, shuffle_buffer_length=1024\n", 89 | " )\n", 90 | "\n", 91 | " forecast_it, ts_it = make_evaluation_predictions(\n", 92 | " dataset=self.dataset_test, predictor=predictor, num_samples=100\n", 93 | " )\n", 94 | " forecasts = list(forecast_it)\n", 95 | " tss = list(ts_it)\n", 96 | " evaluator = MultivariateEvaluator(\n", 97 | " quantiles=(np.arange(20) / 20.0)[1:], target_agg_funcs={\"sum\": np.sum}\n", 98 | " )\n", 99 | " agg_metrics, _ = evaluator(iter(tss), iter(forecasts))\n", 100 | " return agg_metrics[self.metric_type]" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "d6c1da22-755e-44d5-86b6-10465d8c25e8", 107 | "metadata": { 108 | "tags": [] 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "dataset = get_dataset(\"solar_nips\", regenerate=False)\n", 113 | "train_grouper = MultivariateGrouper(\n", 114 | " max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality)\n", 115 | ")\n", 116 | "\n", 117 | "test_grouper = MultivariateGrouper(\n", 118 | " num_test_dates=int(len(dataset.test) / len(dataset.train)),\n", 119 | " max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n", 120 | ")\n", 121 | "dataset_train = train_grouper(dataset.train)\n", 122 | "dataset_test = test_grouper(dataset.test)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "id": "9ded5c21-ea78-4f99-98e3-1573f5abfbf2", 129 | "metadata": { 130 | "tags": [] 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "seed = 42\n", 135 | "random.seed(seed)\n", 136 | "torch.manual_seed(seed)\n", 137 | "start_time = time.time()\n", 138 | "sampler = TPESampler(seed=seed)\n", 139 | "study = optuna.create_study(sampler=sampler, direction=\"minimize\")\n", 140 | "study.optimize(MlpTSMixerObjective(dataset, train_grouper, test_grouper), n_trials=10)\n", 141 | "\n", 142 | "print(\"Number of finished trials: {}\".format(len(study.trials)))\n", 143 | "\n", 144 | "print(\"Best trial:\")\n", 145 | "trial = study.best_trial\n", 146 | "\n", 147 | "print(\" Value: {}\".format(trial.value))\n", 148 | "print(\" Params: \")\n", 149 | "for key, value in trial.params.items():\n", 150 | " print(\" {}: {}\".format(key, value))\n", 151 | "print(time.time() - start_time)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "id": "9819ac62-f056-4dcb-814f-075f8ce15bac", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "py38", 166 | "language": "python", 167 | "name": "py38" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.8.10" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 5 184 | } 185 | -------------------------------------------------------------------------------- /MlpTSMixer/__init__.py: -------------------------------------------------------------------------------- 1 | from .module import MlpTSMixerModel 2 | from .lightning_module import MlpTSMixerLightningModule 3 | from .estimator import MlpTSMixerEstimator 4 | 5 | __all__ = ["MlpTSMixerModel", "MlpTSMixerLightningModule", "MlpTSMixerEstimator"] 6 | -------------------------------------------------------------------------------- /MlpTSMixer/lightning_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | import lightning.pytorch as pl 15 | import torch 16 | 17 | from gluonts.core.component import validated 18 | 19 | from .module import MlpTSMixerModel 20 | 21 | 22 | class MlpTSMixerLightningModule(pl.LightningModule): 23 | """ 24 | A ``pl.LightningModule`` class that can be used to train a 25 | ``MlpTSMixerModel`` with PyTorch Lightning. 26 | 27 | This is a thin layer around a (wrapped) ``MlpTSMixerModel`` object, 28 | that exposes the methods to evaluate training and validation loss. 29 | 30 | Parameters 31 | ---------- 32 | model 33 | ``MlpTSMixerModel`` to be trained. 34 | loss 35 | Loss function to be used for training, 36 | default: ``NegativeLogLikelihood()``. 37 | lr 38 | Learning rate, default: ``1e-3``. 39 | weight_decay 40 | Weight decay regularization parameter, default: ``1e-8``. 41 | """ 42 | 43 | @validated() 44 | def __init__( 45 | self, 46 | model_kwargs: dict, 47 | lr: float = 1e-3, 48 | weight_decay: float = 1e-8, 49 | ): 50 | super().__init__() 51 | self.save_hyperparameters() 52 | self.model = MlpTSMixerModel(**model_kwargs) 53 | self.lr = lr 54 | self.weight_decay = weight_decay 55 | 56 | def forward(self, *args, **kwargs): 57 | distr_args, loc, scale = self.model.forward(*args, **kwargs) 58 | distr = self.model.distr_output.distribution(distr_args, loc, scale) 59 | return distr.sample((self.model.num_parallel_samples,)).reshape( 60 | -1, 61 | self.model.num_parallel_samples, 62 | self.model.prediction_length, 63 | self.model.input_size, 64 | ) 65 | 66 | def _compute_loss(self, batch): 67 | past_target = batch["past_target"] 68 | past_observed_values = batch["past_observed_values"] 69 | target = batch["future_target"] 70 | observed_target = batch["future_observed_values"] 71 | 72 | assert past_target.shape[1] == self.model.context_length 73 | assert target.shape[1] == self.model.prediction_length 74 | 75 | distr_args, loc, scale = self.model( 76 | past_target=past_target, 77 | past_observed_values=past_observed_values, 78 | past_time_feat=batch["past_time_feat"], 79 | future_time_feat=batch["future_time_feat"], 80 | ) 81 | loss_values = self.model.distr_output.loss( 82 | target=target, distr_args=distr_args, loc=loc, scale=scale 83 | ) 84 | return (loss_values * observed_target).sum() / torch.maximum( 85 | torch.tensor(1.0), observed_target.sum() 86 | ) 87 | 88 | def training_step(self, batch, batch_idx: int): # type: ignore 89 | """ 90 | Execute training step. 91 | """ 92 | train_loss = self._compute_loss(batch) 93 | self.log( 94 | "train_loss", 95 | train_loss, 96 | on_epoch=True, 97 | on_step=False, 98 | prog_bar=True, 99 | ) 100 | return train_loss 101 | 102 | def validation_step(self, batch, batch_idx: int): # type: ignore 103 | """ 104 | Execute validation step. 105 | """ 106 | val_loss = self._compute_loss(batch) 107 | self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) 108 | return val_loss 109 | 110 | def configure_optimizers(self): 111 | """ 112 | Returns the optimizer to use. 113 | """ 114 | return torch.optim.Adam( 115 | self.model.parameters(), 116 | lr=self.lr, 117 | weight_decay=self.weight_decay, 118 | ) 119 | -------------------------------------------------------------------------------- /PatchTST/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from .module import PatchTSTModel 15 | from .lightning_module import PatchTSTLightningModule 16 | from .estimator import PatchTSTEstimator 17 | 18 | __all__ = [ 19 | "PatchTSTModel", 20 | "PatchTSTLightningModule", 21 | "PatchTSTEstimator", 22 | ] 23 | -------------------------------------------------------------------------------- /PatchTST/estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from typing import Optional, Iterable, Dict, Any 15 | 16 | import torch 17 | import pytorch_lightning as pl 18 | 19 | from gluonts.core.component import validated 20 | from gluonts.dataset.common import Dataset 21 | from gluonts.dataset.field_names import FieldName 22 | from gluonts.dataset.loader import as_stacked_batches 23 | from gluonts.itertools import Cyclic 24 | from gluonts.model.forecast_generator import DistributionForecastGenerator 25 | from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood 26 | from gluonts.transform import ( 27 | Transformation, 28 | AddObservedValuesIndicator, 29 | InstanceSampler, 30 | InstanceSplitter, 31 | ValidationSplitSampler, 32 | TestSplitSampler, 33 | ExpectedNumInstanceSampler, 34 | SelectFields, 35 | ) 36 | from gluonts.torch.model.estimator import PyTorchLightningEstimator 37 | from gluonts.torch.model.predictor import PyTorchPredictor 38 | from gluonts.torch.distributions import DistributionOutput, StudentTOutput 39 | 40 | from .lightning_module import PatchTSTLightningModule 41 | 42 | PREDICTION_INPUT_NAMES = ["past_target", "past_observed_values"] 43 | 44 | TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ 45 | "future_target", 46 | "future_observed_values", 47 | ] 48 | 49 | 50 | class PatchTSTEstimator(PyTorchLightningEstimator): 51 | """ 52 | An estimator training the PatchTST model for forecasting. 53 | 54 | This class is uses the model defined in ``SimpleFeedForwardModel``, 55 | and wraps it into a ``PatchTSTLightningModule`` for training 56 | purposes: training is performed using PyTorch Lightning's ``pl.Trainer`` 57 | class. 58 | 59 | Parameters 60 | ---------- 61 | prediction_length 62 | Length of the prediction horizon. 63 | context_length 64 | Number of time steps prior to prediction time that the model 65 | takes as inputs (default: ``10 * prediction_length``). 66 | patch_len 67 | Length of the patch. 68 | stride 69 | Stride of the patch. 70 | padding_patch 71 | Padding of the patch. 72 | d_model 73 | Size of hidden layers in the Transformer encoder. 74 | nhead 75 | Number of attention heads in the Transformer encoder. 76 | dim_feedforward 77 | Size of hidden layers in the Transformer encoder. 78 | dropout 79 | Dropout probability in the Transformer encoder. 80 | activation 81 | Activation function in the Transformer encoder. 82 | norm_first 83 | Whether to apply normalization before or after the attention. 84 | num_encoder_layers 85 | Number of layers in the Transformer encoder. 86 | lr 87 | Learning rate (default: ``1e-3``). 88 | weight_decay 89 | Weight decay regularization parameter (default: ``1e-8``). 90 | scaling 91 | Scaling parameter can be "mean", "std" or None. 92 | distr_output 93 | Distribution to use to evaluate observations and sample predictions 94 | (default: StudentTOutput()). 95 | loss 96 | Loss to be optimized during training 97 | (default: ``NegativeLogLikelihood()``). 98 | batch_size 99 | The size of the batches to be used for training (default: 32). 100 | num_batches_per_epoch 101 | Number of batches to be processed in each training epoch 102 | (default: 50). 103 | trainer_kwargs 104 | Additional arguments to provide to ``pl.Trainer`` for construction. 105 | train_sampler 106 | Controls the sampling of windows during training. 107 | validation_sampler 108 | Controls the sampling of windows during validation. 109 | 110 | """ 111 | 112 | @validated() 113 | def __init__( 114 | self, 115 | prediction_length: int, 116 | patch_len: int, 117 | context_length: Optional[int] = None, 118 | stride: int = 8, 119 | padding_patch: str = "end", 120 | d_model: int = 32, 121 | nhead: int = 4, 122 | dim_feedforward: int = 128, 123 | dropout: float = 0.1, 124 | activation: str = "relu", 125 | norm_first: bool = False, 126 | num_encoder_layers: int = 2, 127 | lr: float = 1e-3, 128 | weight_decay: float = 1e-8, 129 | scaling: Optional[str] = "mean", 130 | distr_output: DistributionOutput = StudentTOutput(), 131 | loss: DistributionLoss = NegativeLogLikelihood(), 132 | batch_size: int = 32, 133 | num_batches_per_epoch: int = 50, 134 | trainer_kwargs: Optional[Dict[str, Any]] = None, 135 | train_sampler: Optional[InstanceSampler] = None, 136 | validation_sampler: Optional[InstanceSampler] = None, 137 | ) -> None: 138 | default_trainer_kwargs = {"max_epochs": 100} 139 | if trainer_kwargs is not None: 140 | default_trainer_kwargs.update(trainer_kwargs) 141 | super().__init__(trainer_kwargs=default_trainer_kwargs) 142 | 143 | self.prediction_length = prediction_length 144 | self.context_length = context_length or 10 * prediction_length 145 | # TODO find way to enforce same defaults to network and estimator 146 | # somehow 147 | self.lr = lr 148 | self.weight_decay = weight_decay 149 | self.distr_output = distr_output 150 | self.loss = loss 151 | self.scaling = scaling 152 | self.patch_len = patch_len 153 | self.stride = stride 154 | self.padding_patch = padding_patch 155 | self.d_model = d_model 156 | self.nhead = nhead 157 | self.dim_feedforward = dim_feedforward 158 | self.dropout = dropout 159 | self.activation = activation 160 | self.norm_first = norm_first 161 | self.num_encoder_layers = num_encoder_layers 162 | self.batch_size = batch_size 163 | self.num_batches_per_epoch = num_batches_per_epoch 164 | 165 | self.train_sampler = train_sampler or ExpectedNumInstanceSampler( 166 | num_instances=1.0, min_future=prediction_length 167 | ) 168 | self.validation_sampler = validation_sampler or ValidationSplitSampler( 169 | min_future=prediction_length 170 | ) 171 | 172 | def create_transformation(self) -> Transformation: 173 | return SelectFields( 174 | [ 175 | FieldName.ITEM_ID, 176 | FieldName.INFO, 177 | FieldName.START, 178 | FieldName.TARGET, 179 | ], 180 | allow_missing=True, 181 | ) + AddObservedValuesIndicator( 182 | target_field=FieldName.TARGET, 183 | output_field=FieldName.OBSERVED_VALUES, 184 | ) 185 | 186 | def create_lightning_module(self) -> pl.LightningModule: 187 | return PatchTSTLightningModule( 188 | loss=self.loss, 189 | lr=self.lr, 190 | weight_decay=self.weight_decay, 191 | model_kwargs={ 192 | "prediction_length": self.prediction_length, 193 | "context_length": self.context_length, 194 | "patch_len": self.patch_len, 195 | "stride": self.stride, 196 | "padding_patch": self.padding_patch, 197 | "d_model": self.d_model, 198 | "nhead": self.nhead, 199 | "dim_feedforward": self.dim_feedforward, 200 | "dropout": self.dropout, 201 | "activation": self.activation, 202 | "norm_first": self.norm_first, 203 | "num_encoder_layers": self.num_encoder_layers, 204 | "distr_output": self.distr_output, 205 | "scaling": self.scaling, 206 | }, 207 | ) 208 | 209 | def _create_instance_splitter(self, module: PatchTSTLightningModule, mode: str): 210 | assert mode in ["training", "validation", "test"] 211 | 212 | instance_sampler = { 213 | "training": self.train_sampler, 214 | "validation": self.validation_sampler, 215 | "test": TestSplitSampler(), 216 | }[mode] 217 | 218 | return InstanceSplitter( 219 | target_field=FieldName.TARGET, 220 | is_pad_field=FieldName.IS_PAD, 221 | start_field=FieldName.START, 222 | forecast_start_field=FieldName.FORECAST_START, 223 | instance_sampler=instance_sampler, 224 | past_length=self.context_length, 225 | future_length=self.prediction_length, 226 | time_series_fields=[FieldName.OBSERVED_VALUES], 227 | dummy_value=self.distr_output.value_in_support, 228 | ) 229 | 230 | def create_training_data_loader( 231 | self, 232 | data: Dataset, 233 | module: PatchTSTLightningModule, 234 | shuffle_buffer_length: Optional[int] = None, 235 | **kwargs, 236 | ) -> Iterable: 237 | data = Cyclic(data).stream() 238 | instances = self._create_instance_splitter(module, "training").apply( 239 | data, is_train=True 240 | ) 241 | return as_stacked_batches( 242 | instances, 243 | batch_size=self.batch_size, 244 | shuffle_buffer_length=shuffle_buffer_length, 245 | field_names=TRAINING_INPUT_NAMES, 246 | output_type=torch.tensor, 247 | num_batches_per_epoch=self.num_batches_per_epoch, 248 | ) 249 | 250 | def create_validation_data_loader( 251 | self, data: Dataset, module: PatchTSTLightningModule, **kwargs 252 | ) -> Iterable: 253 | instances = self._create_instance_splitter(module, "validation").apply( 254 | data, is_train=True 255 | ) 256 | return as_stacked_batches( 257 | instances, 258 | batch_size=self.batch_size, 259 | field_names=TRAINING_INPUT_NAMES, 260 | output_type=torch.tensor, 261 | ) 262 | 263 | def create_predictor( 264 | self, transformation: Transformation, module 265 | ) -> PyTorchPredictor: 266 | prediction_splitter = self._create_instance_splitter(module, "test") 267 | 268 | return PyTorchPredictor( 269 | input_transform=transformation + prediction_splitter, 270 | input_names=PREDICTION_INPUT_NAMES, 271 | prediction_net=module, 272 | forecast_generator=DistributionForecastGenerator(self.distr_output), 273 | batch_size=self.batch_size, 274 | prediction_length=self.prediction_length, 275 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 276 | ) 277 | -------------------------------------------------------------------------------- /PatchTST/lightning_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | import pytorch_lightning as pl 15 | import torch 16 | 17 | from gluonts.core.component import validated 18 | from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood 19 | 20 | from .module import PatchTSTModel 21 | 22 | 23 | class PatchTSTLightningModule(pl.LightningModule): 24 | """ 25 | A ``pl.LightningModule`` class that can be used to train a 26 | ``PatchTSTModel`` with PyTorch Lightning. 27 | 28 | This is a thin layer around a (wrapped) ``PatchTSTModel`` object, 29 | that exposes the methods to evaluate training and validation loss. 30 | 31 | Parameters 32 | ---------- 33 | model 34 | ``PatchTSTModel`` to be trained. 35 | loss 36 | Loss function to be used for training, 37 | default: ``NegativeLogLikelihood()``. 38 | lr 39 | Learning rate, default: ``1e-3``. 40 | weight_decay 41 | Weight decay regularization parameter, default: ``1e-8``. 42 | """ 43 | 44 | @validated() 45 | def __init__( 46 | self, 47 | model_kwargs: dict, 48 | loss: DistributionLoss = NegativeLogLikelihood(), 49 | lr: float = 1e-3, 50 | weight_decay: float = 1e-8, 51 | ): 52 | super().__init__() 53 | self.save_hyperparameters() 54 | self.model = PatchTSTModel(**model_kwargs) 55 | self.loss = loss 56 | self.lr = lr 57 | self.weight_decay = weight_decay 58 | 59 | def forward(self, *args, **kwargs): 60 | return self.model.forward(*args, **kwargs) 61 | 62 | def _compute_loss(self, batch): 63 | context = batch["past_target"] 64 | past_observed_values = batch["past_observed_values"] 65 | target = batch["future_target"] 66 | observed_target = batch["future_observed_values"] 67 | 68 | assert context.shape[-1] == self.model.context_length 69 | assert target.shape[-1] == self.model.prediction_length 70 | 71 | distr_args, loc, scale = self.model(context, past_observed_values) 72 | distr = self.model.distr_output.distribution(distr_args, loc, scale) 73 | 74 | return (self.loss(distr, target) * observed_target).sum() / torch.maximum( 75 | torch.tensor(1.0), observed_target.sum() 76 | ) 77 | 78 | def training_step(self, batch, batch_idx: int): # type: ignore 79 | """ 80 | Execute training step. 81 | """ 82 | train_loss = self._compute_loss(batch) 83 | self.log( 84 | "train_loss", 85 | train_loss, 86 | on_epoch=True, 87 | on_step=False, 88 | prog_bar=True, 89 | ) 90 | return train_loss 91 | 92 | def validation_step(self, batch, batch_idx: int): # type: ignore 93 | """ 94 | Execute validation step. 95 | """ 96 | val_loss = self._compute_loss(batch) 97 | self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) 98 | return val_loss 99 | 100 | def configure_optimizers(self): 101 | """ 102 | Returns the optimizer to use. 103 | """ 104 | return torch.optim.Adam( 105 | self.model.parameters(), 106 | lr=self.lr, 107 | weight_decay=self.weight_decay, 108 | ) 109 | -------------------------------------------------------------------------------- /PatchTST/module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from typing import Tuple 15 | 16 | import numpy as np 17 | import torch 18 | from torch import nn 19 | 20 | from gluonts.core.component import validated 21 | from gluonts.model import Input, InputSpec 22 | from gluonts.torch.distributions import StudentTOutput 23 | from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler 24 | from gluonts.torch.util import unsqueeze_expand 25 | 26 | 27 | def make_linear_layer(dim_in, dim_out): 28 | lin = nn.Linear(dim_in, dim_out) 29 | torch.nn.init.uniform_(lin.weight, -0.07, 0.07) 30 | torch.nn.init.zeros_(lin.bias) 31 | return lin 32 | 33 | 34 | class SinusoidalPositionalEmbedding(nn.Embedding): 35 | """This module produces sinusoidal positional embeddings of any length.""" 36 | 37 | def __init__(self, num_positions: int, embedding_dim: int) -> None: 38 | super().__init__(num_positions, embedding_dim) 39 | self.weight = self._init_weight(self.weight) 40 | 41 | @staticmethod 42 | def _init_weight(out: nn.Parameter) -> nn.Parameter: 43 | """ 44 | Features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] 45 | """ 46 | n_pos, dim = out.shape 47 | position_enc = np.array( 48 | [ 49 | [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] 50 | for pos in range(n_pos) 51 | ] 52 | ) 53 | # set early to avoid an error in pytorch-1.8+ 54 | out.requires_grad = False 55 | 56 | sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 57 | out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) 58 | out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) 59 | out.detach_() 60 | return out 61 | 62 | @torch.no_grad() 63 | def forward( 64 | self, input_ids_shape: torch.Size, past_key_values_length: int = 0 65 | ) -> torch.Tensor: 66 | """`input_ids_shape` is expected to be [bsz x seqlen x ...].""" 67 | _, seq_len = input_ids_shape[:2] 68 | positions = torch.arange( 69 | past_key_values_length, 70 | past_key_values_length + seq_len, 71 | dtype=torch.long, 72 | device=self.weight.device, 73 | ) 74 | return super().forward(positions) 75 | 76 | 77 | class PatchTSTModel(nn.Module): 78 | """ 79 | Module implementing the PatchTST model for forecasting. 80 | 81 | Parameters 82 | ---------- 83 | prediction_length 84 | Number of time points to predict. 85 | context_length 86 | Number of time steps prior to prediction time that the model. 87 | distr_output 88 | Distribution to use to evaluate observations and sample predictions. 89 | Default: ``StudentTOutput()``. 90 | """ 91 | 92 | @validated() 93 | def __init__( 94 | self, 95 | prediction_length: int, 96 | context_length: int, 97 | patch_len: int, 98 | stride: int, 99 | padding_patch: str, 100 | d_model: int, 101 | nhead: int, 102 | dim_feedforward: int, 103 | dropout: float, 104 | activation: str, 105 | norm_first: bool, 106 | num_encoder_layers: int, 107 | scaling: str, 108 | distr_output=StudentTOutput(), 109 | ) -> None: 110 | super().__init__() 111 | 112 | assert prediction_length > 0 113 | assert context_length > 0 114 | 115 | self.prediction_length = prediction_length 116 | self.context_length = context_length 117 | self.patch_len = patch_len 118 | self.stride = stride 119 | self.d_model = d_model 120 | self.padding_patch = padding_patch 121 | self.distr_output = distr_output 122 | 123 | if scaling == "mean": 124 | self.scaler = MeanScaler(keepdim=True) 125 | elif scaling == "std": 126 | self.scaler = StdScaler(keepdim=True) 127 | else: 128 | self.scaler = NOPScaler(keepdim=True) 129 | 130 | self.patch_num = int((context_length - patch_len) / stride + 1) 131 | if padding_patch == "end": # can be modified to general case 132 | self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride)) 133 | self.patch_num += 1 134 | 135 | # project from patch_len + 2 features (loc and scale) to d_model 136 | self.patch_proj = make_linear_layer(patch_len + 2, d_model) 137 | 138 | self.positional_encoding = SinusoidalPositionalEmbedding( 139 | self.patch_num, d_model 140 | ) 141 | 142 | layer_norm_eps: float = 1e-5 143 | encoder_layer = nn.TransformerEncoderLayer( 144 | d_model=d_model, 145 | nhead=nhead, 146 | dim_feedforward=dim_feedforward, 147 | dropout=dropout, 148 | activation=activation, 149 | layer_norm_eps=layer_norm_eps, 150 | batch_first=True, 151 | norm_first=norm_first, 152 | ) 153 | encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) 154 | self.encoder = nn.TransformerEncoder( 155 | encoder_layer, num_encoder_layers, encoder_norm 156 | ) 157 | 158 | self.flatten = nn.Linear(d_model * self.patch_num, prediction_length * d_model) 159 | 160 | self.args_proj = self.distr_output.get_args_proj(d_model) 161 | 162 | def describe_inputs(self, batch_size=1) -> InputSpec: 163 | return InputSpec( 164 | { 165 | "past_target": Input( 166 | shape=(batch_size, self.context_length), dtype=torch.float 167 | ), 168 | "past_observed_values": Input( 169 | shape=(batch_size, self.context_length), dtype=torch.float 170 | ), 171 | }, 172 | torch.zeros, 173 | ) 174 | 175 | def forward( 176 | self, 177 | past_target: torch.Tensor, 178 | past_observed_values: torch.Tensor, 179 | ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: 180 | # scale the input 181 | past_target_scaled, loc, scale = self.scaler(past_target, past_observed_values) 182 | 183 | # do patching 184 | if self.padding_patch == "end": 185 | past_target_scaled = self.padding_patch_layer(past_target_scaled) 186 | past_target_patches = past_target_scaled.unfold( 187 | dimension=1, size=self.patch_len, step=self.stride 188 | ) 189 | 190 | # add loc and scale to past_target_patches as additional features 191 | log_abs_loc = loc.abs().log1p() 192 | log_scale = scale.log() 193 | expanded_static_feat = unsqueeze_expand( 194 | torch.cat([log_abs_loc, log_scale], dim=-1), 195 | dim=1, 196 | size=past_target_patches.shape[1], 197 | ) 198 | inputs = torch.cat((past_target_patches, expanded_static_feat), dim=-1) 199 | 200 | # project patches 201 | enc_in = self.patch_proj(inputs) 202 | embed_pos = self.positional_encoding(enc_in.size()) 203 | 204 | # transformer encoder with positional encoding 205 | enc_out = self.encoder(enc_in + embed_pos) 206 | 207 | # flatten and project to prediction length * d_model 208 | flatten_out = self.flatten(enc_out.flatten(start_dim=1)) 209 | 210 | # project to distribution arguments 211 | distr_args = self.args_proj( 212 | flatten_out.reshape(-1, self.prediction_length, self.d_model) 213 | ) 214 | return distr_args, loc, scale 215 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConvTSMixer, MlpTSMixer and TsT 2 | 3 | Time Series Extension of "Patches Are All You Need" 4 | 5 | Implementations of some prototypical time series mixers based on Conv, MLP, and ViT archs. modified for the probabilistic multivariate forecasting use case, where the emission head is currently an "independent same-family" distribution, e.g., diagonal Student-T. 6 | 7 | In everything that follows, the inputs are typically 4-Tensors of shape `[Batch, Variate-dim, Context-length, 1+Features]`, and during training, the subsequent prediction window values are given `[B, Variate-dim, Pred-length]`. The inputs are embedded via 2d-conv to obtain patch embeddings: 8 | 9 | ![Screenshot 2024-10-25 at 09 29 21](https://github.com/user-attachments/assets/5ecc92c5-a115-44a1-95de-971d2d34ed58) 10 | 11 | ## ConvTSMixer 12 | 13 | ![Screenshot 2024-10-25 at 09 29 55](https://github.com/user-attachments/assets/2033b75b-6105-4c5d-bfc6-aaf9c8330d5d) 14 | 15 | 16 | ## MlpTSMixer 17 | 18 | ![Screenshot 2024-10-25 at 09 30 09](https://github.com/user-attachments/assets/14db4edd-2004-4152-acc7-3c926c97c306) 19 | 20 | ## TsT (ViT style) 21 | 22 | ![Screenshot 2024-10-25 at 09 30 40](https://github.com/user-attachments/assets/6a5f88d4-8ded-41fe-9cc5-3ce1ab9df5b9) 23 | 24 | ## Output head 25 | 26 | ![Screenshot 2024-10-25 at 09 31 12](https://github.com/user-attachments/assets/324bec4c-d4bf-40ea-9084-413d1c647870) 27 | -------------------------------------------------------------------------------- /TSM-hyperparameter_tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "id": "86ff733d-8ea7-4d77-8d2a-cd329ab8f385", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import matplotlib.pyplot as plt\n", 13 | "import numpy as np\n", 14 | "import pandas as pd\n", 15 | "\n", 16 | "import torch\n", 17 | "\n", 18 | "from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n", 19 | "from gluonts.dataset.repository.datasets import get_dataset\n", 20 | "from gluonts.evaluation.backtest import make_evaluation_predictions\n", 21 | "from gluonts.evaluation import MultivariateEvaluator\n", 22 | "\n", 23 | "from TSMixer import TSMixerEstimator\n", 24 | "import random\n", 25 | "import numpy as np\n", 26 | "import time\n", 27 | "import optuna\n", 28 | "from optuna.samplers import TPESampler" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 4, 34 | "id": "944676b0-2a9f-4301-92e6-f382f5693639", 35 | "metadata": { 36 | "tags": [] 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "class TSMixerObjective:\n", 41 | " def __init__(\n", 42 | " self,\n", 43 | " dataset,\n", 44 | " train_grouper,\n", 45 | " test_grouper,\n", 46 | " metric_type=\"m_sum_mean_wQuantileLoss\",\n", 47 | " ):\n", 48 | " self.metric_type = metric_type\n", 49 | " self.dataset = dataset\n", 50 | " self.dataset_train = train_grouper(self.dataset.train)\n", 51 | " self.dataset_test = test_grouper(self.dataset.test)\n", 52 | "\n", 53 | " def get_params(self, trial) -> dict:\n", 54 | " return {\n", 55 | " \"context_length\": trial.suggest_int(\n", 56 | " \"context_length\",\n", 57 | " dataset.metadata.prediction_length,\n", 58 | " dataset.metadata.prediction_length * 5,\n", 59 | " 1,\n", 60 | " ),\n", 61 | " \"batch_size\": trial.suggest_int(\"batch_size\", 32, 256, 32),\n", 62 | " \"depth\": trial.suggest_int(\"depth\", 2, 16, 4),\n", 63 | " \"expansion_factor\": trial.suggest_int(\"expansion_factor\", 2, 8, 2),\n", 64 | " \"dim\": trial.suggest_int(\"dim\", 16, 64, 16),\n", 65 | " }\n", 66 | "\n", 67 | " def __call__(self, trial):\n", 68 | " params = self.get_params(trial)\n", 69 | " estimator = TSMixerEstimator(\n", 70 | " # distr_output=StudentTOutput(dim=int(dataset.metadata.feat_static_cat[0].cardinality)),\n", 71 | " input_size=int(self.dataset.metadata.feat_static_cat[0].cardinality),\n", 72 | " prediction_length=self.dataset.metadata.prediction_length,\n", 73 | " context_length=params[\"context_length\"],\n", 74 | " freq=self.dataset.metadata.freq,\n", 75 | " scaling=\"std\",\n", 76 | " depth=params[\"depth\"],\n", 77 | " dim=params[\"dim\"],\n", 78 | " expansion_factor=params[\"expansion_factor\"],\n", 79 | " batch_size=params[\"batch_size\"],\n", 80 | " num_batches_per_epoch=100,\n", 81 | " trainer_kwargs=dict(accelerator=\"cuda\", max_epochs=30),\n", 82 | " )\n", 83 | " predictor = estimator.train(\n", 84 | " training_data=self.dataset_train, num_workers=8, shuffle_buffer_length=1024\n", 85 | " )\n", 86 | "\n", 87 | " forecast_it, ts_it = make_evaluation_predictions(\n", 88 | " dataset=self.dataset_test, predictor=predictor, num_samples=100\n", 89 | " )\n", 90 | " forecasts = list(forecast_it)\n", 91 | " tss = list(ts_it)\n", 92 | " evaluator = MultivariateEvaluator(\n", 93 | " quantiles=(np.arange(20) / 20.0)[1:], target_agg_funcs={\"sum\": np.sum}\n", 94 | " )\n", 95 | " agg_metrics, _ = evaluator(iter(tss), iter(forecasts))\n", 96 | " return agg_metrics[self.metric_type]" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "d6c1da22-755e-44d5-86b6-10465d8c25e8", 103 | "metadata": { 104 | "tags": [] 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "dataset = get_dataset(\"solar_nips\", regenerate=False)\n", 109 | "train_grouper = MultivariateGrouper(\n", 110 | " max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality)\n", 111 | ")\n", 112 | "\n", 113 | "test_grouper = MultivariateGrouper(\n", 114 | " num_test_dates=int(len(dataset.test) / len(dataset.train)),\n", 115 | " max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n", 116 | ")\n", 117 | "dataset_train = train_grouper(dataset.train)\n", 118 | "dataset_test = test_grouper(dataset.test)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "9ded5c21-ea78-4f99-98e3-1573f5abfbf2", 125 | "metadata": { 126 | "tags": [] 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "seed = 42\n", 131 | "random.seed(seed)\n", 132 | "torch.manual_seed(seed)\n", 133 | "start_time = time.time()\n", 134 | "sampler = TPESampler(seed=seed)\n", 135 | "study = optuna.create_study(sampler=sampler, direction=\"minimize\")\n", 136 | "study.optimize(TSMixerObjective(dataset, train_grouper, test_grouper), n_trials=10)\n", 137 | "\n", 138 | "print(\"Number of finished trials: {}\".format(len(study.trials)))\n", 139 | "\n", 140 | "print(\"Best trial:\")\n", 141 | "trial = study.best_trial\n", 142 | "\n", 143 | "print(\" Value: {}\".format(trial.value))\n", 144 | "print(\" Params: \")\n", 145 | "for key, value in trial.params.items():\n", 146 | " print(\" {}: {}\".format(key, value))\n", 147 | "print(time.time() - start_time)" 148 | ] 149 | } 150 | ], 151 | "metadata": { 152 | "kernelspec": { 153 | "display_name": "py38", 154 | "language": "python", 155 | "name": "py38" 156 | }, 157 | "language_info": { 158 | "codemirror_mode": { 159 | "name": "ipython", 160 | "version": 3 161 | }, 162 | "file_extension": ".py", 163 | "mimetype": "text/x-python", 164 | "name": "python", 165 | "nbconvert_exporter": "python", 166 | "pygments_lexer": "ipython3", 167 | "version": "3.8.10" 168 | } 169 | }, 170 | "nbformat": 4, 171 | "nbformat_minor": 5 172 | } 173 | -------------------------------------------------------------------------------- /TSMixer/__init__.py: -------------------------------------------------------------------------------- 1 | from .module import TSMixerModel 2 | from .lightning_module import TSMixerLightningModule 3 | from .estimator import TSMixerEstimator 4 | 5 | __all__ = ["TSMixerModel", "TSMixerLightningModule", "TSMixerEstimator"] 6 | -------------------------------------------------------------------------------- /TSMixer/lightning_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | import pytorch_lightning as pl 15 | import torch 16 | 17 | from gluonts.core.component import validated 18 | 19 | from .module import TSMixerModel 20 | 21 | 22 | class TSMixerLightningModule(pl.LightningModule): 23 | """ 24 | A ``pl.LightningModule`` class that can be used to train a 25 | ``TSMixerModel`` with PyTorch Lightning. 26 | 27 | This is a thin layer around a (wrapped) ``TSMixerModel`` object, 28 | that exposes the methods to evaluate training and validation loss. 29 | 30 | Parameters 31 | ---------- 32 | model 33 | ``TSMixerModel`` to be trained. 34 | lr 35 | Learning rate, default: ``1e-3``. 36 | weight_decay 37 | Weight decay regularization parameter, default: ``1e-8``. 38 | """ 39 | 40 | @validated() 41 | def __init__( 42 | self, 43 | model_kwargs: dict, 44 | lr: float = 1e-3, 45 | weight_decay: float = 1e-8, 46 | ): 47 | super().__init__() 48 | self.save_hyperparameters() 49 | self.model = TSMixerModel(**model_kwargs) 50 | self.lr = lr 51 | self.weight_decay = weight_decay 52 | 53 | def forward(self, *args, **kwargs): 54 | distr_args, loc, scale = self.model.forward(*args, **kwargs) 55 | distr = self.model.distr_output.distribution(distr_args, loc, scale) 56 | return distr.sample((self.model.num_parallel_samples,)).reshape( 57 | -1, 58 | self.model.num_parallel_samples, 59 | self.model.prediction_length, 60 | self.model.input_size, 61 | ) 62 | 63 | def _compute_loss(self, batch): 64 | past_target = batch["past_target"] 65 | past_observed_values = batch["past_observed_values"] 66 | target = batch["future_target"] 67 | observed_target = batch["future_observed_values"] 68 | 69 | assert past_target.shape[1] == self.model.context_length 70 | assert target.shape[1] == self.model.prediction_length 71 | 72 | distr_args, loc, scale = self.model( 73 | past_target=past_target, 74 | past_observed_values=past_observed_values, 75 | past_time_feat=batch["past_time_feat"], 76 | future_time_feat=batch["future_time_feat"], 77 | ) 78 | loss_values = self.model.distr_output.loss( 79 | target=target, distr_args=distr_args, loc=loc, scale=scale 80 | ) 81 | return (loss_values * observed_target).sum() / torch.maximum( 82 | torch.tensor(1.0), observed_target.sum() 83 | ) 84 | 85 | def training_step(self, batch, batch_idx: int): # type: ignore 86 | """ 87 | Execute training step. 88 | """ 89 | train_loss = self._compute_loss(batch) 90 | self.log( 91 | "train_loss", 92 | train_loss, 93 | on_epoch=True, 94 | on_step=False, 95 | prog_bar=True, 96 | ) 97 | return train_loss 98 | 99 | def validation_step(self, batch, batch_idx: int): # type: ignore 100 | """ 101 | Execute validation step. 102 | """ 103 | val_loss = self._compute_loss(batch) 104 | self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) 105 | return val_loss 106 | 107 | def configure_optimizers(self): 108 | """ 109 | Returns the optimizer to use. 110 | """ 111 | return torch.optim.Adam( 112 | self.model.parameters(), 113 | lr=self.lr, 114 | weight_decay=self.weight_decay, 115 | ) 116 | -------------------------------------------------------------------------------- /TSMixer/module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from typing import Tuple, Optional 15 | 16 | import torch 17 | from torch import nn 18 | from einops.layers.torch import Rearrange 19 | from einops import rearrange 20 | 21 | from gluonts.core.component import validated 22 | from gluonts.model import Input, InputSpec 23 | from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler 24 | from gluonts.torch.distributions import StudentTOutput 25 | 26 | 27 | class PreNormResidual(nn.Module): 28 | """ 29 | Pre-Normalization Residual Block. Applies Layer-Normalization over the features and prediction_length dimensions. 30 | 31 | :argument 32 | - dim (int): input dimension 33 | - prediction_length (int): prediction length 34 | - fn (function): function to be applied 35 | 36 | :return 37 | - x (tensor): output tensor 38 | """ 39 | 40 | def __init__(self, dim: int, prediction_length: int, fn): 41 | super().__init__() 42 | self.fn = fn 43 | self.norm = nn.LayerNorm([dim, prediction_length]) 44 | 45 | def forward(self, x): 46 | return self.fn(self.norm(x)) + x 47 | 48 | 49 | class CtxMap(nn.Module): 50 | """ 51 | Module implementing the mapping from the context-length to the forecast length for TSMixer. 52 | 53 | :argument 54 | - context_length (int): context length 55 | - prediction_length (int): prediction length 56 | 57 | :return 58 | - x (tensor): output tensor 59 | """ 60 | 61 | def __init__(self, context_length: int, prediction_length: int): 62 | super().__init__() 63 | self.context_length = context_length 64 | self.prediction_length = prediction_length 65 | 66 | self.fc = nn.Sequential( 67 | Rearrange("b nf h ns -> b nf ns h"), 68 | nn.Linear(self.context_length, self.prediction_length), 69 | Rearrange("b nf ns h -> b nf h ns"), 70 | ) 71 | 72 | def forward(self, x): 73 | out = self.fc(x) 74 | return out 75 | 76 | 77 | class MLPTimeBlock(nn.Module): 78 | """MLP for time embedding. 79 | 80 | :argument 81 | - prediction_length (int): prediction length 82 | - dropout (float): dropout rate 83 | 84 | :return 85 | - x (tensor): output tensor 86 | """ 87 | 88 | def __init__(self, prediction_length: int, dropout: float = 0.1): 89 | super().__init__() 90 | 91 | self.time_mlp = nn.Sequential( 92 | nn.Linear(prediction_length, prediction_length), 93 | nn.ReLU(), 94 | nn.Dropout(dropout), 95 | ) 96 | 97 | def forward(self, x): 98 | out = self.time_mlp(x) 99 | return out 100 | 101 | 102 | class MLPFeatBlock(nn.Module): 103 | """MLPs for feature embedding. 104 | 105 | :argument 106 | - in_channels (int): input channels 107 | - hidden_size (int): hidden size 108 | - dropout (float): dropout rate, default 0.1 109 | 110 | :return 111 | - x (tensor): output tensor 112 | """ 113 | 114 | def __init__(self, in_channels: int, hidden_size: int, dropout: float = 0.1): 115 | super().__init__() 116 | 117 | self.feat_mlp = nn.Sequential( 118 | Rearrange("b ns nf h -> b ns h nf"), 119 | nn.Linear(in_channels, hidden_size), 120 | nn.ReLU(), 121 | nn.Dropout(dropout), 122 | nn.Linear(hidden_size, in_channels), 123 | nn.Dropout(dropout), 124 | Rearrange("b ns h nf -> b ns nf h"), 125 | ) 126 | 127 | def forward(self, x): 128 | out = self.feat_mlp(x) 129 | return out 130 | 131 | 132 | class MLPFeatMap(nn.Module): 133 | """MLP on feature domain. 134 | 135 | :argument 136 | - in_channels (int): input channels 137 | - hidden_size (int): hidden size 138 | - dropout (float): dropout rate 139 | 140 | :return 141 | - x (tensor): output tensor 142 | """ 143 | 144 | def __init__(self, in_channels: int, hidden_size: int, dropout: float = 0.1): 145 | super().__init__() 146 | self.fc = nn.Sequential( 147 | Rearrange("b nf h ns -> b h ns nf"), 148 | nn.Linear(in_channels, hidden_size), 149 | nn.ReLU(), 150 | nn.Dropout(dropout), 151 | Rearrange("b h ns nf -> b nf h ns"), 152 | ) 153 | 154 | def forward(self, x): 155 | out = self.fc(x) 156 | return out 157 | 158 | 159 | class TSMixerModel(nn.Module): 160 | """ 161 | Module implementingTSMixer for forecasting. 162 | 163 | Parameters 164 | ---------- 165 | prediction_length 166 | Number of time points to predict. 167 | context_length 168 | Number of time steps prior to prediction time that the model. 169 | distr_output 170 | Distribution to use to evaluate observations and sample predictions. 171 | Default: ``StudentTOutput()``. 172 | batch_norm 173 | Whether to apply batch normalization. Default: ``False``. 174 | """ 175 | 176 | @validated() 177 | def __init__( 178 | self, 179 | prediction_length: int, 180 | context_length: int, 181 | scaling: str, 182 | input_size: int, 183 | depth: int, 184 | dim: int, 185 | expansion_factor: int = 4, 186 | dropout: float = 0.1, 187 | num_feat_dynamic_real: int = 0, 188 | num_feat_static_real: int = 0, 189 | num_feat_static_cat: int = 0, 190 | distr_output=StudentTOutput(), 191 | num_parallel_samples: int = 100, 192 | batch_norm: bool = False, 193 | ) -> None: 194 | super().__init__() 195 | 196 | assert prediction_length > 0 197 | assert context_length > 0 198 | assert depth > 0 199 | 200 | self.distr_output = distr_output 201 | self.prediction_length = prediction_length 202 | self.context_length = context_length 203 | self.input_size = input_size 204 | self.num_feat_static_real = num_feat_static_real 205 | self.num_feat_dynamic_real = num_feat_dynamic_real 206 | self.num_parallel_samples = num_parallel_samples 207 | 208 | if scaling == "mean": 209 | self.scaler = MeanScaler(keepdim=True, dim=1) 210 | elif scaling == "std": 211 | self.scaler = StdScaler(keepdim=True, dim=1) 212 | else: 213 | self.scaler = NOPScaler(keepdim=True, dim=1) 214 | 215 | self.linear_map = CtxMap(self.context_length, self.prediction_length) 216 | self.mlp_x = MLPFeatMap(self._number_of_features, dim, dropout) 217 | self.mlp_z = MLPFeatMap(self.num_feat_dynamic_real, dim, dropout) 218 | 219 | dim_xz = dim * 2 # since x and z are concatenated along the feature dimension 220 | 221 | self.mlp_mixer_block = nn.Sequential( 222 | Rearrange("b nf h ns -> b ns nf h"), 223 | *[ 224 | nn.Sequential( 225 | PreNormResidual( 226 | dim_xz, 227 | self.prediction_length, 228 | MLPTimeBlock(self.prediction_length, dropout), 229 | ), 230 | PreNormResidual( 231 | dim_xz, 232 | self.prediction_length, 233 | MLPFeatBlock(dim_xz, dim_xz * expansion_factor, dropout), 234 | ), 235 | ) 236 | for _ in range(depth) 237 | ], 238 | Rearrange("b ns nf h -> b h ns nf"), 239 | ) 240 | 241 | self.args_proj = self.distr_output.get_args_proj(dim_xz) 242 | 243 | @property 244 | def _number_of_features(self) -> int: 245 | return ( 246 | self.num_feat_dynamic_real 247 | + self.num_feat_static_real 248 | + 3 # 1 + the log(loc) + log1p(scale) 249 | ) 250 | 251 | def describe_inputs(self, batch_size=1) -> InputSpec: 252 | return InputSpec( 253 | { 254 | "past_target": Input( 255 | shape=(batch_size, self.context_length, self.input_size), 256 | dtype=torch.float, 257 | ), 258 | "past_observed_values": Input( 259 | shape=(batch_size, self.context_length, self.input_size), 260 | dtype=torch.float, 261 | ), 262 | }, 263 | torch.zeros, 264 | ) 265 | 266 | def forward( 267 | self, 268 | feat_static_cat: Optional[torch.Tensor] = None, 269 | feat_static_real: Optional[torch.Tensor] = None, 270 | past_time_feat: Optional[torch.Tensor] = None, 271 | past_target: Optional[torch.Tensor] = None, 272 | past_observed_values: Optional[torch.Tensor] = None, 273 | future_time_feat: Optional[torch.Tensor] = None, 274 | future_target: Optional[torch.Tensor] = None, 275 | future_observed_values: Optional[torch.Tensor] = None, 276 | ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: 277 | past_target_scaled, loc, scale = self.scaler(past_target, past_observed_values) 278 | 279 | past_target_scaled = past_target_scaled.unsqueeze(1) # channel dim 280 | 281 | log_abs_loc = loc.sign().unsqueeze(1).expand_as(past_target_scaled) * loc.abs().log1p().unsqueeze(1).expand_as(past_target_scaled) 282 | log_scale = scale.log().unsqueeze(1).expand_as(past_target_scaled) 283 | 284 | past_time_feat = ( 285 | past_time_feat.transpose(2, 1) 286 | .unsqueeze(-1) 287 | .repeat_interleave(dim=-1, repeats=self.input_size) 288 | ) 289 | 290 | # x: historical data of shape (batch_size, Cx, context_length, n_series) 291 | # z: future time-varying features of shape (batch_size, Cz, prediction_length, n_series) 292 | # s: static features of shape (batch_size, Cs, prediction_length, n_series) 293 | 294 | # b: batch 295 | # h: fcst_h 296 | # ns: n_series 297 | # nf: n_features 298 | 299 | x = torch.cat( 300 | ( 301 | past_target_scaled, 302 | log_abs_loc, 303 | log_scale, 304 | past_time_feat, 305 | ), 306 | dim=1, 307 | ) 308 | 309 | future_time_feat_repeat = future_time_feat.unsqueeze(2).repeat_interleave( 310 | dim=2, repeats=self.input_size 311 | ) 312 | 313 | z = rearrange(future_time_feat_repeat, "b h ns nf -> b nf h ns") 314 | 315 | x = self.linear_map(x) 316 | x_prime = self.mlp_x(x) 317 | z_prime = self.mlp_z(z) 318 | y_prime = torch.cat([x_prime, z_prime], dim=1) 319 | nn_out = self.mlp_mixer_block(y_prime) # self.mixer_blocks(y_prime, s) 320 | distr_args = self.args_proj(nn_out) 321 | 322 | return distr_args, loc, scale 323 | -------------------------------------------------------------------------------- /TSMixer/version_old/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from .module import TSMixerModel 15 | from .lightning_module import TSMixerLightningModule 16 | from .estimator import TSMixerEstimator 17 | 18 | __all__ = ["TSMixerModel", "TSMixerLightningModule", "TSMixerEstimator"] 19 | -------------------------------------------------------------------------------- /TSMixer/version_old/lightning_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | import pytorch_lightning as pl 15 | import torch 16 | 17 | from gluonts.core.component import validated 18 | from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood 19 | 20 | from .module import TSMixerModel 21 | 22 | 23 | class TSMixerLightningModule(pl.LightningModule): 24 | """ 25 | A ``pl.LightningModule`` class that can be used to train a 26 | ``TSMixerModel`` with PyTorch Lightning. 27 | 28 | This is a thin layer around a (wrapped) ``TSMixerModel`` object, 29 | that exposes the methods to evaluate training and validation loss. 30 | 31 | Parameters 32 | ---------- 33 | model 34 | ``TSMixerModel`` to be trained. 35 | loss 36 | Loss function to be used for training, 37 | default: ``NegativeLogLikelihood()``. 38 | lr 39 | Learning rate, default: ``1e-3``. 40 | weight_decay 41 | Weight decay regularization parameter, default: ``1e-8``. 42 | """ 43 | 44 | @validated() 45 | def __init__( 46 | self, 47 | model_kwargs: dict, 48 | loss: DistributionLoss = NegativeLogLikelihood(), 49 | lr: float = 1e-3, 50 | weight_decay: float = 1e-8, 51 | ): 52 | super().__init__() 53 | self.save_hyperparameters() 54 | self.model = TSMixerModel(**model_kwargs) 55 | self.loss = loss 56 | self.lr = lr 57 | self.weight_decay = weight_decay 58 | 59 | def forward(self, *args, **kwargs): 60 | return self.model.forward(*args, **kwargs) 61 | 62 | def _compute_loss(self, batch): 63 | past_target = batch["past_target"] 64 | past_observed_values = batch["past_observed_values"] 65 | target = batch["future_target"] 66 | observed_target = batch["future_observed_values"] 67 | 68 | assert past_target.shape[-1] == self.model.context_length 69 | assert target.shape[-1] == self.model.prediction_length 70 | 71 | distr_args, loc, scale = self.model( 72 | past_target=past_target, past_observed_values=past_observed_values 73 | ) 74 | distr = self.model.distr_output.distribution(distr_args, loc, scale) 75 | 76 | return (self.loss(distr, target) * observed_target).sum() / torch.maximum( 77 | torch.tensor(1.0), observed_target.sum() 78 | ) 79 | 80 | def training_step(self, batch, batch_idx: int): # type: ignore 81 | """ 82 | Execute training step. 83 | """ 84 | train_loss = self._compute_loss(batch) 85 | self.log( 86 | "train_loss", 87 | train_loss, 88 | on_epoch=True, 89 | on_step=False, 90 | prog_bar=True, 91 | ) 92 | return train_loss 93 | 94 | def validation_step(self, batch, batch_idx: int): # type: ignore 95 | """ 96 | Execute validation step. 97 | """ 98 | val_loss = self._compute_loss(batch) 99 | self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) 100 | return val_loss 101 | 102 | def configure_optimizers(self): 103 | """ 104 | Returns the optimizer to use. 105 | """ 106 | return torch.optim.Adam( 107 | self.model.parameters(), 108 | lr=self.lr, 109 | weight_decay=self.weight_decay, 110 | ) 111 | -------------------------------------------------------------------------------- /TSMixer/version_old/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class MLP_Time(nn.Module): 5 | """MLP for time embedding. According to the paper, the authors employ a single layer perceptron. 6 | 7 | :argument 8 | - in_channels (int): input channels 9 | - ts_length (int): time series length 10 | - dropout (float): dropout rate 11 | - batch_norm (bool): whether to apply batch normalization 12 | 13 | :return 14 | - x (tensor): output tensor of shape (batch_size, ts_length, in_channels) 15 | """ 16 | 17 | def __init__( 18 | self, 19 | in_channels: int, 20 | ts_length: int, 21 | dropout: float = 0.1, 22 | batch_norm: bool = True, 23 | ): 24 | super().__init__() 25 | 26 | # BatchNorm2d is applied to the time dimension 27 | self.batch_norm2d = nn.BatchNorm2d(ts_length) if batch_norm is True else None 28 | self.in_channels = in_channels 29 | 30 | # MLP for time embedding 31 | self.time_mlp = nn.Sequential( 32 | nn.Linear(ts_length, ts_length), nn.ReLU(), nn.Dropout(dropout) 33 | ) 34 | 35 | def forward(self, x): 36 | if self.batch_norm2d is not None: 37 | x_norm = x.unsqueeze(-1) if self.in_channels == 1 else x 38 | x_norm = self.batch_norm2d(x_norm) 39 | x_norm = x_norm.squeeze(-1) if self.in_channels == 1 else x_norm 40 | else: 41 | x_norm = x 42 | x_time = self.time_mlp(x_norm.transpose(1, 2)).transpose(1, 2) 43 | return ( 44 | x + x_time 45 | ) # not sure if we need a residual connection here, the paper doesn't mention it. 46 | 47 | 48 | class MLP_Feat(nn.Module): 49 | """MLPs for feature embedding. 50 | 51 | :argument 52 | - in_channels (int): input channels 53 | - embed_dim (int): embedding dimension 54 | - dropout (float): dropout rate, default 0.1 55 | - batch_norm (bool): whether to apply batch normalization 56 | 57 | :return 58 | - x (tensor): output tensor of shape (batch_size, ts_length, in_channels) 59 | """ 60 | 61 | def __init__( 62 | self, 63 | in_channels: int, 64 | embed_dim: int, 65 | dropout: float = 0.1, 66 | batch_norm: bool = True, 67 | ): 68 | super().__init__() 69 | 70 | # BatchNorm2d is applied to the feature dimension 71 | self.batch_norm2d = nn.BatchNorm2d(in_channels) if batch_norm is True else None 72 | self.in_channels = in_channels 73 | 74 | # MLPs for feature embedding 75 | self.feat_mlp1 = nn.Sequential( 76 | nn.Linear(in_channels, embed_dim), nn.ReLU(), nn.Dropout(dropout) 77 | ) 78 | 79 | self.feat_mlp2 = nn.Sequential( 80 | nn.Linear(embed_dim, in_channels), nn.Dropout(dropout) 81 | ) 82 | 83 | def forward(self, x): 84 | if self.batch_norm2d is not None: 85 | x_norm = ( 86 | x.transpose(1, 2).unsqueeze(-1) 87 | if self.in_channels == 1 88 | else x.transpose(1, 2) 89 | ) 90 | x_norm = self.batch_norm2d(x_norm) 91 | x_norm = ( 92 | x_norm.transpose(1, 2).squeeze(-1) 93 | if self.in_channels == 1 94 | else x_norm.transpose(1, 2) 95 | ) 96 | else: 97 | x_norm = x 98 | x_feat = self.feat_mlp1(x_norm) 99 | return x + self.feat_mlp2(x_feat) 100 | 101 | 102 | class Mixer_Block(nn.Module): 103 | """Mixer block. 104 | 105 | :argument 106 | - in_channels (int): input channels 107 | - ts_length (int): time series length 108 | - embed_dim (int): embedding dimension 109 | - dropout (float): dropout rate, default 0.1 110 | - batch_norm (bool): whether to apply batch normalization 111 | 112 | :return 113 | - x (tensor): output tensor of shape (batch_size, ts_length, in_channels) 114 | """ 115 | 116 | def __init__( 117 | self, 118 | in_channels: int, 119 | ts_length: int, 120 | embed_dim: int, 121 | dropout: float = 0.1, 122 | batch_norm: bool = True, 123 | ): 124 | super().__init__() 125 | self.mlp_time = MLP_Time(in_channels, ts_length, dropout, batch_norm) 126 | self.mlp_feat = MLP_Feat(in_channels, embed_dim, dropout, batch_norm) 127 | 128 | def forward(self, x): 129 | x = self.mlp_time(x) 130 | x = self.mlp_feat(x) 131 | return x 132 | 133 | 134 | class TS_Mixer(nn.Module): 135 | """Time Series Mixer. 136 | 137 | :argument 138 | - in_channels (int): input channels 139 | - ts_length (int): time series length 140 | - embed_dim (int): embedding dimension 141 | - num_blocks (int): number of mixer blocks 142 | - fcst_h (int): forecast horizon 143 | - dropout (float): dropout rate, default 0.1 144 | 145 | :return 146 | - x (tensor): output tensor of shape (batch_size, fcst_h, in_channels) 147 | 148 | : source 149 | - Algorithm 1 in [TSMixer: An all-MLP Architecture for Time Series Forecasting] (https://arxiv.org/pdf/2303.06053.pdf) 150 | """ 151 | 152 | def __init__( 153 | self, 154 | in_channels: int, 155 | ts_length: int, 156 | embed_dim: int, 157 | num_blocks: int, 158 | fcst_h: int, 159 | dropout: float = 0.1, 160 | ): 161 | super().__init__() 162 | self.mixer_blocks = nn.Sequential( 163 | *[ 164 | Mixer_Block(in_channels, ts_length, embed_dim, dropout) 165 | for _ in range(num_blocks) 166 | ] 167 | ) 168 | self.fc = nn.Linear(ts_length, fcst_h) 169 | 170 | def forward(self, x): 171 | x = self.mixer_blocks(x) 172 | x = self.fc(x.transpose(1, 2)) 173 | return x.transpose(1, 2) 174 | -------------------------------------------------------------------------------- /TSMixer/version_old/model_auxiliary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .model import MLP_Time as MLP_Time_Block 4 | from .model import MLP_Feat as MLP_Feat_Block 5 | 6 | 7 | class MLP_Feat(nn.Module): 8 | """MLP on feature domain. 9 | 10 | :argument 11 | - n_feat (int): number of input features 12 | - embed_dim (int): embedding dimension 13 | - dropout (float): dropout rate, 14 | 15 | :return 16 | - x (tensor): output tensor of shape (batch_size, fcst_h, embed_dim) 17 | """ 18 | 19 | def __init__(self, n_feat: int, embed_dim: int, dropout: float = 0.1): 20 | super().__init__() 21 | self.mlp1 = nn.Sequential( 22 | nn.Linear(n_feat, embed_dim), nn.ReLU(), nn.Dropout(dropout) 23 | ) 24 | 25 | self.mlp2 = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout)) 26 | 27 | # For cases where the input and the output dimensions are different, we apply an additional 28 | # linear transformation on the residual connection. 29 | self.projector = nn.Linear(n_feat, embed_dim) if n_feat != embed_dim else None 30 | 31 | # Batch normalization 32 | self.bn = nn.BatchNorm1d(embed_dim) 33 | 34 | def forward(self, x): 35 | v = self.mlp1(x) 36 | u = self.mlp2(v) 37 | h = x if self.projector is None else self.projector(x) 38 | out = self.bn((h + u).transpose(1, 2)).transpose(1, 2) 39 | return out 40 | 41 | 42 | class Mixer_Block(nn.Module): 43 | """Mixer block. 44 | 45 | :argument 46 | - n_feat (int): number of input features 47 | - n_static_feat (int): number of static features 48 | - fcst_h (int): forecast horizon 49 | - embed_dim (int): embedding dimension 50 | - dropout (float): dropout rate, default 0.1 51 | 52 | :return 53 | - x (tensor): output tensor of shape (batch_size, fcst_h, embed_dim*2) 54 | """ 55 | 56 | def __init__( 57 | self, 58 | n_feat: int, 59 | n_static_feat: int, 60 | fcst_h: int, 61 | embed_dim: int, 62 | dropout: float = 0.1, 63 | ): 64 | super().__init__() 65 | 66 | self.mlp_time = MLP_Time_Block(fcst_h, dropout) 67 | self.mlp_s = MLP_Feat(n_static_feat, embed_dim, dropout) 68 | self.mlp_feat = MLP_Feat_Block(n_feat, embed_dim, dropout) 69 | 70 | # We apply an additional linear transformation on the output of MLP_Feat_Block. Otherwise, each block increase 71 | # the output dimension by embed_dim 72 | self.projector = nn.Linear( 73 | n_feat, embed_dim * 2 74 | ) # check again if this is necessary. 75 | 76 | def forward(self, x, s): 77 | x = self.mlp_time(x) 78 | out = self.mlp_feat(torch.cat([x, self.mlp_s(s)], dim=2)) 79 | out = self.projector(out) 80 | return out 81 | 82 | 83 | class Mixer(nn.Module): 84 | """Mixer. 85 | 86 | :argument 87 | - n_feat (int): number of input features 88 | - n_static_feat (int): number of static features 89 | - fcst_h (int): forecast horizon 90 | - embed_dim (int): embedding dimension 91 | - num_blocks (int): number of mixer blocks 92 | - dropout (float): dropout rate, default 0.1 93 | 94 | :return 95 | - x (tensor): output tensor of shape (batch_size, fcst_h, embed_dim*2) 96 | """ 97 | 98 | def __init__( 99 | self, 100 | n_feat: int, 101 | n_static_feat: int, 102 | fcst_h: int, 103 | embed_dim: int, 104 | num_blocks: int, 105 | dropout: float = 0.1, 106 | ): 107 | super(Mixer, self).__init__() 108 | self.mixer_blocks = nn.ModuleList( 109 | [ 110 | Mixer_Block(n_feat, n_static_feat, fcst_h, embed_dim, dropout) 111 | for _ in range(num_blocks) 112 | ] 113 | ) 114 | 115 | def forward(self, x, s): 116 | for mixer_block in self.mixer_blocks: 117 | x = mixer_block(x, s) 118 | return x 119 | 120 | 121 | class TS_Mixer_auxiliary(nn.Module): 122 | """Time Series Mixer with auxiliary static and dynamic features. 123 | 124 | :argument 125 | - n_ts (int): number of input time series 126 | - n_static_feat (int): number of static features 127 | - n_dynamic_feat (int): number of dynamic features 128 | - ts_length (int): time series length 129 | - embed_dim (int): embedding dimension 130 | - num_blocks (int): number of mixer blocks 131 | - fcst_h (int): forecast horizon 132 | - out_dim (int): output dimension 133 | - dropout (float): dropout rate, default 0.1 134 | 135 | :return 136 | - x (tensor): output tensor of shape (batch_size, fcst_h, out_dim) 137 | 138 | source: 139 | - Algorithm 2 in [TSMixer: An all-MLP Architecture for Time Series Forecasting] (https://arxiv.org/pdf/2303.06053.pdf) 140 | """ 141 | 142 | def __init__( 143 | self, 144 | n_ts: int, 145 | n_static_feat: int, 146 | n_dynamic_feat: int, 147 | ts_length: int, 148 | embed_dim: int, 149 | num_blocks: int, 150 | fcst_h: int, 151 | out_dim: int, 152 | dropout: float = 0.1, 153 | ): 154 | super().__init__() 155 | 156 | # Number of features for sx and sz 157 | n_feat_sx = embed_dim + n_ts 158 | n_feat_sz = embed_dim + n_dynamic_feat 159 | 160 | # MLP that maps the length of the input time series to fcst_h 161 | self.fc_map = nn.Linear(ts_length, fcst_h) 162 | 163 | # MLPs, conditioned on static features, that map X and Z to embedding space 164 | self.mlp_sx = MLP_Feat(n_static_feat, embed_dim, dropout) 165 | self.mlp_sz = MLP_Feat(n_static_feat, embed_dim, dropout) 166 | self.mlp_x = MLP_Feat(n_feat_sx, embed_dim, dropout) 167 | self.mlp_z = MLP_Feat(n_feat_sz, embed_dim, dropout) 168 | 169 | # Mixer blocks 170 | self.mixer_blocks = Mixer( 171 | embed_dim * 3, n_static_feat, fcst_h, embed_dim, num_blocks, dropout 172 | ) 173 | 174 | # MLP that maps the output of the mixer blocks to the output dimension 175 | self.mlp_out = nn.Linear(embed_dim * 2, out_dim) 176 | 177 | # Layer normalization 178 | self.layer_norm = nn.LayerNorm(out_dim) 179 | 180 | def forward(self, x, z, s): 181 | # X: historical data of shape (batch_size, ts_length, Cx) 182 | # Z: future time-varying features of shape (batch_size, fcst_h, Cz) 183 | # S: static features of shape (batch_size, fcst_h, Cs) 184 | 185 | x = self.fc_map(x.transpose(1, 2)).transpose(1, 2) 186 | x_prime = self.mlp_x(torch.cat([x, self.mlp_sx(s)], dim=2)) 187 | z_prime = self.mlp_z(torch.cat([z, self.mlp_sz(s)], dim=2)) 188 | y_prime = torch.cat([x_prime, z_prime], dim=2) 189 | y_prime_block = self.mixer_blocks(y_prime, s) 190 | out = self.layer_norm(self.mlp_out(y_prime_block)) 191 | return out 192 | -------------------------------------------------------------------------------- /TSMixer/version_old/module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from typing import Tuple, Optional 15 | 16 | import torch 17 | from torch import nn 18 | 19 | from gluonts.core.component import validated 20 | from gluonts.model import Input, InputSpec 21 | from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler 22 | from gluonts.torch.distributions import StudentTOutput 23 | 24 | 25 | class MLP_Time(nn.Module): 26 | """MLP for time embedding. According to the paper, the authors employ a single layer perceptron. 27 | 28 | :argument 29 | - in_channels (int): input channels 30 | - ts_length (int): time series length 31 | - dropout (float): dropout rate 32 | - batch_norm (bool): whether to apply batch normalization 33 | 34 | :return 35 | - x (tensor): output tensor of shape (batch_size, ts_length, in_channels) 36 | """ 37 | 38 | def __init__( 39 | self, 40 | in_channels: int, 41 | ts_length: int, 42 | dropout: float = 0.1, 43 | batch_norm: bool = True, 44 | ): 45 | super().__init__() 46 | 47 | # BatchNorm2d is applied to the time dimension 48 | self.batch_norm2d = nn.BatchNorm2d(ts_length) if batch_norm is True else None 49 | self.in_channels = in_channels 50 | 51 | # MLP for time embedding 52 | self.time_mlp = nn.Sequential( 53 | nn.Linear(ts_length, ts_length), nn.ReLU(), nn.Dropout(dropout) 54 | ) 55 | 56 | def forward(self, x): 57 | if self.batch_norm2d is not None: 58 | x_norm = x.unsqueeze(-1) if self.in_channels == 1 else x 59 | x_norm = self.batch_norm2d(x_norm) 60 | x_norm = x_norm.squeeze(-1) if self.in_channels == 1 else x_norm 61 | else: 62 | x_norm = x 63 | x_time = self.time_mlp(x_norm.transpose(1, 2)).transpose(1, 2) 64 | return ( 65 | x + x_time 66 | ) # not sure if we need a residual connection here, the paper doesn't mention it. 67 | 68 | 69 | class MLP_Feat(nn.Module): 70 | """MLPs for feature embedding. 71 | 72 | :argument 73 | - in_channels (int): input channels 74 | - embed_dim (int): embedding dimension 75 | - dropout (float): dropout rate, default 0.1 76 | - batch_norm (bool): whether to apply batch normalization 77 | 78 | :return 79 | - x (tensor): output tensor of shape (batch_size, ts_length, in_channels) 80 | """ 81 | 82 | def __init__( 83 | self, 84 | in_channels: int, 85 | embed_dim: int, 86 | dropout: float = 0.1, 87 | batch_norm: bool = True, 88 | ): 89 | super().__init__() 90 | 91 | # BatchNorm2d is applied to the feature dimension 92 | self.batch_norm2d = nn.BatchNorm2d(in_channels) if batch_norm is True else None 93 | self.in_channels = in_channels 94 | 95 | # MLPs for feature embedding 96 | self.feat_mlp1 = nn.Sequential( 97 | nn.Linear(in_channels, embed_dim), nn.ReLU(), nn.Dropout(dropout) 98 | ) 99 | 100 | self.feat_mlp2 = nn.Sequential( 101 | nn.Linear(embed_dim, in_channels), nn.Dropout(dropout) 102 | ) 103 | 104 | def forward(self, x): 105 | if self.batch_norm2d is not None: 106 | x_norm = ( 107 | x.transpose(1, 2).unsqueeze(-1) 108 | if self.in_channels == 1 109 | else x.transpose(1, 2) 110 | ) 111 | x_norm = self.batch_norm2d(x_norm) 112 | x_norm = ( 113 | x_norm.transpose(1, 2).squeeze(-1) 114 | if self.in_channels == 1 115 | else x_norm.transpose(1, 2) 116 | ) 117 | else: 118 | x_norm = x 119 | x_feat = self.feat_mlp1(x_norm) 120 | return x + self.feat_mlp2(x_feat) 121 | 122 | 123 | class Mixer_Block(nn.Module): 124 | """Mixer block. 125 | 126 | :argument 127 | - in_channels (int): input channels 128 | - ts_length (int): time series length 129 | - embed_dim (int): embedding dimension 130 | - dropout (float): dropout rate, default 0.1 131 | - batch_norm (bool): whether to apply batch normalization 132 | 133 | :return 134 | - x (tensor): output tensor of shape (batch_size, ts_length, in_channels) 135 | """ 136 | 137 | def __init__( 138 | self, 139 | in_channels: int, 140 | ts_length: int, 141 | embed_dim: int, 142 | dropout: float = 0.1, 143 | batch_norm: bool = True, 144 | ): 145 | super().__init__() 146 | self.mlp_time = MLP_Time(in_channels, ts_length, dropout, batch_norm) 147 | self.mlp_feat = MLP_Feat(in_channels, embed_dim, dropout, batch_norm) 148 | 149 | def forward(self, x): 150 | x = self.mlp_time(x) 151 | x = self.mlp_feat(x) 152 | return x 153 | 154 | 155 | class TSMixerModel(nn.Module): 156 | """ 157 | Module implementing TSMixer for forecasting. 158 | 159 | Parameters 160 | ---------- 161 | prediction_length 162 | Number of time points to predict. 163 | context_length 164 | Number of time steps prior to prediction time that the model. 165 | scaling 166 | Whether to scale the target values. If "mean", the target values are scaled by the mean of the training set. 167 | If "std", the target values are scaled by the standard deviation of the training set. 168 | If "none", the target values are not scaled. 169 | input_size 170 | Number of input channels. 171 | n_blocks 172 | Number of mixer blocks 173 | hidden_size 174 | Size of hidden layers in the feed-forward network. 175 | dropout 176 | Dropout rate. Default: ``0.1``. 177 | batch_norm 178 | Whether to apply batch normalization. 179 | distr_output 180 | Distribution to use to evaluate observations and sample predictions. 181 | Default: ``StudentTOutput()``. 182 | 183 | : References: 184 | - Algorithm 1 in [TSMixer: An all-MLP Architecture for Time Series Forecasting] (https://arxiv.org/pdf/2303.06053.pdf) 185 | """ 186 | 187 | @validated() 188 | def __init__( 189 | self, 190 | prediction_length: int, 191 | context_length: int, 192 | scaling: str, 193 | input_size: int, 194 | n_blocks: int, 195 | hidden_size: int, 196 | dropout: float, 197 | batch_norm: bool = True, 198 | distr_output=StudentTOutput(), 199 | ) -> None: 200 | super().__init__() 201 | 202 | assert prediction_length > 0 203 | assert context_length > 0 204 | assert n_blocks > 0 205 | 206 | self.prediction_length = prediction_length 207 | self.context_length = context_length 208 | self.input_size = input_size 209 | 210 | if scaling == "mean": 211 | self.scaler = MeanScaler(keepdim=True) 212 | elif scaling == "std": 213 | self.scaler = StdScaler(keepdim=True) 214 | else: 215 | self.scaler = NOPScaler(keepdim=True) 216 | 217 | self.distr_output = distr_output 218 | 219 | self.mixer_blocks = nn.Sequential( 220 | *[ 221 | Mixer_Block( 222 | input_size, context_length, hidden_size, dropout, batch_norm 223 | ) 224 | for _ in range(n_blocks) 225 | ] 226 | ) 227 | 228 | # MLP that maps the output of the mixer blocks (=context_length) to the prediction length 229 | self.ts_map = nn.Linear(context_length, prediction_length) 230 | 231 | # MLP that maps the input_size to a higher hidden size, needed for the distr_output (only works for input_size = 1)? 232 | # self.hidden_map = nn.Linear(input_size, hidden_size) 233 | 234 | # MLP that maps the hidden size from self.hidden_map to the distribution output 235 | self.args_proj = self.distr_output.get_args_proj(input_size) 236 | 237 | def describe_inputs(self, batch_size=1) -> InputSpec: 238 | return InputSpec( 239 | { 240 | "past_target": Input( 241 | shape=(batch_size, self.context_length), dtype=torch.float 242 | ), 243 | "past_observed_values": Input( 244 | shape=(batch_size, self.context_length), dtype=torch.float 245 | ), 246 | }, 247 | torch.zeros, 248 | ) 249 | 250 | def forward( 251 | self, 252 | feat_static_cat: Optional[torch.Tensor] = None, 253 | feat_static_real: Optional[torch.Tensor] = None, 254 | past_time_feat: Optional[torch.Tensor] = None, 255 | past_target: Optional[torch.Tensor] = None, 256 | past_observed_values: Optional[torch.Tensor] = None, 257 | future_time_feat: Optional[torch.Tensor] = None, 258 | future_target: Optional[torch.Tensor] = None, 259 | future_observed_values: Optional[torch.Tensor] = None, 260 | ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: 261 | past_target_scaled, loc, scale = self.scaler(past_target, past_observed_values) 262 | past_target_scaled = ( 263 | past_target_scaled.unsqueeze(-1) 264 | if self.input_size == 1 265 | else past_target_scaled 266 | ) 267 | nn_out = self.mixer_blocks(past_target_scaled) 268 | nn_out = self.ts_map(nn_out.transpose(1, 2)).transpose(1, 2) 269 | # nn_out = self.hidden_map(nn_out) 270 | distr_args = self.args_proj(nn_out) 271 | return distr_args, loc, scale 272 | -------------------------------------------------------------------------------- /TsT-hyperparameter_tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "id": "86ff733d-8ea7-4d77-8d2a-cd329ab8f385", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import matplotlib.pyplot as plt\n", 13 | "import numpy as np\n", 14 | "import pandas as pd\n", 15 | "\n", 16 | "import torch\n", 17 | "\n", 18 | "from gluonts.dataset.multivariate_grouper import MultivariateGrouper\n", 19 | "from gluonts.dataset.repository.datasets import dataset_recipes, get_dataset\n", 20 | "from gluonts.evaluation.backtest import make_evaluation_predictions\n", 21 | "from gluonts.evaluation import MultivariateEvaluator\n", 22 | "\n", 23 | "from pts.modules import StudentTOutput\n", 24 | "\n", 25 | "from TsT import TsTEstimator\n", 26 | "import random\n", 27 | "import numpy as np\n", 28 | "import time\n", 29 | "import optuna\n", 30 | "from optuna.samplers import TPESampler" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 4, 36 | "id": "944676b0-2a9f-4301-92e6-f382f5693639", 37 | "metadata": { 38 | "tags": [] 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "class TsTObjective:\n", 43 | " def __init__(\n", 44 | " self,\n", 45 | " dataset,\n", 46 | " train_grouper,\n", 47 | " test_grouper,\n", 48 | " metric_type=\"m_sum_mean_wQuantileLoss\",\n", 49 | " ):\n", 50 | " self.metric_type = metric_type\n", 51 | " self.dataset = dataset\n", 52 | " self.dataset_train = train_grouper(self.dataset.train)\n", 53 | " self.dataset_test = test_grouper(self.dataset.test)\n", 54 | "\n", 55 | " def get_params(self, trial) -> dict:\n", 56 | " return {\n", 57 | " \"context_length\": trial.suggest_int(\n", 58 | " \"context_length\",\n", 59 | " dataset.metadata.prediction_length,\n", 60 | " dataset.metadata.prediction_length * 5,\n", 61 | " 1,\n", 62 | " ),\n", 63 | " \"batch_size\": trial.suggest_int(\"batch_size\", 32, 256, 32),\n", 64 | " \"depth\": trial.suggest_int(\"depth\", 2, 16, 4),\n", 65 | " \"dim\": trial.suggest_int(\"dim\", 16, 64, 16),\n", 66 | " \"patch_size\": trial.suggest_int(\"patch_size\", 2, 16, 4),\n", 67 | " \"kernel_size\": trial.suggest_int(\"kernel_size\", 9, 18, 3),\n", 68 | " }\n", 69 | "\n", 70 | " def __call__(self, trial):\n", 71 | " params = self.get_params(trial)\n", 72 | " estimator = estimator = TsTEstimator(\n", 73 | " # distr_output=StudentTOutput(dim=int(dataset.metadata.feat_static_cat[0].cardinality)),\n", 74 | " input_size=int(self.dataset.metadata.feat_static_cat[0].cardinality),\n", 75 | " prediction_length=self.dataset.metadata.prediction_length,\n", 76 | " context_length=params[\"context_length\"],\n", 77 | " freq=self.dataset.metadata.freq,\n", 78 | " scaling=\"std\",\n", 79 | " depth=params[\"depth\"],\n", 80 | " dim=params[\"dim\"],\n", 81 | " patch_size=(params[\"patch_size\"], params[\"patch_size\"]),\n", 82 | " batch_size=params[\"batch_size\"],\n", 83 | " num_batches_per_epoch=100,\n", 84 | " patch_reverse_mapping_layer=\"mlp\",\n", 85 | " trainer_kwargs=dict(accelerator=\"cuda\", max_epochs=30),\n", 86 | " )\n", 87 | " predictor = estimator.train(\n", 88 | " training_data=self.dataset_train, num_workers=8, shuffle_buffer_length=1024\n", 89 | " )\n", 90 | "\n", 91 | " forecast_it, ts_it = make_evaluation_predictions(\n", 92 | " dataset=self.dataset_test, predictor=predictor, num_samples=100\n", 93 | " )\n", 94 | " forecasts = list(forecast_it)\n", 95 | " tss = list(ts_it)\n", 96 | " evaluator = MultivariateEvaluator(\n", 97 | " quantiles=(np.arange(20) / 20.0)[1:], target_agg_funcs={\"sum\": np.sum}\n", 98 | " )\n", 99 | " agg_metrics, _ = evaluator(iter(tss), iter(forecasts))\n", 100 | " return agg_metrics[self.metric_type]" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 5, 106 | "id": "d6c1da22-755e-44d5-86b6-10465d8c25e8", 107 | "metadata": { 108 | "tags": [] 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "dataset = get_dataset(\"solar_nips\", regenerate=False)\n", 113 | "train_grouper = MultivariateGrouper(\n", 114 | " max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality)\n", 115 | ")\n", 116 | "\n", 117 | "test_grouper = MultivariateGrouper(\n", 118 | " num_test_dates=int(len(dataset.test) / len(dataset.train)),\n", 119 | " max_target_dim=int(dataset.metadata.feat_static_cat[0].cardinality),\n", 120 | ")\n", 121 | "dataset_train = train_grouper(dataset.train)\n", 122 | "dataset_test = test_grouper(dataset.test)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "id": "9ded5c21-ea78-4f99-98e3-1573f5abfbf2", 129 | "metadata": { 130 | "tags": [], 131 | "pycharm": { 132 | "is_executing": true 133 | } 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "seed = 42\n", 138 | "random.seed(seed)\n", 139 | "torch.manual_seed(seed)\n", 140 | "start_time = time.time()\n", 141 | "sampler = TPESampler(seed=seed)\n", 142 | "study = optuna.create_study(sampler=sampler, direction=\"minimize\")\n", 143 | "study.optimize(TsTObjective(dataset, train_grouper, test_grouper), n_trials=10)\n", 144 | "\n", 145 | "print(\"Number of finished trials: {}\".format(len(study.trials)))\n", 146 | "\n", 147 | "print(\"Best trial:\")\n", 148 | "trial = study.best_trial\n", 149 | "\n", 150 | "print(\" Value: {}\".format(trial.value))\n", 151 | "print(\" Params: \")\n", 152 | "for key, value in trial.params.items():\n", 153 | " print(\" {}: {}\".format(key, value))\n", 154 | "print(time.time() - start_time)" 155 | ] 156 | } 157 | ], 158 | "metadata": { 159 | "kernelspec": { 160 | "display_name": "py38", 161 | "language": "python", 162 | "name": "py38" 163 | }, 164 | "language_info": { 165 | "codemirror_mode": { 166 | "name": "ipython", 167 | "version": 3 168 | }, 169 | "file_extension": ".py", 170 | "mimetype": "text/x-python", 171 | "name": "python", 172 | "nbconvert_exporter": "python", 173 | "pygments_lexer": "ipython3", 174 | "version": "3.8.10" 175 | } 176 | }, 177 | "nbformat": 4, 178 | "nbformat_minor": 5 179 | } 180 | -------------------------------------------------------------------------------- /TsT/__init__.py: -------------------------------------------------------------------------------- 1 | from .module import TsTModel 2 | from .lightning_module import TsTLightningModule 3 | from .estimator import TsTEstimator 4 | 5 | __all__ = ["TsTModel", "TsTLightningModule", "TsTEstimator"] 6 | -------------------------------------------------------------------------------- /TsT/lightning_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | import lightning.pytorch as pl 15 | import torch 16 | 17 | from gluonts.core.component import validated 18 | 19 | from .module import TsTModel 20 | 21 | 22 | class TsTLightningModule(pl.LightningModule): 23 | """ 24 | A ``pl.LightningModule`` class that can be used to train a 25 | ``TsTModel`` with PyTorch Lightning. 26 | 27 | This is a thin layer around a (wrapped) ``TsTModel`` object, 28 | that exposes the methods to evaluate training and validation loss. 29 | 30 | Parameters 31 | ---------- 32 | model 33 | ``TsTModel`` to be trained. 34 | loss 35 | Loss function to be used for training, 36 | default: ``NegativeLogLikelihood()``. 37 | lr 38 | Learning rate, default: ``1e-3``. 39 | weight_decay 40 | Weight decay regularization parameter, default: ``1e-8``. 41 | """ 42 | 43 | @validated() 44 | def __init__( 45 | self, 46 | model_kwargs: dict, 47 | lr: float = 1e-3, 48 | weight_decay: float = 1e-8, 49 | ): 50 | super().__init__() 51 | self.save_hyperparameters() 52 | self.model = TsTModel(**model_kwargs) 53 | self.lr = lr 54 | self.weight_decay = weight_decay 55 | 56 | def forward(self, *args, **kwargs): 57 | distr_args, loc, scale = self.model.forward(*args, **kwargs) 58 | distr = self.model.distr_output.distribution(distr_args, loc, scale) 59 | return distr.sample((self.model.num_parallel_samples,)).reshape( 60 | -1, 61 | self.model.num_parallel_samples, 62 | self.model.prediction_length, 63 | self.model.input_size, 64 | ) 65 | 66 | def _compute_loss(self, batch): 67 | past_target = batch["past_target"] 68 | past_observed_values = batch["past_observed_values"] 69 | target = batch["future_target"] 70 | observed_target = batch["future_observed_values"] 71 | 72 | assert past_target.shape[1] == self.model.context_length 73 | assert target.shape[1] == self.model.prediction_length 74 | 75 | distr_args, loc, scale = self.model( 76 | past_target=past_target, 77 | past_observed_values=past_observed_values, 78 | past_time_feat=batch["past_time_feat"], 79 | future_time_feat=batch["future_time_feat"], 80 | ) 81 | 82 | loss_values = self.model.distr_output.loss( 83 | target=target, distr_args=distr_args, loc=loc, scale=scale 84 | ) 85 | return (loss_values * observed_target).sum() / torch.maximum( 86 | torch.tensor(1.0), observed_target.sum() 87 | ) 88 | 89 | # distr = self.model.distr_output.distribution(distr_args, loc, scale) 90 | 91 | # return (self.loss(distr, target) * observed_target).sum() / torch.maximum( 92 | # torch.tensor(1.0), observed_target.sum() 93 | # ) 94 | 95 | def training_step(self, batch, batch_idx: int): # type: ignore 96 | """ 97 | Execute training step. 98 | """ 99 | train_loss = self._compute_loss(batch) 100 | self.log( 101 | "train_loss", 102 | train_loss, 103 | on_epoch=True, 104 | on_step=False, 105 | prog_bar=True, 106 | ) 107 | return train_loss 108 | 109 | def validation_step(self, batch, batch_idx: int): # type: ignore 110 | """ 111 | Execute validation step. 112 | """ 113 | val_loss = self._compute_loss(batch) 114 | self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True) 115 | return val_loss 116 | 117 | def configure_optimizers(self): 118 | """ 119 | Returns the optimizer to use. 120 | """ 121 | return torch.optim.Adam( 122 | self.model.parameters(), 123 | lr=self.lr, 124 | weight_decay=self.weight_decay, 125 | ) 126 | -------------------------------------------------------------------------------- /TsT/module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from typing import Tuple, Optional 15 | 16 | import numpy as np 17 | import torch 18 | from torch import nn 19 | from einops.layers.torch import Rearrange 20 | 21 | from gluonts.core.component import validated 22 | from gluonts.model import Input, InputSpec 23 | from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler 24 | from gluonts.torch.distributions import StudentTOutput 25 | 26 | 27 | class Residual(nn.Module): 28 | def __init__(self, fn): 29 | super().__init__() 30 | self.fn = fn 31 | 32 | def forward(self, x): 33 | return self.fn(x) + x 34 | 35 | 36 | class MLPPatchMap(nn.Module): 37 | """ 38 | Module implementing MLPMap for the reverse mapping of the patch-tensor. 39 | 40 | Parameters 41 | ---------- 42 | patch_size : Tuple[int, int] 43 | Patch size. 44 | context_length : int 45 | Context length. 46 | prediction_length : int 47 | Number of time points to predict. 48 | input_size : int 49 | Input size. 50 | 51 | Returns 52 | ------- 53 | x : torch.Tensor 54 | """ 55 | 56 | def __init__( 57 | self, 58 | patch_size: Tuple[int, int], 59 | context_length: int, 60 | prediction_length: int, 61 | input_size: int, 62 | ): 63 | super().__init__() 64 | p1 = int(context_length / patch_size[0]) 65 | p2 = int(input_size / patch_size[1]) 66 | self.fc = nn.Sequential( 67 | Rearrange("b c w h -> b c h w"), 68 | nn.Linear(p1, prediction_length), 69 | Rearrange("b c h w -> b c w h"), 70 | nn.Linear(p2, input_size), 71 | ) 72 | 73 | def forward(self, x): 74 | x = self.fc(x) 75 | return x 76 | 77 | 78 | def RevMapLayer( 79 | layer_type: str, 80 | pooling_type: str, 81 | dim: int, 82 | patch_size: int, 83 | context_length: int, 84 | prediction_length: int, 85 | input_size: int, 86 | ): 87 | """ 88 | Returns the mapping layer for the reverse mapping of the patch-tensor to [b nf h ns]. 89 | 90 | :argument 91 | layer_type: str = "pooling" or "mlp" or "conv1d" 92 | pooling_type: str = "max" or "mean" 93 | dim: int = dimension of the embeddings 94 | patch_size: Tuple[int, int] = patch size 95 | prediction_length: int = prediction length 96 | context_length: int = context length 97 | input_size: int = input size 98 | 99 | :returns 100 | nn.Module = mapping layer 101 | 102 | """ 103 | if layer_type == "pooling": 104 | if pooling_type == "max": 105 | return nn.AdaptiveMaxPool2d((prediction_length, input_size)) 106 | elif pooling_type == "mean": 107 | return nn.AdaptiveAvgPool2d((prediction_length, input_size)) 108 | elif layer_type == "mlp": 109 | return MLPPatchMap(patch_size, context_length, prediction_length, input_size) 110 | else: 111 | raise ValueError("Invalid layer type: {}".format(layer_type)) 112 | 113 | 114 | class TsTModel(nn.Module): 115 | """ 116 | Module implementing TsT for forecasting. 117 | 118 | Parameters 119 | ---------- 120 | prediction_length 121 | Number of time points to predict. 122 | context_length 123 | Number of time steps prior to prediction time that the model. 124 | hidden_dimensions 125 | Size of hidden layers in the feed-forward network. 126 | distr_output 127 | Distribution to use to evaluate observations and sample predictions. 128 | Default: ``StudentTOutput()``. 129 | batch_norm 130 | Whether to apply batch normalization. Default: ``False``. 131 | """ 132 | 133 | @validated() 134 | def __init__( 135 | self, 136 | prediction_length: int, 137 | context_length: int, 138 | scaling: str, 139 | input_size: int, 140 | depth: int, 141 | dim: int, 142 | nhead: int, 143 | patch_size: Tuple[int, int], 144 | dim_feedforward: int, 145 | dropout: float, 146 | activation: str, 147 | norm_first: bool, 148 | patch_reverse_mapping_layer: str = "mlp", 149 | pooling_type: str = "max", 150 | num_feat_dynamic_real: int = 0, 151 | num_feat_static_real: int = 0, 152 | num_feat_static_cat: int = 0, 153 | distr_output=StudentTOutput(), 154 | num_parallel_samples: int = 100, 155 | ) -> None: 156 | super().__init__() 157 | 158 | assert prediction_length > 0 159 | assert context_length > 0 160 | assert depth > 0 161 | 162 | self.prediction_length = prediction_length 163 | self.context_length = context_length 164 | self.input_size = input_size 165 | self.dim = dim 166 | self.num_feat_static_real = num_feat_static_real 167 | self.num_feat_dynamic_real = num_feat_dynamic_real 168 | self.num_parallel_samples = num_parallel_samples 169 | 170 | if scaling == "mean": 171 | self.scaler = MeanScaler(keepdim=True, dim=1) 172 | elif scaling == "std": 173 | self.scaler = StdScaler(keepdim=True, dim=1) 174 | else: 175 | self.scaler = NOPScaler(keepdim=True, dim=1) 176 | 177 | self.distr_output = distr_output 178 | 179 | self.conv_proj = nn.Conv2d( 180 | self._number_of_features, dim, kernel_size=patch_size, stride=patch_size 181 | ) 182 | 183 | self.patch_num = (self.context_length // patch_size[0]) * ( 184 | self.input_size // patch_size[1] 185 | ) 186 | 187 | self.positional_encoding = SinusoidalPositionalEmbedding(self.patch_num, dim) 188 | 189 | layer_norm_eps: float = 1e-5 190 | encoder_layer = nn.TransformerEncoderLayer( 191 | d_model=dim, 192 | nhead=nhead, 193 | dim_feedforward=dim_feedforward, 194 | dropout=dropout, 195 | activation=activation, 196 | layer_norm_eps=layer_norm_eps, 197 | batch_first=True, 198 | norm_first=norm_first, 199 | ) 200 | encoder_norm = nn.LayerNorm(dim, eps=layer_norm_eps) 201 | self.encoder = nn.TransformerEncoder(encoder_layer, depth, encoder_norm) 202 | 203 | self.rev_map_layer = RevMapLayer( 204 | layer_type=patch_reverse_mapping_layer, 205 | pooling_type=pooling_type, 206 | dim=dim, 207 | patch_size=patch_size, 208 | prediction_length=self.prediction_length, 209 | context_length=self.context_length, 210 | input_size=self.input_size, 211 | ) 212 | 213 | self.args_proj = self.distr_output.get_args_proj( 214 | dim + self.num_feat_dynamic_real 215 | ) 216 | 217 | @property 218 | def _number_of_features(self) -> int: 219 | return ( 220 | self.num_feat_dynamic_real 221 | + self.num_feat_static_real 222 | + 3 # 1 + the log(loc) + log1p(scale) 223 | ) 224 | 225 | def describe_inputs(self, batch_size=1) -> InputSpec: 226 | return InputSpec( 227 | { 228 | "past_target": Input( 229 | shape=(batch_size, self.context_length, self.input_size), 230 | dtype=torch.float, 231 | ), 232 | "past_observed_values": Input( 233 | shape=(batch_size, self.context_length, self.input_size), 234 | dtype=torch.float, 235 | ), 236 | }, 237 | torch.zeros, 238 | ) 239 | 240 | def forward( 241 | self, 242 | feat_static_cat: Optional[torch.Tensor] = None, 243 | feat_static_real: Optional[torch.Tensor] = None, 244 | past_time_feat: Optional[torch.Tensor] = None, 245 | past_target: Optional[torch.Tensor] = None, 246 | past_observed_values: Optional[torch.Tensor] = None, 247 | future_time_feat: Optional[torch.Tensor] = None, 248 | future_target: Optional[torch.Tensor] = None, 249 | future_observed_values: Optional[torch.Tensor] = None, 250 | ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: 251 | past_target_scaled, loc, scale = self.scaler(past_target, past_observed_values) 252 | # [B, C, D], [B, D], [B, D] 253 | 254 | # [B, 1, C, D] 255 | past_target_scaled = past_target_scaled.unsqueeze(1) # channel dim 256 | 257 | log_abs_loc = loc.sign().unsqueeze(1).expand_as(past_target_scaled) * loc.abs().log1p().unsqueeze(1).expand_as(past_target_scaled) 258 | log_scale = scale.log().unsqueeze(1).expand_as(past_target_scaled) 259 | 260 | # [B, C, F] -> [B, F, C, 1] -> [B, F, C, D] 261 | past_time_feat = ( 262 | past_time_feat.transpose(2, 1) 263 | .unsqueeze(-1) 264 | .repeat_interleave(dim=-1, repeats=self.input_size) 265 | ) 266 | 267 | proj_input = torch.cat( 268 | ( 269 | past_target_scaled, 270 | log_abs_loc, 271 | log_scale, 272 | past_time_feat, 273 | ), 274 | dim=1, 275 | ) 276 | 277 | x = self.conv_proj(proj_input) 278 | B, C, H, W = x.shape 279 | 280 | x = x.reshape(B, self.dim, -1) 281 | x = x.permute(0, 2, 1) # [B, P, D] 282 | embed_pos = self.positional_encoding(x.size()) 283 | enc_out = self.encoder(x + embed_pos) 284 | 285 | nn_out = self.rev_map_layer(enc_out.permute(0, 2, 1).reshape(B, C, H, W)) 286 | 287 | # [B, F, C, D] -> [B, F, P, D] 288 | 289 | nn_out_reshaped = nn_out.transpose(1, -1).transpose(1, 2) 290 | future_time_feat_repeat = future_time_feat.unsqueeze(2).repeat_interleave( 291 | dim=2, repeats=self.input_size 292 | ) 293 | distr_args = self.args_proj( 294 | torch.cat((nn_out_reshaped, future_time_feat_repeat), dim=-1) 295 | ) 296 | 297 | return distr_args, loc, scale 298 | 299 | 300 | class SinusoidalPositionalEmbedding(nn.Embedding): 301 | """This module produces sinusoidal positional embeddings of any length.""" 302 | 303 | def __init__(self, num_positions: int, embedding_dim: int) -> None: 304 | super().__init__(num_positions, embedding_dim) 305 | self.weight = self._init_weight(self.weight) 306 | 307 | @staticmethod 308 | def _init_weight(out: nn.Parameter) -> nn.Parameter: 309 | """ 310 | Features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] 311 | """ 312 | n_pos, dim = out.shape 313 | position_enc = np.array( 314 | [ 315 | [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] 316 | for pos in range(n_pos) 317 | ] 318 | ) 319 | # set early to avoid an error in pytorch-1.8+ 320 | out.requires_grad = False 321 | 322 | sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 323 | out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) 324 | out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) 325 | out.detach_() 326 | return out 327 | 328 | @torch.no_grad() 329 | def forward( 330 | self, input_ids_shape: torch.Size, past_key_values_length: int = 0 331 | ) -> torch.Tensor: 332 | """`input_ids_shape` is expected to be [bsz x seqlen x ...].""" 333 | _, seq_len = input_ids_shape[:2] 334 | positions = torch.arange( 335 | past_key_values_length, 336 | past_key_values_length + seq_len, 337 | dtype=torch.long, 338 | device=self.weight.device, 339 | ) 340 | return super().forward(positions) 341 | -------------------------------------------------------------------------------- /examples/TS-Mixer Auxiliary Example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Description\n", 7 | "\n", 8 | "This notebook only illustrates the workings of the auxiliary TS_Mixer model including static and dynamic features." 9 | ], 10 | "metadata": { 11 | "collapsed": false 12 | } 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "source": [ 17 | "# Imports" 18 | ], 19 | "metadata": { 20 | "collapsed": false 21 | } 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "outputs": [], 27 | "source": [ 28 | "from TSMixer.model_auxiliary import *" 29 | ], 30 | "metadata": { 31 | "collapsed": false 32 | } 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "source": [ 37 | "# Sample Data" 38 | ], 39 | "metadata": { 40 | "collapsed": false 41 | } 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "outputs": [], 47 | "source": [ 48 | "torch.manual_seed(123)\n", 49 | "fcst_h = 20\n", 50 | "out_dim = 2\n", 51 | "x = torch.randn(100, 3)[None, :]\n", 52 | "z = torch.randn(fcst_h, 2)[None, :]\n", 53 | "s = torch.ones_like(z)\n", 54 | "s[:, :, 1] = 2" 55 | ], 56 | "metadata": { 57 | "collapsed": false 58 | } 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 3, 63 | "outputs": [], 64 | "source": [ 65 | "n_ts = x.shape[2]\n", 66 | "n_static_feat = s.shape[2]\n", 67 | "n_dynamic_feat = z.shape[2]\n", 68 | "ts_length = x.shape[1]\n", 69 | "embed_dim = 64\n", 70 | "\n", 71 | "n_feat_sx = embed_dim + n_ts\n", 72 | "n_feat_sz = embed_dim + n_dynamic_feat" 73 | ], 74 | "metadata": { 75 | "collapsed": false 76 | } 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 4, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": "(torch.Size([1, 20, 3]),\n torch.Size([1, 20, 64]),\n torch.Size([1, 20, 64]),\n torch.Size([1, 20, 64]),\n torch.Size([1, 20, 64]),\n torch.Size([1, 20, 128]),\n torch.Size([1, 20, 128]),\n torch.Size([1, 20, 2]))" 85 | }, 86 | "execution_count": 4, 87 | "metadata": {}, 88 | "output_type": "execute_result" 89 | } 90 | ], 91 | "source": [ 92 | "mlp_x_map = nn.Linear(ts_length, fcst_h)\n", 93 | "x_map = mlp_x_map(x.transpose(1, 2)).transpose(1, 2)\n", 94 | "\n", 95 | "mlp_sx = MLP_Feat(n_static_feat, embed_dim, dropout=0.1)\n", 96 | "sx_out = mlp_sx(s)\n", 97 | "\n", 98 | "mlp_sz = MLP_Feat(n_dynamic_feat, embed_dim, dropout=0.1)\n", 99 | "sz_out = mlp_sz(s)\n", 100 | "\n", 101 | "mlp_x = MLP_Feat(n_feat_sx, embed_dim, dropout=0.1)\n", 102 | "x_prime = mlp_x(torch.cat([x_map, sx_out], dim=2))\n", 103 | "\n", 104 | "mlp_z = MLP_Feat(n_feat_sz, embed_dim, dropout=0.1)\n", 105 | "z_prime = mlp_z(torch.cat([z, sz_out], dim=2))\n", 106 | "\n", 107 | "y_prime = torch.cat([x_prime, z_prime], dim=2)\n", 108 | "\n", 109 | "n_feat = embed_dim * 3\n", 110 | "\n", 111 | "mixer_block = Mixer_Block(n_feat, n_static_feat, fcst_h, embed_dim, dropout=0.1)\n", 112 | "y_prime_mixer = mixer_block(y_prime, s)\n", 113 | "\n", 114 | "mlp_out = nn.Linear(embed_dim * 2, out_dim)\n", 115 | "x_out = mlp_out(y_prime_mixer)\n", 116 | "\n", 117 | "layer_norm = nn.LayerNorm(out_dim)\n", 118 | "x_out = layer_norm(x_out)\n", 119 | "\n", 120 | "x_map.shape, sx_out.shape, sz_out.shape, x_prime.shape, z_prime.shape, y_prime.shape, y_prime_mixer.shape, x_out.shape" 121 | ], 122 | "metadata": { 123 | "collapsed": false 124 | } 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "source": [ 129 | "# TS-Mixer" 130 | ], 131 | "metadata": { 132 | "collapsed": false 133 | } 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 5, 138 | "outputs": [ 139 | { 140 | "ename": "TypeError", 141 | "evalue": "forward() takes 2 positional arguments but 3 were given", 142 | "output_type": "error", 143 | "traceback": [ 144 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 145 | "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", 146 | "Cell \u001b[1;32mIn[5], line 4\u001b[0m\n\u001b[0;32m 2\u001b[0m out_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m\n\u001b[0;32m 3\u001b[0m ts_mixer_aux \u001b[38;5;241m=\u001b[39m TS_Mixer_auxiliary(n_ts, n_static_feat, n_dynamic_feat, ts_length, embed_dim, num_blocks, fcst_h, out_dim, dropout\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.1\u001b[39m)\n\u001b[1;32m----> 4\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mts_mixer_aux\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(ts_mixer_aux)\n", 147 | "File \u001b[1;32m~\\.virtualenvs\\ConvTS_Mixer-3Rl3B8jo\\lib\\site-packages\\torch\\nn\\modules\\module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", 148 | "File \u001b[1;32m~\\OneDrive - adidas\\ConvTS Mixer\\TSMixer\\model_auxiliary.py:200\u001b[0m, in \u001b[0;36mTS_Mixer_auxiliary.forward\u001b[1;34m(self, x, z, s)\u001b[0m\n\u001b[0;32m 198\u001b[0m z_prime \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmlp_z(torch\u001b[38;5;241m.\u001b[39mcat([z, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmlp_sz(s)], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m))\n\u001b[0;32m 199\u001b[0m y_prime \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([x_prime, z_prime], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m--> 200\u001b[0m y_prime_block \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmixer_blocks\u001b[49m\u001b[43m(\u001b[49m\u001b[43my_prime\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 201\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer_norm(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmlp_out(y_prime_block))\n\u001b[0;32m 202\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n", 149 | "File \u001b[1;32m~\\.virtualenvs\\ConvTS_Mixer-3Rl3B8jo\\lib\\site-packages\\torch\\nn\\modules\\module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", 150 | "File \u001b[1;32m~\\OneDrive - adidas\\ConvTS Mixer\\TSMixer\\model_auxiliary.py:118\u001b[0m, in \u001b[0;36mMixer.forward\u001b[1;34m(self, x, s)\u001b[0m\n\u001b[0;32m 117\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, s):\n\u001b[1;32m--> 118\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmixer_blocks\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n", 151 | "File \u001b[1;32m~\\.virtualenvs\\ConvTS_Mixer-3Rl3B8jo\\lib\\site-packages\\torch\\nn\\modules\\module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", 152 | "\u001b[1;31mTypeError\u001b[0m: forward() takes 2 positional arguments but 3 were given" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "num_blocks = 2\n", 158 | "out_dim = 2\n", 159 | "ts_mixer_aux = TS_Mixer_auxiliary(\n", 160 | " n_ts,\n", 161 | " n_static_feat,\n", 162 | " n_dynamic_feat,\n", 163 | " ts_length,\n", 164 | " embed_dim,\n", 165 | " num_blocks,\n", 166 | " fcst_h,\n", 167 | " out_dim,\n", 168 | " dropout=0.1,\n", 169 | ")\n", 170 | "out = ts_mixer_aux(x, z, s)\n", 171 | "print(ts_mixer_aux)" 172 | ], 173 | "metadata": { 174 | "collapsed": false 175 | } 176 | } 177 | ], 178 | "metadata": { 179 | "kernelspec": { 180 | "display_name": "Python 3", 181 | "language": "python", 182 | "name": "python3" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 2 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython2", 194 | "version": "2.7.6" 195 | } 196 | }, 197 | "nbformat": 4, 198 | "nbformat_minor": 0 199 | } 200 | -------------------------------------------------------------------------------- /examples/TS-Mixer Base Example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Description\n", 7 | "\n", 8 | "This notebook only illustrates the workings of the base TS_Mixer, without additional static or dynamic features." 9 | ], 10 | "metadata": { 11 | "collapsed": false 12 | } 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "source": [ 17 | "# Imports" 18 | ], 19 | "metadata": { 20 | "collapsed": false 21 | } 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "metadata": { 27 | "collapsed": true 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "import torch\n", 32 | "from TSMixer.model import TS_Mixer" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "source": [ 38 | "# Sample Data\n", 39 | "\n", 40 | "The following simulates a set of 3 time series, each with a length of 100. Hence, we treat the series as multiple univariate time series and set $C_x = 3$. For simplicity, we assume a batch size of 1. This notebook only illustrates the workings of the base TS_Mixer, without additional static or dynamic features." 41 | ], 42 | "metadata": { 43 | "collapsed": false 44 | } 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "outputs": [ 50 | { 51 | "data": { 52 | "text/plain": "torch.Size([1, 100, 3])" 53 | }, 54 | "execution_count": 2, 55 | "metadata": {}, 56 | "output_type": "execute_result" 57 | } 58 | ], 59 | "source": [ 60 | "torch.manual_seed(123)\n", 61 | "ts_sample = torch.randn(100, 3)[None, :]\n", 62 | "ts_sample.shape" 63 | ], 64 | "metadata": { 65 | "collapsed": false 66 | } 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "source": [ 71 | "# TS-Mixer" 72 | ], 73 | "metadata": { 74 | "collapsed": false 75 | } 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/plain": "torch.Size([1, 20, 3])" 84 | }, 85 | "execution_count": 3, 86 | "metadata": {}, 87 | "output_type": "execute_result" 88 | } 89 | ], 90 | "source": [ 91 | "ts_mixer = TS_Mixer(\n", 92 | " in_channels=ts_sample.shape[2],\n", 93 | " ts_length=ts_sample.shape[1],\n", 94 | " embed_dim=64,\n", 95 | " num_blocks=1,\n", 96 | " fcst_h=20,\n", 97 | " dropout=0.1,\n", 98 | ")\n", 99 | "x_ts_mixer = ts_mixer(ts_sample)\n", 100 | "x_ts_mixer.shape" 101 | ], 102 | "metadata": { 103 | "collapsed": false 104 | } 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "source": [ 109 | "# BatchNorm2d\n" 110 | ], 111 | "metadata": { 112 | "collapsed": false 113 | } 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 4, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/plain": "torch.Size([1, 100, 3])" 122 | }, 123 | "execution_count": 4, 124 | "metadata": {}, 125 | "output_type": "execute_result" 126 | } 127 | ], 128 | "source": [ 129 | "x = ts_sample\n", 130 | "ts_length = x.shape[1]\n", 131 | "\n", 132 | "# Reshape the data to [batch_size, ts_length, n_time_series, 1] for batch normalization\n", 133 | "x = x.unsqueeze(-1)\n", 134 | "\n", 135 | "# Compute batch normalization along the second dimension (i.e., time dimension)\n", 136 | "bn_time = torch.nn.BatchNorm2d(ts_length)\n", 137 | "x_bn = bn_time(x)\n", 138 | "\n", 139 | "# Reshape the data back to the original shape\n", 140 | "x_bn = x_bn.squeeze(-1)\n", 141 | "x_bn.shape" 142 | ], 143 | "metadata": { 144 | "collapsed": false 145 | } 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 5, 150 | "outputs": [ 151 | { 152 | "data": { 153 | "text/plain": "torch.Size([1, 100, 3])" 154 | }, 155 | "execution_count": 5, 156 | "metadata": {}, 157 | "output_type": "execute_result" 158 | } 159 | ], 160 | "source": [ 161 | "x = ts_sample\n", 162 | "n_time_series = x.shape[2]\n", 163 | "\n", 164 | "# Reshape the data to [batch_size, n_time_series, ts_length, 1] for batch normalization\n", 165 | "x = x.transpose(1, 2).unsqueeze(-1)\n", 166 | "\n", 167 | "# Compute batch normalization along the second dimension (i.e., feature dimension)\n", 168 | "bn_feat = torch.nn.BatchNorm2d(n_time_series)\n", 169 | "x_bn = bn_feat(x)\n", 170 | "\n", 171 | "# Reshape the data back to the original shape\n", 172 | "x_bn = x_bn.transpose(1, 2).squeeze(-1)\n", 173 | "x_bn.shape" 174 | ], 175 | "metadata": { 176 | "collapsed": false 177 | } 178 | } 179 | ], 180 | "metadata": { 181 | "kernelspec": { 182 | "display_name": "Python 3", 183 | "language": "python", 184 | "name": "python3" 185 | }, 186 | "language_info": { 187 | "codemirror_mode": { 188 | "name": "ipython", 189 | "version": 2 190 | }, 191 | "file_extension": ".py", 192 | "mimetype": "text/x-python", 193 | "name": "python", 194 | "nbconvert_exporter": "python", 195 | "pygments_lexer": "ipython2", 196 | "version": "2.7.6" 197 | } 198 | }, 199 | "nbformat": 4, 200 | "nbformat_minor": 0 201 | } 202 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StatMixedML/ConvTS-Mixer/34162b04cc3560b9bdbe251454065baf2f9bd2db/examples/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/zalandoresearch/pytorch-ts.git@version-0.7.0 2 | git+https://github.com/awslabs/gluonts.git 3 | einops 4 | matplotlib 5 | numpy 6 | pandas 7 | lightning 8 | optuna 9 | scipy 10 | --------------------------------------------------------------------------------