├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── VERSION ├── configs └── lag_llama.json ├── data ├── augmentations │ ├── __init__.py │ ├── augmentations.py │ ├── freq_mask.py │ └── freq_mix.py ├── data_utils.py ├── dataset_list.py └── read_new_dataset.py ├── gluon_utils ├── gluon_ts_distributions │ └── implicit_quantile_network.py └── scalers │ └── robust_scaler.py ├── images └── lagllama.webp ├── lag_llama ├── gluon │ ├── __init__.py │ ├── estimator.py │ └── lightning_module.py └── model │ ├── __init__.py │ └── module.py ├── pyproject.toml ├── requirements.txt ├── run.py ├── scripts ├── finetune.sh └── pretrain.sh └── utils └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Custom files 163 | 164 | experiments 165 | scratch/* 166 | wandb/* 167 | stats 168 | hyperparameters 169 | data/datasets 170 | .DS_Store 171 | hyperparameters_lag_transformer 172 | notebooks 173 | datasets 174 | *ckpt 175 | *tar.gz -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include VERSION 2 | include requirements.txt 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting 2 | 3 | ![lag-llama-architecture](images/lagllama.webp) 4 | 5 | Lag-Llama is the first open-source foundation model for time series forecasting! 6 | 7 | [[Tweet Thread](https://twitter.com/arjunashok37/status/1755261111233114165)] 8 | 9 | [[Model Weights](https://huggingface.co/time-series-foundation-models/Lag-Llama)] [[Colab Demo 1: Zero-Shot Forecasting](https://colab.research.google.com/drive/1DRAzLUPxsd-0r8b-o4nlyFXrjw_ZajJJ?usp=sharing)] [[Colab Demo 2: (Preliminary Finetuning)](https://colab.research.google.com/drive/1uvTmh-pe1zO5TeaaRVDdoEWJ5dFDI-pA?usp=sharing)] 10 | 11 | [[Paper](https://arxiv.org/abs/2310.08278)] 12 | 13 | [[Video](https://www.youtube.com/watch?v=Mf2FOzDPxck)] 14 | ____ 15 | 16 | Updates: 17 | * **27-June-2024**: Fixed critical issues in the kv_cache implementation, improving forecast accuracy. The fixes include: resetting the self.y_cache flag globally, using causal attention correctly during kv_cache initialization, and adjusting rotary embeddings post-concatenation. Contribution by [@KelianM](https://github.com/KelianM). 18 | * **16-Apr-2024**: Released pretraining and finetuning scripts to replicate the experiments in the paper. See [Reproducing Experiments in the Paper](https://github.com/time-series-foundation-models/lag-llama?tab=readme-ov-file#reproducing-experiments-in-the-paper) for details. 19 | * **9-Apr-2024**: We have released a 15-minute video 🎥 on Lag-Llama on [YouTube](https://www.youtube.com/watch?v=Mf2FOzDPxck). 20 | * **5-Apr-2024**: Added a [section](https://colab.research.google.com/drive/1DRAzLUPxsd-0r8b-o4nlyFXrjw_ZajJJ?authuser=1#scrollTo=Mj9LXMpJ01d7&line=6&uniqifier=1) in Colab Demo 1 on the importance of tuning the context length for zero-shot forecasting. Added a [best practices section](https://github.com/time-series-foundation-models/lag-llama?tab=readme-ov-file#best-practices) in the README; added recommendations for finetuning. These recommendations will be demonstrated with an example in [Colab Demo 2](https://colab.research.google.com/drive/1uvTmh-pe1zO5TeaaRVDdoEWJ5dFDI-pA?usp=sharing) soon. 21 | * **4-Apr-2024**: We have updated our requirements file with new versions of certain packages. Please update/recreate your environments if you have previously used the code locally. 22 | * **7-Mar-2024**: We have released a preliminary [Colab Demo 2](https://colab.research.google.com/drive/1uvTmh-pe1zO5TeaaRVDdoEWJ5dFDI-pA?usp=sharing) for finetuning. Please note this is a preliminary tutorial. We recommend taking a look at the best practices if you are finetuning the model or using it for benchmarking. 23 | * **17-Feb-2024**: We have released a new updated [Colab Demo 1](https://colab.research.google.com/drive/1DRAzLUPxsd-0r8b-o4nlyFXrjw_ZajJJ?usp=sharing) for zero-shot forecasting that shows how one can load time series of different formats. 24 | * **7-Feb-2024**: We released Lag-Llama, with open-source model checkpoints and a Colab Demo for zero-shot forecasting. 25 | 26 | ____ 27 | 28 | **Current Features**: 29 | 30 | 💫 Zero-shot forecasting on a dataset of any frequency for any prediction length, using Colab Demo 1.
31 | 32 | 💫 Finetuning on a dataset using [Colab Demo 2](https://colab.research.google.com/drive/1uvTmh-pe1zO5TeaaRVDdoEWJ5dFDI-pA?usp=sharing). 33 | 34 | 💫 Reproducing experiments in the paper using the released scripts. See [Reproducing Experiments in the Paper](https://github.com/time-series-foundation-models/lag-llama?tab=readme-ov-file#reproducing-experiments-in-the-paper) for details. 35 | 36 | **Note**: Please see the [best practices section](https://github.com/time-series-foundation-models/lag-llama?tab=readme-ov-file#best-practices) when using the model for zero-shot prediction and finetuning. 37 | 38 | ____ 39 | 40 | ## Reproducing Experiments in the Paper 41 | 42 | To replicate the pretraining setup used in the paper, please see [the pretraining script](scripts/pretrain.sh). Once a model is pretrained, instructions to finetune it with the setup in the paper can be found in [the finetuning script](scripts/finetune.sh). 43 | 44 | 45 | ## Best Practices 46 | 47 | Here are some general tips in using Lag-Llama. 48 | 49 | 50 | ### General Information 51 | 52 | * Lag-Llama is a **probabilistic** forecasting model trained to output a probability distribution for each timestep to be predicted. For your own specific use-case, we would recommend benchmarking the zero-shot performance of the model on your data first, and then finetuning if necessary. As we show in our paper, Lag-Llama has strong zero-shot capabilities, but performs best when finetuned. The more data you finetune on, the better. For specific tips on applying on model zero-shot or on finetuning, please refer to the sections below. 53 | 54 | #### Zero-Shot Forecasting 55 | 56 | * Importantly, we recommend trying different **context lengths** (starting from $32$ which it was trained on) and identifying what works best for your data. As we show in [this section of the zero-shot forecasting demo](https://colab.research.google.com/drive/1DRAzLUPxsd-0r8b-o4nlyFXrjw_ZajJJ?authuser=1#scrollTo=Mj9LXMpJ01d7&line=6&uniqifier=1), the model's zero-shot performance improves as the context length is increased, until a certain context length which may be specific to your data. Further, we recommend enabling RoPE scaling for the model to work well with context lengths larger than what it was trained on. 57 | 58 | #### Fine-Tuning 59 | 60 | If you are trying to **benchmark** the performance of the model under finetuning, or trying to obtain maximum performance from the model: 61 | 62 | * We recommend tuning two important hyperparameters for each dataset that you finetune on: the **context length** (suggested values: $32$, $64$, $128$, $256$, $512$, $1024$) and the **learning rate** (suggested values: $10^{-2}$, $5 * 10^{-3}$, $10^{-3}$, $5 * 10^{-3}$, $1 * 10^{-4}$, $5 * 10^{-4}$). 63 | * We also highly recommend using a validation split of your dataset to early stop your model, with an early stopping patience of 50 epochs. 64 | 65 | ## Contact 66 | 67 | We are dedicated to ensuring the reproducility of our results, and would be happy to help clarify questions about benchmarking our model or about the experiments in the paper. 68 | The quickest way to reach us would be by email. Please email **both**: 69 | 1. [Arjun Ashok](https://ashok-arjun.github.io/) - arjun [dot] ashok [at] servicenow [dot] com 70 | 2. [Kashif Rasul](https://scholar.google.de/citations?user=cfIrwmAAAAAJ&hl=en) - kashif [dot] rasul [at] gmail [dot] com 71 | 72 | If you have questions about the model usage (or) code (or) have specific errors (eg. using it with your own dataset), it would be best to create an issue in the GitHub repository. 73 | 74 | ## Citing this work 75 | 76 | Please use the following Bibtex entry to cite Lag-Llama. 77 | 78 | ``` 79 | @misc{rasul2024lagllama, 80 | title={Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting}, 81 | author={Kashif Rasul and Arjun Ashok and Andrew Robert Williams and Hena Ghonia and Rishika Bhagwatkar and Arian Khorasani and Mohammad Javad Darvishi Bayazi and George Adamopoulos and Roland Riachi and Nadhir Hassen and Marin Biloš and Sahil Garg and Anderson Schneider and Nicolas Chapados and Alexandre Drouin and Valentina Zantedeschi and Yuriy Nevmyvaka and Irina Rish}, 82 | year={2024}, 83 | eprint={2310.08278}, 84 | archivePrefix={arXiv}, 85 | primaryClass={cs.LG} 86 | } 87 | ``` 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.1.0 -------------------------------------------------------------------------------- /configs/lag_llama.json: -------------------------------------------------------------------------------- 1 | { 2 | "use_single_instance_sampler": true, 3 | "stratified_sampling": "series", 4 | "data_normalization": "robust", 5 | "n_layer": 8, 6 | "n_head": 9, 7 | "n_embd_per_head": 16, 8 | "time_feat": true, 9 | "context_length": 32, 10 | "aug_prob": 0.5, 11 | "freq_mask_rate": 0.5, 12 | "freq_mixing_rate": 0.25, 13 | "weight_decay": 0.0, 14 | "dropout": 0.0 15 | } -------------------------------------------------------------------------------- /data/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. -------------------------------------------------------------------------------- /data/augmentations/augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Adapted from https://github.com/vafl/gluon-ts/blob/ts_embeddings/src/gluonts/nursery/ts_embeddings/pt_augmentation.py 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | from scipy.interpolate import CubicSpline 21 | 22 | 23 | class ApplyAugmentations(nn.Module): 24 | def __init__(self, transforms): 25 | super().__init__() 26 | self.transformation = RandomApply(transforms) 27 | 28 | def forward(self, ip1, ip2): 29 | ip_concat = torch.concat((ip1, ip2), dim=1) 30 | ip_aug = self.transformation(ip_concat) 31 | op1, op2 = torch.split(ip_aug, [ip1.size(1), ip2.size(1)], dim=1) 32 | assert len(op1.shape) == 2 and len(op2.shape) == 2 33 | return op1, op2 34 | 35 | 36 | class RandomApply(nn.Module): 37 | def __init__(self, transforms, p=0.5): 38 | super().__init__() 39 | # 'transforms' is a list of transformation modules to be applied to the input data. 40 | # These could be any transformations like normalization, augmentation, etc. 41 | self.transforms = nn.ModuleList(transforms) 42 | 43 | # 'p' is the probability with which the transformations will be applied. 44 | # It's a floating-point number between 0 and 1. 45 | self.p = p 46 | 47 | def forward(self, x): 48 | # Randomly decide whether to apply the transformations or not, based on probability 'p'. 49 | if self.p < torch.rand(1): 50 | # If the random number is greater than 'p', return the input as it is. 51 | return x 52 | 53 | # Assert shape is (bsz, seq_len) 54 | assert len(x.shape) == 2 55 | 56 | # Unsqueeze last dimension 57 | x = x.unsqueeze(-1) 58 | 59 | # Apply each transformation in the list to the transposed input. 60 | for t in self.transforms: 61 | x = t(x) 62 | 63 | # Squeeze last dimension 64 | x = x.squeeze(-1) 65 | 66 | # Finally, transpose the tensor back to its original dimension order. 67 | # return x_time_first.transpose(1, 2) 68 | return x 69 | 70 | 71 | class Jitter(nn.Module): 72 | """ 73 | The Jitter class implements a jittering transformation as described in the paper: 74 | 'Data Augmentation for Machine Learning Algorithms' (https://arxiv.org/pdf/1706.00527.pdf). 75 | It adds random noise to the input data, which is a common technique for data augmentation. 76 | """ 77 | 78 | def __init__(self, p, sigma=0.03): 79 | super().__init__() 80 | # 'p' is the probability with which the jitter (noise) will be applied to the input data. 81 | self.p = p 82 | 83 | # 'sigma' defines the standard deviation of the normal distribution used for generating the jitter. 84 | # It controls the magnitude of the noise added to the data. 85 | self.sigma = sigma 86 | 87 | def forward(self, x): 88 | # Randomly decide whether to apply jitter (noise) or not, based on the probability 'p'. 89 | if self.p < torch.rand(1): 90 | # If the random number is greater than 'p', return the input as it is, without adding noise. 91 | return x 92 | 93 | # Generate random noise from a normal distribution with mean 0 and standard deviation 'sigma'. 94 | # The size and device of the noise tensor are the same as the input tensor 'x'. 95 | noise = torch.normal(mean=0.0, std=self.sigma, size=x.shape, device=x.device) 96 | 97 | # Add the generated noise to the input data and return the result. 98 | return x + noise 99 | 100 | 101 | class Scaling(nn.Module): 102 | """ 103 | The Scaling class implements a scaling transformation as described in the paper: 104 | 'Data Augmentation for Machine Learning Algorithms' (https://arxiv.org/pdf/1706.00527.pdf). 105 | This transformation scales the input data by a random factor, which can be useful for data augmentation. 106 | """ 107 | 108 | def __init__(self, p, sigma=0.1): 109 | super().__init__() 110 | # 'p' is the probability with which the scaling will be applied to the input data. 111 | self.p = p 112 | 113 | # 'sigma' defines the standard deviation of the normal distribution used for generating the scaling factor. 114 | # It controls the variability of the scaling factor. 115 | self.sigma = sigma 116 | 117 | def forward(self, x): 118 | # Randomly decide whether to apply scaling or not, based on the probability 'p'. 119 | if self.p < torch.rand(1): 120 | # If the random number is greater than 'p', return the input as it is, without scaling. 121 | return x 122 | 123 | # Generate random scaling factors from a normal distribution with mean 1 and standard deviation 'sigma'. 124 | # The size of the scaling factor tensor is tailored to match the batch and time dimensions of the input tensor 'x', 125 | # but it has a single channel so that the same factor is applied across all channels. 126 | factor = torch.normal( 127 | mean=1.0, std=self.sigma, size=(x.shape[0], 1, x.shape[2]), device=x.device 128 | ) 129 | 130 | # Multiply the input data by the scaling factors and return the result. 131 | return x * factor 132 | 133 | 134 | class Rotation(nn.Module): 135 | """ 136 | This Rotation class is designed to randomly rotate the input data. 137 | It's a form of data augmentation that can be particularly useful in scenarios where 138 | the orientation of the data is not a defining characteristic. 139 | """ 140 | 141 | def __init__(self, p): 142 | super().__init__() 143 | # 'p' is the probability of applying the rotation to the input data. 144 | self.p = p 145 | 146 | def forward(self, x): 147 | # Randomly decide whether to rotate the data or not, based on the probability 'p'. 148 | if self.p < torch.rand(1): 149 | # If the random number is greater than 'p', return the input as it is, without rotation. 150 | return x 151 | 152 | # Create an index for flipping, where each element has a 50% chance of being 0 or 1. 153 | flip_index = torch.multinomial( 154 | torch.tensor([0.5, 0.5], dtype=x.dtype, device=x.device), 155 | num_samples=x.shape[0] * x.shape[2], 156 | replacement=True, 157 | ) 158 | 159 | # Create a tensor of ones, which will be used to flip the sign of the data based on the flip_index. 160 | ones = torch.ones((x.shape[0] * x.shape[2]), device=x.device) 161 | flip = torch.where(flip_index == 0, -ones, ones) 162 | 163 | # Randomly shuffle the axes along which the data will be rotated. 164 | rotate_axis = np.arange(x.shape[2]) 165 | np.random.shuffle(rotate_axis) 166 | 167 | # Apply the flipping and rotation to the data and return the result. 168 | return flip.reshape(x.shape[0], 1, x.shape[2]) * x[:, :, rotate_axis] 169 | 170 | 171 | class Permutation(nn.Module): 172 | """ 173 | The Permutation class implements a data augmentation technique where the data is divided into segments, 174 | and these segments are then randomly permuted. This can be useful for tasks where the order of data points 175 | is not crucial and can help in improving the robustness of models. 176 | """ 177 | 178 | def __init__(self, p, max_segments=5, seg_mode="equal"): 179 | super().__init__() 180 | # 'p' is the probability of applying the permutation to the input data. 181 | self.p = p 182 | 183 | # 'max_segments' defines the maximum number of segments into which the data can be split for permutation. 184 | self.max_segments = max_segments 185 | 186 | # 'seg_mode' determines how the segments are created: 'equal' for equal-sized segments, 'random' for random splits. 187 | self.seg_mode = seg_mode 188 | 189 | def forward(self, x): 190 | # Randomly decide whether to permute the data or not, based on the probability 'p'. 191 | if self.p < torch.rand(1): 192 | # If the random number is greater than 'p', return the input as it is, without permutation. 193 | return x 194 | 195 | # Create an array representing the original order of data points. 196 | orig_steps = np.arange(x.shape[1]) 197 | 198 | # Randomly decide the number of segments for each batch in the data. 199 | num_segs = np.random.randint(1, self.max_segments, size=(x.shape[0])) 200 | 201 | # Initialize a tensor to hold the permuted data. 202 | ret = torch.zeros_like(x) 203 | for i, pat in enumerate(x): 204 | if num_segs[i] > 1: 205 | if self.seg_mode == "random": 206 | # In 'random' mode, choose random split points. 207 | split_points = np.random.choice( 208 | x.shape[1] - 2, num_segs[i] - 1, replace=False 209 | ) 210 | split_points.sort() 211 | splits = np.split(orig_steps, split_points) 212 | else: 213 | # In 'equal' mode, split the data into roughly equal segments. 214 | splits = np.array_split(orig_steps, num_segs[i]) 215 | 216 | # Permute the segments and recombine them. 217 | warp = np.concatenate(np.random.permutation(splits)).ravel() 218 | ret[i] = pat[warp] 219 | else: 220 | # If there's only one segment, keep the data as it is. 221 | ret[i] = pat 222 | 223 | return ret 224 | 225 | 226 | class MagnitudeWarp(nn.Module): 227 | """ 228 | The MagnitudeWarp class applies a non-linear warping to the magnitude of the input data. 229 | This is achieved by using cubic splines to create smooth, random warp functions that are 230 | then applied to the input. It's a form of data augmentation useful in scenarios where the 231 | model needs to be robust to variations in the magnitude of the input data. 232 | """ 233 | 234 | def __init__(self, p, sigma=0.2, knot=4): 235 | super().__init__() 236 | # 'p' is the probability with which the magnitude warp will be applied. 237 | self.p = p 238 | 239 | # 'sigma' controls the variability of the warp. Higher values lead to more pronounced warping. 240 | self.sigma = sigma 241 | 242 | # 'knot' is the number of points in the cubic spline used for warping. 243 | self.knot = knot 244 | 245 | def forward(self, x): 246 | # Decide whether to apply the warp based on the probability 'p'. 247 | if self.p < torch.rand(1): 248 | return x 249 | 250 | # Generate an array representing the original order of data points. 251 | orig_steps = np.arange(x.shape[1]) 252 | 253 | # Generate random warps using a normal distribution centered at 1.0. 254 | random_warps = np.random.normal( 255 | loc=1.0, 256 | scale=self.sigma, 257 | size=(x.shape[0], self.knot + 2, x.shape[2]), 258 | ) 259 | 260 | # Create warp steps evenly distributed across the data length. 261 | warp_steps = ( 262 | np.ones((x.shape[2], 1)) 263 | * (np.linspace(0, x.shape[1] - 1.0, num=self.knot + 2)) 264 | ).T 265 | 266 | # Initialize a tensor to hold the warped data. 267 | ret = torch.zeros_like(x) 268 | for i, pat in enumerate(x): 269 | # For each dimension, create a cubic spline based on the warp steps and random warps, 270 | # and apply it to the original steps to get the warper. 271 | warper = np.array( 272 | [ 273 | CubicSpline(warp_steps[:, dim], random_warps[i, :, dim])(orig_steps) 274 | for dim in range(x.shape[2]) 275 | ] 276 | ).T 277 | 278 | # Apply the warper to the pattern and store it in the result tensor. 279 | ret[i] = pat * torch.from_numpy(warper).float().to(x.device) 280 | 281 | return ret 282 | 283 | 284 | class TimeWarp(nn.Module): 285 | """ 286 | The TimeWrap class applies a non-linear warping to the time axis of the input data. 287 | This is achieved by using cubic splines to create smooth, random warp functions that 288 | distort the time dimension of the input. It's a form of data augmentation useful for 289 | tasks where the model needs to be robust to variations in the timing of the input data. 290 | """ 291 | 292 | def __init__(self, p, sigma=0.2, knot=4): 293 | super().__init__() 294 | # 'p' is the probability with which the time warp will be applied. 295 | self.p = p 296 | 297 | # 'sigma' controls the variability of the warp. Higher values lead to more pronounced warping. 298 | self.sigma = sigma 299 | 300 | # 'knot' is the number of points in the cubic spline used for warping. 301 | self.knot = knot 302 | 303 | def forward(self, x): 304 | # Decide whether to apply the warp based on the probability 'p'. 305 | if self.p < torch.rand(1): 306 | return x 307 | 308 | # Generate an array representing the original time steps of the data. 309 | orig_steps = np.arange(x.shape[1]) 310 | 311 | # Generate random warps using a normal distribution centered at 1.0. 312 | random_warps = np.random.normal( 313 | loc=1.0, 314 | scale=self.sigma, 315 | size=(x.shape[0], self.knot + 2, x.shape[2]), 316 | ) 317 | 318 | # Create warp steps evenly distributed across the data length. 319 | warp_steps = ( 320 | np.ones((x.shape[2], 1)) 321 | * (np.linspace(0, x.shape[1] - 1.0, num=self.knot + 2)) 322 | ).T 323 | 324 | # Initialize a tensor to hold the time-warped data. 325 | ret = torch.zeros_like(x) 326 | for i, pat in enumerate(x): 327 | for dim in range(x.shape[2]): 328 | # Create a cubic spline based on the warp steps and random warps to generate the time warp. 329 | time_warp = CubicSpline( 330 | warp_steps[:, dim], 331 | warp_steps[:, dim] * random_warps[i, :, dim], 332 | )(orig_steps) 333 | # Scale the time warp to fit the original data length. 334 | scale = (x.shape[1] - 1) / time_warp[-1] 335 | wrap = np.interp( 336 | orig_steps, 337 | np.clip(scale * time_warp, 0, x.shape[1] - 1), 338 | pat[:, dim].cpu().numpy(), 339 | ).T 340 | # Apply the time warp to the corresponding dimension of the data. 341 | ret[i, :, dim] = torch.from_numpy(wrap).float().to(x.device) 342 | 343 | return ret 344 | 345 | 346 | class WindowSlice(nn.Module): 347 | """ 348 | The WindowSlice class implements a data augmentation technique where a slice of the input data 349 | is stretched to fill the entire length of the input. This technique is useful for training models 350 | to focus on local features of the data and can be found in literature such as: 351 | 'Time Series Data Augmentation for Deep Learning: A Survey' (https://halshs.archives-ouvertes.fr/halshs-01357973/document). 352 | """ 353 | 354 | def __init__(self, p, reduce_ratio=0.9): 355 | super().__init__() 356 | # 'p' is the probability of applying the window slicing to the input data. 357 | self.p = p 358 | 359 | # 'reduce_ratio' determines the size of the slice relative to the original data. 360 | self.reduce_ratio = reduce_ratio 361 | 362 | def forward(self, x): 363 | # Decide whether to apply the slice based on the probability 'p'. 364 | if self.p < torch.rand(1): 365 | return x 366 | 367 | # Calculate the target length of the slice. 368 | target_len = np.ceil(self.reduce_ratio * x.shape[1]).astype(int) 369 | if target_len >= x.shape[1]: 370 | return x 371 | 372 | # Randomly select start points for the slice in each batch. 373 | starts = np.random.randint( 374 | low=0, high=x.shape[1] - target_len, size=(x.shape[0]) 375 | ).astype(int) 376 | ends = (target_len + starts).astype(int) 377 | 378 | # Initialize a tensor to hold the sliced and stretched data. 379 | ret = torch.zeros_like(x) 380 | for i, pat in enumerate(x): 381 | for dim in range(x.shape[2]): 382 | # Interpolate the slice to stretch it across the original length of the data. 383 | warp = np.interp( 384 | np.linspace(0, target_len, num=x.shape[1]), 385 | np.arange(target_len), 386 | pat[starts[i] : ends[i], dim].cpu().numpy(), 387 | ).T 388 | # Apply the stretched slice to the corresponding dimension of the data. 389 | ret[i, :, dim] = torch.from_numpy(warp).float().to(x.device) 390 | return ret 391 | 392 | 393 | class WindowWarp(nn.Module): 394 | """ 395 | The WindowWarp class implements a data augmentation technique where a segment (window) of the input data 396 | is selected and warped in size. This technique is useful for simulating variations in the speed or rate 397 | of the data within a certain window, as discussed in: 398 | 'Time Series Data Augmentation for Deep Learning: A Survey' (https://halshs.archives-ouvertes.fr/halshs-01357973/document). 399 | """ 400 | 401 | def __init__(self, p, window_ratio=0.1, scales=[0.5, 2.0]): 402 | super().__init__() 403 | # 'p' is the probability of applying the window warp to the input data. 404 | self.p = p 405 | 406 | # 'window_ratio' determines the size of the window relative to the original data. 407 | self.window_ratio = window_ratio 408 | 409 | # 'scales' are the possible scaling factors to be applied to the window. 410 | self.scales = scales 411 | 412 | def forward(self, x): 413 | # Decide whether to apply the warp based on the probability 'p'. 414 | if self.p < torch.rand(1): 415 | return x 416 | 417 | # Randomly choose a scaling factor for each batch in the data. 418 | warp_scales = np.random.choice(self.scales, x.shape[0]) 419 | 420 | # Calculate the size of the warp window. 421 | warp_size = np.ceil(self.window_ratio * x.shape[1]).astype(int) 422 | window_steps = np.arange(warp_size) 423 | 424 | # Randomly select start points for the window in each batch. 425 | window_starts = np.random.randint( 426 | low=1, high=x.shape[1] - warp_size - 1, size=(x.shape[0]) 427 | ).astype(int) 428 | window_ends = (window_starts + warp_size).astype(int) 429 | 430 | # Initialize a tensor to hold the window-warped data. 431 | ret = torch.zeros_like(x) 432 | for i, pat in enumerate(x): 433 | for dim in range(x.shape[2]): 434 | # Isolate the segments before, within, and after the window. 435 | start_seg = pat[: window_starts[i], dim].cpu().numpy() 436 | window_seg = np.interp( 437 | np.linspace( 438 | 0, 439 | warp_size - 1, 440 | num=int(warp_size * warp_scales[i]), 441 | ), 442 | window_steps, 443 | pat[window_starts[i] : window_ends[i], dim].cpu().numpy(), 444 | ) 445 | end_seg = pat[window_ends[i] :, dim].cpu().numpy() 446 | 447 | # Concatenate the segments and stretch them to fit the original data length. 448 | warped = np.concatenate((start_seg, window_seg, end_seg)) 449 | warp = np.interp( 450 | np.arange(x.shape[1]), 451 | np.linspace(0, x.shape[1] - 1.0, num=warped.size), 452 | warped, 453 | ).T 454 | 455 | # Apply the window warp to the corresponding dimension of the data. 456 | ret[i, :, dim] = torch.from_numpy(warp).float().to(x.device) 457 | return ret 458 | -------------------------------------------------------------------------------- /data/augmentations/freq_mask.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | @torch.no_grad() 19 | def freq_mask(x, y, rate=0.1, dim=1): 20 | # Get lengths of the input tensors along the specified dimension. 21 | x_len = x.shape[dim] 22 | y_len = y.shape[dim] 23 | 24 | # Concatenate x and y along the specified dimension. 25 | # x and y represent past and future targets respectively. 26 | xy = torch.cat([x, y], dim=dim) 27 | 28 | # Perform a real-valued fast Fourier transform (RFFT) on the concatenated tensor. 29 | # This transforms the time series data into the frequency domain. 30 | xy_f = torch.fft.rfft(xy, dim=dim) 31 | 32 | # Create a random mask with a probability defined by 'rate'. 33 | # This mask will be used to randomly select frequencies to be zeroed out. 34 | m = torch.rand_like(xy_f, dtype=xy.dtype) < rate 35 | 36 | # Apply the mask to the real and imaginary parts of the frequency data, 37 | # setting the selected frequencies to zero. This 'masks' those frequencies. 38 | freal = xy_f.real.masked_fill(m, 0) 39 | fimag = xy_f.imag.masked_fill(m, 0) 40 | 41 | # Combine the masked real and imaginary parts back into complex frequency data. 42 | xy_f = torch.complex(freal, fimag) 43 | 44 | # Perform an inverse RFFT to transform the data back to the time domain. 45 | # The masked frequencies will affect the reconstructed time series. 46 | xy = torch.fft.irfft(xy_f, dim=dim) 47 | 48 | # If the reconstructed data length differs from the original concatenated length, 49 | # adjust it to maintain consistency. This step ensures the output shape matches the input. 50 | if x_len + y_len != xy.shape[dim]: 51 | xy = torch.cat([x[:, 0:1, ...], xy], 1) 52 | 53 | # Split the reconstructed data back into two parts corresponding to the original x and y. 54 | return torch.split(xy, [x_len, y_len], dim=dim) 55 | -------------------------------------------------------------------------------- /data/augmentations/freq_mix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | 18 | 19 | @torch.no_grad() 20 | def freq_mix(x, y, rate=0.1, dim=1): 21 | # Get lengths of the input tensors along the specified dimension. 22 | x_len = x.shape[dim] 23 | y_len = y.shape[dim] 24 | 25 | # Concatenate x and y along the specified dimension. 26 | # x and y represent past and future targets respectively. 27 | xy = torch.cat([x, y], dim=dim) 28 | 29 | # Perform a real-valued fast Fourier transform (RFFT) on the concatenated tensor. 30 | xy_f = torch.fft.rfft(xy, dim=dim) 31 | 32 | # Create a random mask with a probability defined by 'rate'. 33 | # This mask will be used to select which frequencies to manipulate. 34 | m = torch.rand_like(xy_f, dtype=xy.dtype) < rate 35 | 36 | # Calculate the amplitude of the frequency components. 37 | amp = abs(xy_f) 38 | 39 | # Sort the amplitudes and create a mask to ignore the most dominant frequencies. 40 | _, index = amp.sort(dim=dim, descending=True) 41 | dominant_mask = index > 2 42 | m = torch.bitwise_and(m, dominant_mask) 43 | 44 | # Apply the mask to the real and imaginary parts of the frequency data, 45 | # setting masked frequencies to zero. 46 | freal = xy_f.real.masked_fill(m, 0) 47 | fimag = xy_f.imag.masked_fill(m, 0) 48 | 49 | # Shuffle the batches in x and y to mix data from different sequences. 50 | b_idx = np.arange(x.shape[0]) 51 | np.random.shuffle(b_idx) 52 | x2, y2 = x[b_idx], y[b_idx] 53 | 54 | # Concatenate the shuffled tensors and perform RFFT. 55 | xy2 = torch.cat([x2, y2], dim=dim) 56 | xy2_f = torch.fft.rfft(xy2, dim=dim) 57 | 58 | # Invert the mask and apply it to the shuffled frequency data. 59 | m = torch.bitwise_not(m) 60 | freal2 = xy2_f.real.masked_fill(m, 0) 61 | fimag2 = xy2_f.imag.masked_fill(m, 0) 62 | 63 | # Combine the original and shuffled frequency data. 64 | freal += freal2 65 | fimag += fimag2 66 | 67 | # Reconstruct the complex frequency data and perform an inverse RFFT. 68 | xy_f = torch.complex(freal, fimag) 69 | xy = torch.fft.irfft(xy_f, dim=dim) 70 | 71 | # If the reconstructed data length differs from the original concatenated length, 72 | # adjust it to maintain consistency. 73 | if x_len + y_len != xy.shape[dim]: 74 | xy = torch.cat([x[:, 0:1, ...], xy], 1) 75 | 76 | # Split the reconstructed data back into two parts corresponding to the original x and y. 77 | return torch.split(xy, [x_len, y_len], dim=dim) 78 | -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import random 17 | import warnings 18 | import json 19 | import os 20 | from pathlib import Path 21 | warnings.simplefilter(action="ignore", category=FutureWarning) 22 | warnings.simplefilter(action="ignore", category=UserWarning) 23 | from pathlib import Path 24 | 25 | import numpy as np 26 | import pandas as pd 27 | from tqdm import tqdm 28 | from gluonts.dataset.common import ListDataset 29 | from gluonts.dataset.repository.datasets import get_dataset 30 | from gluonts.transform import InstanceSampler 31 | from pandas.tseries.frequencies import to_offset 32 | 33 | from data.read_new_dataset import get_ett_dataset, create_train_dataset_without_last_k_timesteps, TrainDatasets, MetaData 34 | 35 | class CombinedDatasetIterator: 36 | def __init__(self, datasets, seed, weights): 37 | self._datasets = [iter(el) for el in datasets] 38 | self._weights = weights 39 | self._rng = random.Random(seed) 40 | 41 | def __next__(self): 42 | (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) 43 | return next(dataset) 44 | 45 | 46 | class CombinedDataset: 47 | def __init__(self, datasets, seed=None, weights=None): 48 | self._seed = seed 49 | self._datasets = datasets 50 | self._weights = weights 51 | n_datasets = len(datasets) 52 | if weights is None: 53 | self._weights = [1 / n_datasets] * n_datasets 54 | 55 | def __iter__(self): 56 | return CombinedDatasetIterator(self._datasets, self._seed, self._weights) 57 | 58 | def __len__(self): 59 | return sum([len(ds) for ds in self._datasets]) 60 | 61 | 62 | class SingleInstanceSampler(InstanceSampler): 63 | """ 64 | Randomly pick a single valid window in the given time series. 65 | This fix the bias in ExpectedNumInstanceSampler which leads to varying sampling frequency 66 | of time series of unequal length, not only based on their length, but when they were sampled. 67 | """ 68 | 69 | """End index of the history""" 70 | 71 | def __call__(self, ts: np.ndarray) -> np.ndarray: 72 | a, b = self._get_bounds(ts) 73 | window_size = b - a + 1 74 | if window_size <= 0: 75 | return np.array([], dtype=int) 76 | indices = np.random.randint(window_size, size=1) 77 | return indices + a 78 | 79 | 80 | def _count_timesteps( 81 | left: pd.Timestamp, right: pd.Timestamp, delta: pd.DateOffset 82 | ) -> int: 83 | """ 84 | Count how many timesteps there are between left and right, according to the given timesteps delta. 85 | If the number if not integer, round down. 86 | """ 87 | # This is due to GluonTS replacing Timestamp by Period for version 0.10.0. 88 | # Original code was tested on version 0.9.4 89 | if type(left) == pd.Period: 90 | left = left.to_timestamp() 91 | if type(right) == pd.Period: 92 | right = right.to_timestamp() 93 | assert ( 94 | right >= left 95 | ), f"Case where left ({left}) is after right ({right}) is not implemented in _count_timesteps()." 96 | try: 97 | return (right - left) // delta 98 | except TypeError: 99 | # For MonthEnd offsets, the division does not work, so we count months one by one. 100 | for i in range(10000): 101 | if left + (i + 1) * delta > right: 102 | return i 103 | else: 104 | raise RuntimeError( 105 | f"Too large difference between both timestamps ({left} and {right}) for _count_timesteps()." 106 | ) 107 | 108 | from pathlib import Path 109 | from gluonts.dataset.common import ListDataset 110 | from gluonts.dataset.repository.datasets import get_dataset 111 | 112 | def create_train_dataset_last_k_percentage( 113 | raw_train_dataset, 114 | freq, 115 | k=100 116 | ): 117 | # Get training data 118 | train_data = [] 119 | for i, series in enumerate(raw_train_dataset): 120 | s_train = series.copy() 121 | number_of_values = int(len(s_train["target"]) * k / 100) 122 | train_start_index = len(s_train["target"]) - number_of_values 123 | s_train["target"] = s_train["target"][train_start_index:] 124 | train_data.append(s_train) 125 | 126 | train_data = ListDataset(train_data, freq=freq) 127 | 128 | return train_data 129 | 130 | def create_train_and_val_datasets_with_dates( 131 | name, 132 | dataset_path, 133 | data_id, 134 | history_length, 135 | prediction_length=None, 136 | num_val_windows=None, 137 | val_start_date=None, 138 | train_start_date=None, 139 | freq=None, 140 | last_k_percentage=None 141 | ): 142 | """ 143 | Train Start date is assumed to be the start of the series if not provided 144 | Freq is not given is inferred from the data 145 | We can use ListDataset to just group multiple time series - https://github.com/awslabs/gluonts/issues/695 146 | """ 147 | 148 | if name in ("ett_h1", "ett_h2", "ett_m1", "ett_m2"): 149 | path = os.path.join(dataset_path, "ett_datasets") 150 | raw_dataset = get_ett_dataset(name, path) 151 | elif name in ("cpu_limit_minute", "cpu_usage_minute", \ 152 | "function_delay_minute", "instances_minute", \ 153 | "memory_limit_minute", "memory_usage_minute", \ 154 | "platform_delay_minute", "requests_minute"): 155 | path = os.path.join(dataset_path, "huawei/" + name + ".json") 156 | with open(path, "r") as f: data = json.load(f) 157 | metadata = MetaData(**data["metadata"]) 158 | train_data = [x for x in data["train"] if type(x["target"][0]) != str] 159 | test_data = [x for x in data["test"] if type(x["target"][0]) != str] 160 | train_ds = ListDataset(train_data, freq=metadata.freq) 161 | test_ds = ListDataset(test_data, freq=metadata.freq) 162 | raw_dataset = TrainDatasets(metadata=metadata, train=train_ds, test=test_ds) 163 | elif name in ("beijing_pm25", "AirQualityUCI", "beijing_multisite"): 164 | path = os.path.join(dataset_path, "air_quality/" + name + ".json") 165 | with open(path, "r") as f: 166 | data = json.load(f) 167 | metadata = MetaData(**data["metadata"]) 168 | train_test_data = [x for x in data["data"] if type(x["target"][0]) != str] 169 | full_dataset = ListDataset(train_test_data, freq=metadata.freq) 170 | train_ds = create_train_dataset_without_last_k_timesteps(full_dataset, freq=metadata.freq, k=24) 171 | raw_dataset = TrainDatasets(metadata=metadata, train=train_ds, test=full_dataset) 172 | else: 173 | raw_dataset = get_dataset(name, path=Path(dataset_path)) 174 | 175 | if prediction_length is None: 176 | prediction_length = raw_dataset.metadata.prediction_length 177 | if freq is None: 178 | freq = raw_dataset.metadata.freq 179 | timestep_delta = pd.tseries.frequencies.to_offset(freq) 180 | raw_train_dataset = raw_dataset.train 181 | 182 | if not num_val_windows and not val_start_date: 183 | raise Exception("Either num_val_windows or val_start_date must be provided") 184 | if num_val_windows and val_start_date: 185 | raise Exception("Either num_val_windows or val_start_date must be provided") 186 | 187 | max_train_end_date = None 188 | 189 | # Get training data 190 | total_train_points = 0 191 | train_data = [] 192 | for i, series in enumerate(raw_train_dataset): 193 | s_train = series.copy() 194 | if val_start_date is not None: 195 | train_end_index = _count_timesteps( 196 | series["start"] if not train_start_date else train_start_date, 197 | val_start_date, 198 | timestep_delta, 199 | ) 200 | else: 201 | train_end_index = len(series["target"]) - num_val_windows 202 | # Compute train_start_index based on last_k_percentage 203 | if last_k_percentage: 204 | number_of_values = int(len(s_train["target"]) * last_k_percentage / 100) 205 | train_start_index = train_end_index - number_of_values 206 | else: 207 | train_start_index = 0 208 | s_train["target"] = series["target"][train_start_index:train_end_index] 209 | s_train["item_id"] = i 210 | s_train["data_id"] = data_id 211 | train_data.append(s_train) 212 | total_train_points += len(s_train["target"]) 213 | 214 | # Calculate the end date 215 | end_date = s_train["start"] + to_offset(freq) * (len(s_train["target"]) - 1) 216 | if max_train_end_date is None or end_date > max_train_end_date: 217 | max_train_end_date = end_date 218 | 219 | train_data = ListDataset(train_data, freq=freq) 220 | 221 | # Get validation data 222 | total_val_points = 0 223 | total_val_windows = 0 224 | val_data = [] 225 | for i, series in enumerate(raw_train_dataset): 226 | s_val = series.copy() 227 | if val_start_date is not None: 228 | train_end_index = _count_timesteps( 229 | series["start"], val_start_date, timestep_delta 230 | ) 231 | else: 232 | train_end_index = len(series["target"]) - num_val_windows 233 | val_start_index = train_end_index - prediction_length - history_length 234 | s_val["start"] = series["start"] + val_start_index * timestep_delta 235 | s_val["target"] = series["target"][val_start_index:] 236 | s_val["item_id"] = i 237 | s_val["data_id"] = data_id 238 | val_data.append(s_val) 239 | total_val_points += len(s_val["target"]) 240 | total_val_windows += len(s_val["target"]) - prediction_length - history_length 241 | val_data = ListDataset(val_data, freq=freq) 242 | 243 | total_points = ( 244 | total_train_points 245 | + total_val_points 246 | - (len(raw_train_dataset) * (prediction_length + history_length)) 247 | ) 248 | 249 | return ( 250 | train_data, 251 | val_data, 252 | total_train_points, 253 | total_val_points, 254 | total_val_windows, 255 | max_train_end_date, 256 | total_points, 257 | ) 258 | 259 | 260 | def create_test_dataset( 261 | name, dataset_path, history_length, freq=None, data_id=None 262 | ): 263 | """ 264 | For now, only window per series is used. 265 | make_evaluation_predictions automatically only predicts for the last "prediction_length" timesteps 266 | NOTE / TODO: For datasets where the test set has more series (possibly due to more timestamps), \ 267 | we should check if we only use the last N series where N = series per single timestamp, or if we should do something else. 268 | """ 269 | 270 | if name in ("ett_h1", "ett_h2", "ett_m1", "ett_m2"): 271 | path = os.path.join(dataset_path, "ett_datasets") 272 | dataset = get_ett_dataset(name, path) 273 | elif name in ("cpu_limit_minute", "cpu_usage_minute", \ 274 | "function_delay_minute", "instances_minute", \ 275 | "memory_limit_minute", "memory_usage_minute", \ 276 | "platform_delay_minute", "requests_minute"): 277 | path = os.path.join(dataset_path, "huawei/" + name + ".json") 278 | with open(path, "r") as f: data = json.load(f) 279 | metadata = MetaData(**data["metadata"]) 280 | train_data = [x for x in data["train"] if type(x["target"][0]) != str] 281 | test_data = [x for x in data["test"] if type(x["target"][0]) != str] 282 | train_ds = ListDataset(train_data, freq=metadata.freq) 283 | test_ds = ListDataset(test_data, freq=metadata.freq) 284 | dataset = TrainDatasets(metadata=metadata, train=train_ds, test=test_ds) 285 | elif name in ("beijing_pm25", "AirQualityUCI", "beijing_multisite"): 286 | path = os.path.join(dataset_path, "air_quality/" + name + ".json") 287 | with open(path, "r") as f: 288 | data = json.load(f) 289 | metadata = MetaData(**data["metadata"]) 290 | train_test_data = [x for x in data["data"] if type(x["target"][0]) != str] 291 | full_dataset = ListDataset(train_test_data, freq=metadata.freq) 292 | train_ds = create_train_dataset_without_last_k_timesteps(full_dataset, freq=metadata.freq, k=24) 293 | dataset = TrainDatasets(metadata=metadata, train=train_ds, test=full_dataset) 294 | else: 295 | dataset = get_dataset(name, path=Path(dataset_path)) 296 | 297 | if freq is None: 298 | freq = dataset.metadata.freq 299 | prediction_length = dataset.metadata.prediction_length 300 | data = [] 301 | total_points = 0 302 | for i, series in enumerate(dataset.test): 303 | offset = len(series["target"]) - (history_length + prediction_length) 304 | if offset > 0: 305 | target = series["target"][-(history_length + prediction_length) :] 306 | data.append( 307 | { 308 | "target": target, 309 | "start": series["start"] + offset, 310 | "item_id": i, 311 | "data_id": data_id, 312 | } 313 | ) 314 | else: 315 | series_copy = copy.deepcopy(series) 316 | series_copy["item_id"] = i 317 | series_copy["data_id"] = data_id 318 | data.append(series_copy) 319 | total_points += len(data[-1]["target"]) 320 | return ListDataset(data, freq=freq), prediction_length, total_points -------------------------------------------------------------------------------- /data/dataset_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ALL_DATASETS = ["australian_electricity_demand", "electricity_hourly", "london_smart_meters_without_missing", "solar_10_minutes", "wind_farms_without_missing", "pedestrian_counts", "uber_tlc_hourly", "traffic", "kdd_cup_2018_without_missing", "saugeenday", "sunspot_without_missing", "exchange_rate", "cpu_limit_minute", "cpu_usage_minute", "function_delay_minute", "instances_minute", "memory_limit_minute", "memory_usage_minute", "platform_delay_minute", "requests_minute", "ett_h1", "ett_h2", "ett_m1", "ett_m2", "beijing_pm25", "AirQualityUCI", "beijing_multisite"] -------------------------------------------------------------------------------- /data/read_new_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | warnings.simplefilter(action="ignore", category=FutureWarning) 17 | warnings.simplefilter(action="ignore", category=UserWarning) 18 | 19 | import gzip, json 20 | from gluonts.dataset.common import ListDataset, TrainDatasets, MetaData 21 | from pathlib import Path 22 | from gluonts.dataset.repository.datasets import get_dataset 23 | import os 24 | 25 | def create_train_dataset_without_last_k_timesteps( 26 | raw_train_dataset, 27 | freq, 28 | k=0 29 | ): 30 | train_data = [] 31 | for i, series in enumerate(raw_train_dataset): 32 | s_train = series.copy() 33 | s_train["target"] = s_train["target"][:len(s_train["target"])-k] 34 | train_data.append(s_train) 35 | train_data = ListDataset(train_data, freq=freq) 36 | return train_data 37 | 38 | def load_jsonl_gzip_file(file_path): 39 | with gzip.open(file_path, 'rt') as f: 40 | return [json.loads(line) for line in f] 41 | 42 | def get_ett_dataset(dataset_name, path): 43 | dataset_path = Path(path) / dataset_name 44 | metadata_path = dataset_path / 'metadata.json' 45 | with open(metadata_path, 'r') as f: 46 | metadata_dict = json.load(f) 47 | metadata = MetaData(**metadata_dict) 48 | # Load train and test datasets 49 | train_data_path = dataset_path / 'train' / 'data.json.gz' 50 | test_data_path = dataset_path / 'test' / 'data.json.gz' 51 | # test dataset 52 | test_data = load_jsonl_gzip_file(test_data_path) 53 | # Create GluonTS ListDatasets 54 | test_ds = ListDataset(test_data, freq=metadata.freq) 55 | train_ds = create_train_dataset_without_last_k_timesteps(test_ds, freq=metadata.freq, k=24) 56 | return TrainDatasets(metadata=metadata, train=train_ds, test=test_ds) 57 | 58 | if __name__ == "__main__": 59 | dataset_name = "ett_h1" 60 | 61 | if dataset_name in ("ett_h1", "ett_h2", "ett_m1", "ett_m2"): 62 | path = "data/datasets/ett_datasets" 63 | ds = get_ett_dataset(dataset_name, path) 64 | 65 | if dataset_name in ("cpu_limit_minute", "cpu_usage_minute", \ 66 | "function_delay_minute", "instances_minute", \ 67 | "memory_limit_minute", "memory_usage_minute", \ 68 | "platform_delay_minute", "requests_minute"): 69 | path = "data/datasets/huawei/" + dataset_name + ".json" 70 | with open(path, "r") as f: data = json.load(f) 71 | metadata = MetaData(**data["metadata"]) 72 | train_data = [x for x in data["train"] if type(x["target"][0]) != str] 73 | test_data = [x for x in data["test"] if type(x["target"][0]) != str] 74 | train_ds = ListDataset(train_data, freq=metadata.freq) 75 | test_ds = ListDataset(test_data, freq=metadata.freq) 76 | ds = TrainDatasets(metadata=metadata, train=train_ds, test=test_ds) 77 | 78 | if dataset_name in ("beijing_pm25", "AirQualityUCI", "beijing_multisite"): 79 | path = "data/datasets/air_quality/" + dataset_name + ".json" 80 | with open(path, "r") as f: 81 | data = json.load(f) 82 | metadata = MetaData(**data["metadata"]) 83 | train_test_data = [x for x in data["data"] if type(x["target"][0]) != str] 84 | full_dataset = ListDataset(train_test_data, freq=metadata.freq) 85 | train_ds = create_train_dataset_without_last_k_timesteps(test_ds, freq=metadata.freq, k=24) 86 | ds = TrainDatasets(metadata=metadata, train=train_ds, test=full_dataset) -------------------------------------------------------------------------------- /gluon_utils/gluon_ts_distributions/implicit_quantile_network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # A copy of the License is located at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # or in the "license" file accompanying this file. This file is distributed 10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 11 | # express or implied. See the License for the specific language governing 12 | # permissions and limitations under the License. 13 | 14 | from functools import partial 15 | from typing import Callable, Dict, Optional, Tuple 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | from torch.distributions import Beta, Distribution, constraints 21 | 22 | from gluonts.core.component import validated 23 | from gluonts.torch.distributions import DistributionOutput 24 | from gluonts.torch.modules.lambda_layer import LambdaLayer 25 | 26 | 27 | class QuantileLayer(nn.Module): 28 | r""" 29 | Implicit Quantile Layer from the paper ``IQN for Distributional 30 | Reinforcement Learning`` (https://arxiv.org/abs/1806.06923) by 31 | Dabney et al. 2018. 32 | """ 33 | 34 | def __init__(self, num_output: int, cos_embedding_dim: int = 128): 35 | super().__init__() 36 | 37 | self.output_layer = nn.Sequential( 38 | nn.Linear(cos_embedding_dim, cos_embedding_dim), 39 | nn.PReLU(), 40 | nn.Linear(cos_embedding_dim, num_output), 41 | ) 42 | 43 | self.register_buffer("integers", torch.arange(0, cos_embedding_dim)) 44 | 45 | def forward(self, tau: torch.Tensor) -> torch.Tensor: # tau: [B, T] 46 | cos_emb_tau = torch.cos(tau.unsqueeze(-1) * self.integers * torch.pi) 47 | return self.output_layer(cos_emb_tau) 48 | 49 | 50 | class ImplicitQuantileModule(nn.Module): 51 | r""" 52 | Implicit Quantile Network from the paper ``IQN for Distributional 53 | Reinforcement Learning`` (https://arxiv.org/abs/1806.06923) by 54 | Dabney et al. 2018. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | in_features: int, 60 | args_dim: Dict[str, int], 61 | domain_map: Callable[..., Tuple[torch.Tensor]], 62 | concentration1: float = 1.0, 63 | concentration0: float = 1.0, 64 | output_domain_map=None, 65 | cos_embedding_dim: int = 64, 66 | ): 67 | super().__init__() 68 | self.output_domain_map = output_domain_map 69 | self.domain_map = domain_map 70 | self.beta = Beta(concentration1=concentration1, concentration0=concentration0) 71 | 72 | self.quantile_layer = QuantileLayer( 73 | in_features, cos_embedding_dim=cos_embedding_dim 74 | ) 75 | self.output_layer = nn.Sequential( 76 | nn.Linear(in_features, in_features), nn.PReLU() 77 | ) 78 | 79 | self.proj = nn.ModuleList( 80 | [nn.Linear(in_features, dim) for dim in args_dim.values()] 81 | ) 82 | 83 | def forward(self, inputs: torch.Tensor): 84 | if self.training: 85 | taus = self.beta.sample(sample_shape=inputs.shape[:-1]).to(inputs.device) 86 | else: 87 | taus = torch.rand(size=inputs.shape[:-1], device=inputs.device) 88 | 89 | emb_taus = self.quantile_layer(taus) 90 | emb_inputs = inputs * (1.0 + emb_taus) 91 | 92 | emb_outputs = self.output_layer(emb_inputs) 93 | outputs = [proj(emb_outputs).squeeze(-1) for proj in self.proj] 94 | if self.output_domain_map is not None: 95 | outputs = [self.output_domain_map(output) for output in outputs] 96 | return (*self.domain_map(*outputs), taus) 97 | 98 | 99 | class ImplicitQuantileNetwork(Distribution): 100 | r""" 101 | Distribution class for the Implicit Quantile from which 102 | we can sample or calculate the quantile loss. 103 | 104 | Parameters 105 | ---------- 106 | outputs 107 | Outputs from the Implicit Quantile Network. 108 | taus 109 | Tensor random numbers from the Beta or Uniform distribution for the 110 | corresponding outputs. 111 | """ 112 | 113 | arg_constraints: Dict[str, constraints.Constraint] = {} 114 | 115 | def __init__(self, outputs: torch.Tensor, taus: torch.Tensor, validate_args=None): 116 | self.taus = taus 117 | self.outputs = outputs 118 | 119 | super().__init__(batch_shape=outputs.shape, validate_args=validate_args) 120 | 121 | @torch.no_grad() 122 | def sample(self, sample_shape=torch.Size()) -> torch.Tensor: 123 | return self.outputs 124 | 125 | def quantile_loss(self, value: torch.Tensor) -> torch.Tensor: 126 | # penalize by tau for under-predicting 127 | # and by 1-tau for over-predicting 128 | return (self.taus - (value < self.outputs).float()) * (value - self.outputs) 129 | 130 | 131 | class ImplicitQuantileNetworkOutput(DistributionOutput): 132 | r""" 133 | DistributionOutput class for the IQN from the paper 134 | ``Probabilistic Time Series Forecasting with Implicit Quantile Networks`` 135 | (https://arxiv.org/abs/2107.03743) by Gouttes et al. 2021. 136 | 137 | Parameters 138 | ---------- 139 | output_domain 140 | Optional domain mapping of the output. Can be "positive", "unit" 141 | or None. 142 | concentration1 143 | Alpha parameter of the Beta distribution when sampling the taus 144 | during training. 145 | concentration0 146 | Beta parameter of the Beta distribution when sampling the taus 147 | during training. 148 | cos_embedding_dim 149 | The embedding dimension for the taus embedding layer of IQN. 150 | Default is 64. 151 | """ 152 | 153 | distr_cls = ImplicitQuantileNetwork 154 | args_dim = {"quantile_function": 1} 155 | 156 | @validated() 157 | def __init__( 158 | self, 159 | output_domain: Optional[str] = None, 160 | concentration1: float = 1.0, 161 | concentration0: float = 1.0, 162 | cos_embedding_dim: int = 64, 163 | ) -> None: 164 | super().__init__() 165 | 166 | self.concentration1 = concentration1 167 | self.concentration0 = concentration0 168 | self.cos_embedding_dim = cos_embedding_dim 169 | 170 | if output_domain in ["positive", "unit"]: 171 | output_domain_map_func = { 172 | "positive": F.softplus, 173 | "unit": partial(F.softmax, dim=-1), 174 | } 175 | self.output_domain_map = output_domain_map_func[output_domain] 176 | else: 177 | self.output_domain_map = None 178 | 179 | def get_args_proj(self, in_features: int) -> nn.Module: 180 | return ImplicitQuantileModule( 181 | in_features=in_features, 182 | args_dim=self.args_dim, 183 | output_domain_map=self.output_domain_map, 184 | domain_map=LambdaLayer(self.domain_map), 185 | concentration1=self.concentration1, 186 | concentration0=self.concentration0, 187 | cos_embedding_dim=self.cos_embedding_dim, 188 | ) 189 | 190 | @classmethod 191 | def domain_map(cls, *args): 192 | return args 193 | 194 | def distribution(self, distr_args, loc=0, scale=None) -> ImplicitQuantileNetwork: 195 | (outputs, taus) = distr_args 196 | 197 | if scale is not None: 198 | outputs = outputs * scale 199 | if loc is not None: 200 | outputs = outputs + loc 201 | return self.distr_cls(outputs=outputs, taus=taus) 202 | 203 | @property 204 | def event_shape(self): 205 | return () 206 | 207 | def loss( 208 | self, 209 | target: torch.Tensor, 210 | distr_args: Tuple[torch.Tensor, ...], 211 | loc: Optional[torch.Tensor] = None, 212 | scale: Optional[torch.Tensor] = None, 213 | ) -> torch.Tensor: 214 | distribution = self.distribution(distr_args, loc=loc, scale=scale) 215 | return distribution.quantile_loss(target) 216 | 217 | 218 | iqn = ImplicitQuantileNetworkOutput() 219 | -------------------------------------------------------------------------------- /gluon_utils/scalers/robust_scaler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import torch 18 | from gluonts.core.component import validated 19 | from gluonts.torch.scaler import Scaler 20 | 21 | 22 | class RobustScaler(Scaler): 23 | """ 24 | Computes a scaling factor by removing the median and scaling by the 25 | interquartile range (IQR). 26 | 27 | Parameters 28 | ---------- 29 | dim 30 | dimension along which to compute the scale 31 | keepdim 32 | controls whether to retain dimension ``dim`` (of length 1) in the 33 | scale tensor, or suppress it. 34 | minimum_scale 35 | minimum possible scale that is used for any item. 36 | """ 37 | 38 | @validated() 39 | def __init__( 40 | self, 41 | dim: int = -1, 42 | keepdim: bool = False, 43 | minimum_scale: float = 1e-10, 44 | ) -> None: 45 | self.dim = dim 46 | self.keepdim = keepdim 47 | self.minimum_scale = minimum_scale 48 | 49 | def __call__( 50 | self, data: torch.Tensor, weights: torch.Tensor 51 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 52 | assert ( 53 | data.shape == weights.shape 54 | ), "data and observed_indicator must have same shape" 55 | 56 | with torch.no_grad(): 57 | observed_data = torch.where(weights == 1, data, torch.nan) 58 | 59 | med = torch.nanmedian(observed_data, dim=self.dim, keepdim=True).values 60 | q1 = torch.nanquantile(observed_data, 0.25, dim=self.dim, keepdim=True) 61 | q3 = torch.nanquantile(observed_data, 0.75, dim=self.dim, keepdim=True) 62 | iqr = q3 - q1 63 | 64 | # if observed data is all zeros, nanmedian returns nan 65 | loc = torch.where(torch.isnan(med), torch.zeros_like(med), med) 66 | scale = torch.where(torch.isnan(iqr), torch.ones_like(iqr), iqr) 67 | scale = torch.maximum(scale, torch.full_like(iqr, self.minimum_scale)) 68 | 69 | scaled_data = (data - loc) / scale 70 | 71 | if not self.keepdim: 72 | loc = torch.squeeze(loc, dim=self.dim) 73 | scale = torch.squeeze(scale, dim=self.dim) 74 | 75 | # assert no nans in scaled data, loc or scale 76 | assert not torch.any(torch.isnan(scaled_data)) 77 | assert not torch.any(torch.isnan(loc)) 78 | assert not torch.any(torch.isnan(scale)) 79 | assert not torch.any(scale == 0) 80 | 81 | return scaled_data, loc, scale 82 | -------------------------------------------------------------------------------- /images/lagllama.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/time-series-foundation-models/lag-llama/013ff0c786c6fee677a360b76df95685c9dac25d/images/lagllama.webp -------------------------------------------------------------------------------- /lag_llama/gluon/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. -------------------------------------------------------------------------------- /lag_llama/gluon/estimator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, Iterable, Optional 16 | 17 | import pytorch_lightning as pl 18 | import torch 19 | 20 | from gluonts.core.component import validated 21 | from gluonts.dataset.common import Dataset 22 | from gluonts.dataset.field_names import FieldName 23 | from gluonts.dataset.loader import as_stacked_batches 24 | from gluonts.dataset.stat import calculate_dataset_statistics 25 | from gluonts.itertools import Cyclic 26 | from gluonts.time_feature import ( 27 | get_lags_for_frequency, 28 | time_features_from_frequency_str, 29 | ) 30 | from gluonts.torch.distributions import StudentTOutput, NegativeBinomialOutput 31 | from gluonts.torch.model.estimator import PyTorchLightningEstimator 32 | from gluonts.torch.model.predictor import PyTorchPredictor 33 | from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood 34 | from gluonts.transform import ( 35 | AddObservedValuesIndicator, 36 | AddTimeFeatures, 37 | Chain, 38 | DummyValueImputation, 39 | ExpectedNumInstanceSampler, 40 | InstanceSampler, 41 | InstanceSplitter, 42 | TestSplitSampler, 43 | Transformation, 44 | ValidationSplitSampler, 45 | ) 46 | 47 | from gluon_utils.gluon_ts_distributions.implicit_quantile_network import ( 48 | ImplicitQuantileNetworkOutput, 49 | ) 50 | from lag_llama.gluon.lightning_module import LagLlamaLightningModule 51 | 52 | PREDICTION_INPUT_NAMES = [ 53 | "past_target", 54 | "past_observed_values", 55 | ] 56 | TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ 57 | "future_target", 58 | "future_observed_values", 59 | ] 60 | 61 | 62 | class LagLlamaEstimator(PyTorchLightningEstimator): 63 | """ 64 | An estimator training a ConvTSMixer model for forecasting. 65 | 66 | This class is uses the model defined in ``ConvTSMixerModel``, 67 | and wraps it into a ``ConvTSMixerLightningModule`` for training 68 | purposes: training is performed using PyTorch Lightning's ``pl.Trainer`` 69 | class. 70 | 71 | Parameters 72 | ---------- 73 | prediction_length 74 | Length of the prediction horizon. 75 | context_length 76 | Number of time steps prior to prediction time that the model 77 | takes as inputs (default: ``10 * prediction_length``). 78 | lr 79 | Learning rate (default: ``1e-3``). 80 | weight_decay 81 | Weight decay regularization parameter (default: ``1e-8``). 82 | distr_output 83 | Distribution to use to evaluate observations and sample predictions 84 | (default: StudentTOutput()). 85 | loss 86 | Loss to be optimized during training 87 | (default: ``NegativeLogLikelihood()``). 88 | batch_norm 89 | Whether to apply batch normalization. 90 | batch_size 91 | The size of the batches to be used for training (default: 32). 92 | num_batches_per_epoch 93 | Number of batches to be processed in each training epoch 94 | (default: 50). 95 | trainer_kwargs 96 | Additional arguments to provide to ``pl.Trainer`` for construction. 97 | train_sampler 98 | Controls the sampling of windows during training. 99 | validation_sampler 100 | Controls the sampling of windows during validation. 101 | use_single_pass_sampling 102 | If True, use a single forward pass and sample N times from the saved distribution, much more efficient. 103 | If False, perform N forward passes and maintain N parallel prediction paths, this is true probalistic forecasting. 104 | (default: False) 105 | """ 106 | 107 | @validated() 108 | def __init__( 109 | self, 110 | prediction_length: int, 111 | context_length: Optional[int] = None, 112 | input_size: int = 1, 113 | n_layer: int = 1, 114 | n_embd_per_head: int = 32, 115 | n_head: int = 4, 116 | max_context_length: int = 2048, 117 | rope_scaling=None, 118 | scaling: Optional[str] = "mean", 119 | lr: float = 1e-3, 120 | weight_decay: float = 1e-8, 121 | # Augmentations arguments 122 | aug_prob: float = 0.1, 123 | freq_mask_rate: float = 0.1, 124 | freq_mixing_rate: float = 0.1, 125 | jitter_prob: float = 0.0, 126 | jitter_sigma: float = 0.03, 127 | scaling_prob: float = 0.0, 128 | scaling_sigma: float = 0.1, 129 | rotation_prob: float = 0.0, 130 | permutation_prob: float = 0.0, 131 | permutation_max_segments: int = 5, 132 | permutation_seg_mode: str = "equal", 133 | magnitude_warp_prob: float = 0.0, 134 | magnitude_warp_sigma: float = 0.2, 135 | magnitude_warp_knot: int = 4, 136 | time_warp_prob: float = 0.0, 137 | time_warp_sigma: float = 0.2, 138 | time_warp_knot: int = 4, 139 | window_slice_prob: float = 0.0, 140 | window_slice_reduce_ratio: float = 0.9, 141 | window_warp_prob: float = 0.0, 142 | window_warp_window_ratio: float = 0.1, 143 | window_warp_scales: list = [0.5, 2.0], 144 | # Continuning model arguments 145 | distr_output: str = "studentT", 146 | loss: DistributionLoss = NegativeLogLikelihood(), 147 | num_parallel_samples: int = 100, 148 | batch_size: int = 32, 149 | num_batches_per_epoch: int = 50, 150 | trainer_kwargs: Optional[Dict[str, Any]] = None, 151 | train_sampler: Optional[InstanceSampler] = None, 152 | validation_sampler: Optional[InstanceSampler] = None, 153 | time_feat: bool = False, 154 | dropout: float = 0.0, 155 | lags_seq: list = ["Q", "M", "W", "D", "H", "T", "S"], 156 | data_id_to_name_map: dict = {}, 157 | use_cosine_annealing_lr: bool = False, 158 | cosine_annealing_lr_args: dict = {}, 159 | track_loss_per_series: bool = False, 160 | ckpt_path: Optional[str] = None, 161 | nonnegative_pred_samples: bool = False, 162 | use_single_pass_sampling: bool = False, 163 | device: torch.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 164 | ) -> None: 165 | default_trainer_kwargs = {"max_epochs": 100} 166 | if trainer_kwargs is not None: 167 | default_trainer_kwargs.update(trainer_kwargs) 168 | super().__init__(trainer_kwargs=default_trainer_kwargs) 169 | 170 | self.scaling = scaling 171 | self.input_size = input_size 172 | self.prediction_length = prediction_length 173 | self.context_length = context_length 174 | self.max_context_length = max_context_length 175 | 176 | lag_indices = [] 177 | for freq in lags_seq: 178 | lag_indices.extend( 179 | get_lags_for_frequency(freq_str=freq, num_default_lags=1) 180 | ) 181 | 182 | if len(lag_indices): 183 | self.lags_seq = sorted(set(lag_indices)) 184 | self.lags_seq = [lag_index - 1 for lag_index in self.lags_seq] 185 | else: 186 | self.lags_seq = [] 187 | 188 | self.n_head = n_head 189 | self.n_layer = n_layer 190 | self.n_embd_per_head = n_embd_per_head 191 | self.rope_scaling = rope_scaling 192 | 193 | self.lr = lr 194 | self.weight_decay = weight_decay 195 | if distr_output == "studentT": 196 | distr_output = StudentTOutput() 197 | elif distr_output == "neg_bin": 198 | distr_output = NegativeBinomialOutput() 199 | elif distr_output == "iqn": 200 | distr_output = ImplicitQuantileNetworkOutput() 201 | self.distr_output = distr_output 202 | self.num_parallel_samples = num_parallel_samples 203 | self.loss = loss 204 | self.batch_size = batch_size 205 | self.num_batches_per_epoch = num_batches_per_epoch 206 | self.nonnegative_pred_samples = nonnegative_pred_samples 207 | self.use_single_pass_sampling = use_single_pass_sampling 208 | 209 | self.train_sampler = train_sampler or ExpectedNumInstanceSampler( 210 | num_instances=1.0, 211 | min_future=prediction_length, 212 | min_instances=1, 213 | ) 214 | self.validation_sampler = validation_sampler or ValidationSplitSampler( 215 | min_future=prediction_length 216 | ) 217 | 218 | self.aug_prob = aug_prob 219 | self.freq_mask_rate = freq_mask_rate 220 | self.freq_mixing_rate = freq_mixing_rate 221 | self.jitter_prob = jitter_prob 222 | self.jitter_sigma = jitter_sigma 223 | self.scaling_prob = scaling_prob 224 | self.scaling_sigma = scaling_sigma 225 | self.rotation_prob = rotation_prob 226 | self.permutation_prob = permutation_prob 227 | self.permutation_max_segments = permutation_max_segments 228 | self.permutation_seg_mode = permutation_seg_mode 229 | self.magnitude_warp_prob = magnitude_warp_prob 230 | self.magnitude_warp_sigma = magnitude_warp_sigma 231 | self.magnitude_warp_knot = magnitude_warp_knot 232 | self.time_warp_prob = time_warp_prob 233 | self.time_warp_sigma = time_warp_sigma 234 | self.time_warp_knot = time_warp_knot 235 | self.window_slice_prob = window_slice_prob 236 | self.window_slice_reduce_ratio = window_slice_reduce_ratio 237 | self.window_warp_prob = window_warp_prob 238 | self.window_warp_window_ratio = window_warp_window_ratio 239 | self.window_warp_scales = window_warp_scales 240 | self.track_loss_per_series = track_loss_per_series 241 | 242 | self.time_feat = time_feat 243 | self.dropout = dropout 244 | self.data_id_to_name_map = data_id_to_name_map 245 | self.ckpt_path = ckpt_path 246 | 247 | self.use_cosine_annealing_lr = use_cosine_annealing_lr 248 | self.cosine_annealing_lr_args = cosine_annealing_lr_args 249 | self.device = device 250 | 251 | @classmethod 252 | def derive_auto_fields(cls, train_iter): 253 | stats = calculate_dataset_statistics(train_iter) 254 | 255 | return { 256 | "num_feat_dynamic_real": stats.num_feat_dynamic_real, 257 | "num_feat_static_cat": len(stats.feat_static_cat), 258 | "cardinality": [len(cats) for cats in stats.feat_static_cat], 259 | } 260 | 261 | def create_transformation(self) -> Transformation: 262 | if self.time_feat: 263 | return Chain( 264 | [ 265 | AddTimeFeatures( 266 | start_field=FieldName.START, 267 | target_field=FieldName.TARGET, 268 | output_field=FieldName.FEAT_TIME, 269 | time_features=time_features_from_frequency_str("S"), 270 | pred_length=self.prediction_length, 271 | ), 272 | AddObservedValuesIndicator( 273 | target_field=FieldName.TARGET, 274 | output_field=FieldName.OBSERVED_VALUES, 275 | imputation_method=DummyValueImputation(0.0), 276 | ), 277 | ] 278 | ) 279 | else: 280 | return Chain( 281 | [ 282 | AddObservedValuesIndicator( 283 | target_field=FieldName.TARGET, 284 | output_field=FieldName.OBSERVED_VALUES, 285 | imputation_method=DummyValueImputation(0.0), 286 | ), 287 | ] 288 | ) 289 | 290 | def create_lightning_module(self, use_kv_cache: bool = False) -> pl.LightningModule: 291 | model_kwargs = { 292 | "input_size": self.input_size, 293 | "context_length": self.context_length, 294 | "max_context_length": self.max_context_length, 295 | "lags_seq": self.lags_seq, 296 | "n_layer": self.n_layer, 297 | "n_embd_per_head": self.n_embd_per_head, 298 | "n_head": self.n_head, 299 | "scaling": self.scaling, 300 | "distr_output": self.distr_output, 301 | "num_parallel_samples": self.num_parallel_samples, 302 | "rope_scaling": self.rope_scaling, 303 | "time_feat": self.time_feat, 304 | "dropout": self.dropout, 305 | } 306 | if self.ckpt_path is not None: 307 | return LagLlamaLightningModule.load_from_checkpoint( 308 | checkpoint_path=self.ckpt_path, 309 | map_location=self.device, 310 | strict=False, 311 | loss=self.loss, 312 | lr=self.lr, 313 | weight_decay=self.weight_decay, 314 | context_length=self.context_length, 315 | prediction_length=self.prediction_length, 316 | model_kwargs=model_kwargs, 317 | # Augmentations 318 | aug_prob=self.aug_prob, 319 | freq_mask_rate=self.freq_mask_rate, 320 | freq_mixing_rate=self.freq_mixing_rate, 321 | jitter_prob=self.jitter_prob, 322 | jitter_sigma=self.jitter_sigma, 323 | scaling_prob=self.scaling_prob, 324 | scaling_sigma=self.scaling_sigma, 325 | rotation_prob=self.rotation_prob, 326 | permutation_prob=self.permutation_prob, 327 | permutation_max_segments=self.permutation_max_segments, 328 | permutation_seg_mode=self.permutation_seg_mode, 329 | magnitude_warp_prob=self.magnitude_warp_prob, 330 | magnitude_warp_sigma=self.magnitude_warp_sigma, 331 | magnitude_warp_knot=self.magnitude_warp_knot, 332 | time_warp_prob=self.time_warp_prob, 333 | time_warp_sigma=self.time_warp_sigma, 334 | time_warp_knot=self.time_warp_knot, 335 | window_slice_prob=self.window_slice_prob, 336 | window_slice_reduce_ratio=self.window_slice_reduce_ratio, 337 | window_warp_prob=self.window_warp_prob, 338 | window_warp_window_ratio=self.window_warp_window_ratio, 339 | window_warp_scales=self.window_warp_scales, 340 | use_kv_cache=use_kv_cache, 341 | data_id_to_name_map=self.data_id_to_name_map, 342 | use_cosine_annealing_lr=self.use_cosine_annealing_lr, 343 | cosine_annealing_lr_args=self.cosine_annealing_lr_args, 344 | track_loss_per_series=self.track_loss_per_series, 345 | nonnegative_pred_samples=self.nonnegative_pred_samples, 346 | ) 347 | else: 348 | return LagLlamaLightningModule( 349 | loss=self.loss, 350 | lr=self.lr, 351 | weight_decay=self.weight_decay, 352 | context_length=self.context_length, 353 | prediction_length=self.prediction_length, 354 | model_kwargs=model_kwargs, 355 | # Augmentations 356 | aug_prob=self.aug_prob, 357 | freq_mask_rate=self.freq_mask_rate, 358 | freq_mixing_rate=self.freq_mixing_rate, 359 | jitter_prob=self.jitter_prob, 360 | jitter_sigma=self.jitter_sigma, 361 | scaling_prob=self.scaling_prob, 362 | scaling_sigma=self.scaling_sigma, 363 | rotation_prob=self.rotation_prob, 364 | permutation_prob=self.permutation_prob, 365 | permutation_max_segments=self.permutation_max_segments, 366 | permutation_seg_mode=self.permutation_seg_mode, 367 | magnitude_warp_prob=self.magnitude_warp_prob, 368 | magnitude_warp_sigma=self.magnitude_warp_sigma, 369 | magnitude_warp_knot=self.magnitude_warp_knot, 370 | time_warp_prob=self.time_warp_prob, 371 | time_warp_sigma=self.time_warp_sigma, 372 | time_warp_knot=self.time_warp_knot, 373 | window_slice_prob=self.window_slice_prob, 374 | window_slice_reduce_ratio=self.window_slice_reduce_ratio, 375 | window_warp_prob=self.window_warp_prob, 376 | window_warp_window_ratio=self.window_warp_window_ratio, 377 | window_warp_scales=self.window_warp_scales, 378 | use_kv_cache=use_kv_cache, 379 | data_id_to_name_map=self.data_id_to_name_map, 380 | use_cosine_annealing_lr=self.use_cosine_annealing_lr, 381 | cosine_annealing_lr_args=self.cosine_annealing_lr_args, 382 | track_loss_per_series=self.track_loss_per_series, 383 | nonnegative_pred_samples=self.nonnegative_pred_samples, 384 | ) 385 | 386 | def _create_instance_splitter(self, module: LagLlamaLightningModule, mode: str): 387 | assert mode in ["training", "validation", "test"] 388 | 389 | instance_sampler = { 390 | "training": self.train_sampler, 391 | "validation": self.validation_sampler, 392 | "test": TestSplitSampler(), 393 | }[mode] 394 | 395 | return InstanceSplitter( 396 | target_field=FieldName.TARGET, 397 | is_pad_field=FieldName.IS_PAD, 398 | start_field=FieldName.START, 399 | forecast_start_field=FieldName.FORECAST_START, 400 | instance_sampler=instance_sampler, 401 | past_length=self.context_length + max(self.lags_seq), 402 | future_length=self.prediction_length, 403 | time_series_fields=[FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES] 404 | if self.time_feat 405 | else [FieldName.OBSERVED_VALUES], 406 | dummy_value=self.distr_output.value_in_support, 407 | ) 408 | 409 | def create_training_data_loader( 410 | self, 411 | data: Dataset, 412 | module: LagLlamaLightningModule, 413 | shuffle_buffer_length: Optional[int] = None, 414 | **kwargs, 415 | ) -> Iterable: 416 | data = Cyclic(data).stream() 417 | instances = self._create_instance_splitter(module, "training").apply( 418 | data, is_train=True 419 | ) 420 | if self.time_feat: 421 | return as_stacked_batches( 422 | instances, 423 | batch_size=self.batch_size, 424 | shuffle_buffer_length=shuffle_buffer_length, 425 | field_names=TRAINING_INPUT_NAMES 426 | + ["past_time_feat", "future_time_feat"], 427 | output_type=torch.tensor, 428 | num_batches_per_epoch=self.num_batches_per_epoch, 429 | ) 430 | 431 | else: 432 | return as_stacked_batches( 433 | instances, 434 | batch_size=self.batch_size, 435 | shuffle_buffer_length=shuffle_buffer_length, 436 | field_names=TRAINING_INPUT_NAMES, 437 | output_type=torch.tensor, 438 | num_batches_per_epoch=self.num_batches_per_epoch, 439 | ) 440 | 441 | def create_validation_data_loader( 442 | self, 443 | data: Dataset, 444 | module: LagLlamaLightningModule, 445 | **kwargs, 446 | ) -> Iterable: 447 | instances = self._create_instance_splitter(module, "validation").apply( 448 | data, is_train=True 449 | ) 450 | if self.time_feat: 451 | return as_stacked_batches( 452 | instances, 453 | batch_size=self.batch_size, 454 | field_names=TRAINING_INPUT_NAMES 455 | + ["past_time_feat", "future_time_feat"], 456 | output_type=torch.tensor, 457 | ) 458 | else: 459 | return as_stacked_batches( 460 | instances, 461 | batch_size=self.batch_size, 462 | field_names=TRAINING_INPUT_NAMES, 463 | output_type=torch.tensor, 464 | ) 465 | 466 | def create_predictor( 467 | self, 468 | transformation: Transformation, 469 | module, 470 | ) -> PyTorchPredictor: 471 | prediction_splitter = self._create_instance_splitter(module, "test") 472 | if self.time_feat: 473 | return PyTorchPredictor( 474 | input_transform=transformation + prediction_splitter, 475 | input_names=PREDICTION_INPUT_NAMES 476 | + ["past_time_feat", "future_time_feat"], 477 | prediction_net=module, 478 | batch_size=self.batch_size, 479 | prediction_length=self.prediction_length, 480 | device="cuda" if torch.cuda.is_available() else "cpu", 481 | ) 482 | else: 483 | return PyTorchPredictor( 484 | input_transform=transformation + prediction_splitter, 485 | input_names=PREDICTION_INPUT_NAMES, 486 | prediction_net=module, 487 | batch_size=self.batch_size, 488 | prediction_length=self.prediction_length, 489 | device="cuda" if torch.cuda.is_available() else "cpu", 490 | ) 491 | -------------------------------------------------------------------------------- /lag_llama/gluon/lightning_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import random 16 | 17 | import numpy as np 18 | 19 | from lightning import LightningModule 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | from gluonts.core.component import validated 24 | from gluonts.itertools import prod 25 | from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood 26 | from gluonts.torch.util import repeat_along_dim, take_last 27 | 28 | from data.augmentations.freq_mask import freq_mask 29 | from data.augmentations.freq_mix import freq_mix 30 | from data.augmentations.augmentations import ( 31 | ApplyAugmentations, 32 | Jitter, 33 | MagnitudeWarp, 34 | Permutation, 35 | Rotation, 36 | Scaling, 37 | TimeWarp, 38 | WindowSlice, 39 | WindowWarp, 40 | ) 41 | from gluon_utils.gluon_ts_distributions.implicit_quantile_network import ( 42 | ImplicitQuantileNetworkOutput, 43 | ) 44 | from lag_llama.model.module import LagLlamaModel 45 | 46 | 47 | class LagLlamaLightningModule(LightningModule): 48 | """ 49 | A ``pl.LightningModule`` class that can be used to train a 50 | ``LagLlamaLightningModule`` with PyTorch Lightning. 51 | 52 | This is a thin layer around a (wrapped) ``LagLlamaLightningModule`` object, 53 | that exposes the methods to evaluate training and validation loss. 54 | 55 | Parameters 56 | ---------- 57 | model 58 | ``LagLlamaLightningModule`` to be trained. 59 | loss 60 | Loss function to be used for training, 61 | default: ``NegativeLogLikelihood()``. 62 | lr 63 | Learning rate, default: ``1e-3``. 64 | weight_decay 65 | Weight decay regularization parameter, default: ``1e-8``. 66 | """ 67 | 68 | @validated() 69 | def __init__( 70 | self, 71 | model_kwargs: dict, 72 | context_length: int, 73 | prediction_length: int, 74 | loss: DistributionLoss = NegativeLogLikelihood(), 75 | lr: float = 1e-3, 76 | weight_decay: float = 1e-8, 77 | aug_prob: float = 0.1, 78 | freq_mask_rate: float = 0.1, 79 | freq_mixing_rate: float = 0.1, 80 | jitter_prob: float = 0.0, 81 | jitter_sigma: float = 0.03, 82 | scaling_prob: float = 0.0, 83 | scaling_sigma: float = 0.1, 84 | rotation_prob: float = 0.0, 85 | permutation_prob: float = 0.0, 86 | permutation_max_segments: int = 5, 87 | permutation_seg_mode: str = "equal", 88 | magnitude_warp_prob: float = 0.0, 89 | magnitude_warp_sigma: float = 0.2, 90 | magnitude_warp_knot: int = 4, 91 | time_warp_prob: float = 0.0, 92 | time_warp_sigma: float = 0.2, 93 | time_warp_knot: int = 4, 94 | window_slice_prob: float = 0.0, 95 | window_slice_reduce_ratio: float = 0.9, 96 | window_warp_prob: float = 0.0, 97 | window_warp_window_ratio: float = 0.1, 98 | window_warp_scales: list = [0.5, 2.0], 99 | data_id_to_name_map: dict = {}, 100 | use_cosine_annealing_lr: bool = False, 101 | cosine_annealing_lr_args: dict = {}, 102 | track_loss_per_series: bool = False, 103 | nonnegative_pred_samples: bool = False, 104 | use_kv_cache: bool = True, 105 | use_single_pass_sampling: bool = False, 106 | ): 107 | super().__init__() 108 | self.save_hyperparameters() 109 | self.context_length = self.hparams.context_length 110 | self.prediction_length = self.hparams.prediction_length 111 | self.model = LagLlamaModel(**self.hparams.model_kwargs) 112 | self.loss = self.hparams.loss 113 | self.lr = self.hparams.lr 114 | self.weight_decay = self.hparams.weight_decay 115 | self.aug_prob = self.hparams.aug_prob 116 | self.freq_mask_rate = self.hparams.freq_mask_rate 117 | self.freq_mixing_rate = self.hparams.freq_mixing_rate 118 | self.jitter_prob = self.hparams.jitter_prob 119 | self.jitter_sigma = self.hparams.jitter_sigma 120 | self.scaling_prob = self.hparams.scaling_prob 121 | self.scaling_sigma = self.hparams.scaling_sigma 122 | self.rotation_prob = self.hparams.rotation_prob 123 | self.permutation_prob = self.hparams.permutation_prob 124 | self.permutation_max_segments = self.hparams.permutation_max_segments 125 | self.permutation_seg_mode = self.hparams.permutation_seg_mode 126 | self.magnitude_warp_prob = self.hparams.magnitude_warp_prob 127 | self.magnitude_warp_sigma = self.hparams.magnitude_warp_sigma 128 | self.magnitude_warp_knot = self.hparams.magnitude_warp_knot 129 | self.time_warp_prob = self.hparams.time_warp_prob 130 | self.time_warp_sigma = self.hparams.time_warp_sigma 131 | self.time_warp_knot = self.hparams.time_warp_knot 132 | self.window_slice_prob = self.hparams.window_slice_prob 133 | self.window_slice_reduce_ratio = self.hparams.window_slice_reduce_ratio 134 | self.window_warp_prob = self.hparams.window_warp_prob 135 | self.window_warp_window_ratio = self.hparams.window_warp_window_ratio 136 | self.window_warp_scales = self.hparams.window_warp_scales 137 | self.data_id_to_name_map = self.hparams.data_id_to_name_map 138 | self.use_cosine_annealing_lr = self.hparams.use_cosine_annealing_lr 139 | self.cosine_annealing_lr_args = self.hparams.cosine_annealing_lr_args 140 | self.track_loss_per_series = self.hparams.track_loss_per_series 141 | self.nonnegative_pred_samples = self.hparams.nonnegative_pred_samples 142 | 143 | self.time_feat = self.hparams.model_kwargs["time_feat"] 144 | # data_id based 145 | self.train_loss_dict = {} 146 | self.val_loss_dict = {} 147 | # item_id based - to be used only in single-dataset mode 148 | self.train_loss_dict_per_series = {} 149 | self.val_loss_dict_per_series = {} 150 | self.use_kv_cache = use_kv_cache 151 | self.use_single_pass_sampling = use_single_pass_sampling 152 | self.transforms = [] 153 | aug_probs = dict( 154 | Jitter=dict(prob=self.jitter_prob, sigma=self.jitter_sigma), 155 | Scaling=dict(prob=self.scaling_prob, sigma=self.scaling_sigma), 156 | Rotation=dict(prob=self.rotation_prob), 157 | Permutation=dict( 158 | prob=self.permutation_prob, 159 | max_segments=self.permutation_max_segments, 160 | seg_mode=self.permutation_seg_mode, 161 | ), 162 | MagnitudeWarp=dict( 163 | prob=self.magnitude_warp_prob, 164 | sigma=self.magnitude_warp_sigma, 165 | knot=self.magnitude_warp_knot, 166 | ), 167 | TimeWarp=dict( 168 | prob=self.time_warp_prob, 169 | sigma=self.time_warp_sigma, 170 | knot=self.time_warp_knot, 171 | ), 172 | WindowSlice=dict( 173 | prob=self.window_slice_prob, reduce_ratio=self.window_slice_reduce_ratio 174 | ), 175 | WindowWarp=dict( 176 | prob=self.window_warp_prob, 177 | window_ratio=self.window_warp_window_ratio, 178 | warp_slices=self.window_warp_scales, 179 | ), 180 | ) 181 | for aug, params in aug_probs.items(): 182 | if params["prob"] > 0: 183 | if aug == "Jitter": 184 | self.transforms.append(Jitter(params["prob"], params["sigma"])) 185 | elif aug == "Scaling": 186 | self.transforms.append(Scaling(params["prob"], params["sigma"])) 187 | elif aug == "Rotation": 188 | self.transforms.append(Rotation(params["prob"])) 189 | elif aug == "Permutation": 190 | self.transforms.append( 191 | Permutation( 192 | params["prob"], params["max_segments"], params["seg_mode"] 193 | ) 194 | ) 195 | elif aug == "MagnitudeWarp": 196 | self.transforms.append( 197 | MagnitudeWarp(params["prob"], params["sigma"], params["knot"]) 198 | ) 199 | elif aug == "TimeWarp": 200 | self.transforms.append( 201 | TimeWarp(params["prob"], params["sigma"], params["knot"]) 202 | ) 203 | elif aug == "WindowSlice": 204 | self.transforms.append( 205 | WindowSlice(params["prob"], params["reduce_ratio"]) 206 | ) 207 | elif aug == "WindowWarp": 208 | self.transforms.append( 209 | WindowWarp( 210 | params["prob"], 211 | params["window_ratio"], 212 | params["warp_slices"], 213 | ) 214 | ) 215 | 216 | self.augmentations = ApplyAugmentations(self.transforms) 217 | 218 | # greedy prediction 219 | def forward(self, *args, **kwargs): 220 | past_target = kwargs[ 221 | "past_target" 222 | ] # (bsz, model.context_length+max(model.lags_seq)) 223 | past_observed_values = kwargs[ 224 | "past_observed_values" 225 | ] # (bsz, model.context_length+max(model.lags_seq)) 226 | if self.time_feat: 227 | past_time_feat = kwargs["past_time_feat"] 228 | future_time_feat = kwargs["future_time_feat"] 229 | 230 | use_single_pass_sampling = self.use_single_pass_sampling 231 | 232 | future_samples = [] 233 | 234 | if use_single_pass_sampling: 235 | # Single-pass sampling mode: Single forward pass per step, save distribution parameters, sample `num_parallel_samples` times, add mean to context. 236 | for t in range(self.prediction_length): 237 | params, loc, scale = self.model( 238 | *args, 239 | past_time_feat=past_time_feat if self.time_feat else None, 240 | future_time_feat=future_time_feat[..., : t + 1, :] if self.time_feat else None, 241 | past_target=past_target, 242 | past_observed_values=past_observed_values, 243 | use_kv_cache=self.use_kv_cache, 244 | ) 245 | 246 | sliced_params = [ 247 | p[:, -1:] for p in params 248 | ] # Take the last timestep predicted. Each tensor is of shape (#bsz, 1) 249 | # Singular distribution is used for getting the greedy prediction (mean) 250 | distr = self.model.distr_output.distribution(sliced_params, loc, scale) 251 | greedy_prediction = distr.mean # (#bsz, 1) 252 | 253 | repeated_sliced_params = [ 254 | p[:, -1:].repeat_interleave( 255 | self.model.num_parallel_samples, 0 256 | ) for p in params 257 | ] # Take the last timestep predicted and repeat for number of samples. Each tensor is of shape (#bsz*#parallel_samples, 1) 258 | repeated_loc = loc.repeat_interleave( 259 | self.model.num_parallel_samples, 0 260 | ) 261 | repeated_scale = scale.repeat_interleave( 262 | self.model.num_parallel_samples, 0 263 | ) 264 | # Repeated distribution is used for getting the parallel samples 265 | # (distr.sample([self.model.num_parallel_samples]) seems to give terrible results) 266 | repeated_distr = self.model.distr_output.distribution(repeated_sliced_params, repeated_loc, repeated_scale) 267 | sample = repeated_distr.sample() # (#bsz*#parallel_samples, 1) 268 | if self.nonnegative_pred_samples: 269 | sample = F.relu(sample) 270 | future_samples.append(sample) 271 | 272 | past_target = torch.cat((past_target, greedy_prediction), dim=1) 273 | past_observed_values = torch.cat( 274 | (past_observed_values, torch.ones_like(greedy_prediction)), dim=1 275 | ) 276 | else: 277 | # Original probabilistic forecasting: Duplicate input, `num_parallel_samples` forward passes per step, sample each distribution once, add samples to context. 278 | repeated_past_target = past_target.repeat_interleave(self.model.num_parallel_samples, 0) 279 | repeated_past_observed_values = past_observed_values.repeat_interleave(self.model.num_parallel_samples, 0) 280 | if self.time_feat: 281 | repeated_past_time_feat = past_time_feat.repeat_interleave(self.model.num_parallel_samples, 0) 282 | repeated_future_time_feat = future_time_feat.repeat_interleave(self.model.num_parallel_samples, 0) 283 | 284 | for t in range(self.prediction_length): 285 | if self.time_feat: 286 | params, loc, scale = self.model( 287 | *args, 288 | past_time_feat=repeated_past_time_feat, 289 | future_time_feat=repeated_future_time_feat[..., : t + 1, :], 290 | past_target=repeated_past_target, 291 | past_observed_values=repeated_past_observed_values, 292 | use_kv_cache=self.use_kv_cache, 293 | ) 294 | else: 295 | params, loc, scale = self.model( 296 | *args, 297 | past_time_feat=None, 298 | future_time_feat=None, 299 | past_target=repeated_past_target, 300 | past_observed_values=repeated_past_observed_values, 301 | use_kv_cache=self.use_kv_cache, 302 | ) 303 | 304 | sliced_params = [p[:, -1:] for p in params] 305 | distr = self.model.distr_output.distribution(sliced_params, loc, scale) 306 | sample = distr.sample() 307 | if self.nonnegative_pred_samples: 308 | sample = F.relu(sample) 309 | future_samples.append(sample) 310 | 311 | repeated_past_target = torch.cat((repeated_past_target, sample), dim=1) 312 | repeated_past_observed_values = torch.cat( 313 | (repeated_past_observed_values, torch.ones_like(sample)), dim=1 314 | ) 315 | 316 | self.model.reset_cache() 317 | 318 | concat_future_samples = torch.cat(future_samples, dim=-1) 319 | return concat_future_samples.reshape( 320 | (-1, self.model.num_parallel_samples, self.prediction_length) 321 | + self.model.distr_output.event_shape, 322 | ) 323 | 324 | 325 | # train 326 | def _compute_loss(self, batch, do_not_average=False, return_observed_values=False): 327 | past_target = batch[ 328 | "past_target" 329 | ] # (bsz, model.context_length+max(model.lags_seq)) 330 | past_observed_values = batch[ 331 | "past_observed_values" 332 | ] # (bsz, model.context_length+max(model.lags_seq)) with 0s or 1s indicating available (1s) or missing (0s) 333 | future_target = batch["future_target"] # (bsz, model.prediction_length) 334 | future_observed_values = batch[ 335 | "future_observed_values" 336 | ] # (bsz, model.prediction_length) with 0s or 1s indicating available (1s) or missing (0s) 337 | if self.time_feat: 338 | past_time_feat = batch["past_time_feat"] 339 | future_time_feat = batch["future_time_feat"] 340 | else: 341 | past_time_feat = None 342 | future_time_feat = None 343 | 344 | extra_dims = len(future_target.shape) - len(past_target.shape) # usually 0 345 | extra_shape = future_target.shape[:extra_dims] # shape remains the same 346 | 347 | repeats = prod(extra_shape) # usually 1 348 | past_target = repeat_along_dim( 349 | past_target, 0, repeats 350 | ) # (bsz, model.context_length+max(model.lags_seq)) 351 | past_observed_values = repeat_along_dim( 352 | past_observed_values, 0, repeats 353 | ) # (bsz, model.context_length+max(model.lags_seq)) 354 | 355 | future_target_reshaped = future_target.reshape( 356 | -1, 357 | *future_target.shape[extra_dims + 1 :], 358 | ) # (bsz, model.prediction_length) 359 | future_observed_reshaped = future_observed_values.reshape( 360 | -1, 361 | *future_observed_values.shape[extra_dims + 1 :], 362 | ) # (bsz, model.prediction_length) 363 | 364 | distr_args, loc, scale = self.model( 365 | past_target=past_target, 366 | past_observed_values=past_observed_values, 367 | past_time_feat=past_time_feat, 368 | future_time_feat=future_time_feat, 369 | future_target=future_target_reshaped, 370 | ) # distr_args is a tuple with two tensors of shape (bsz, context_length+pred_len-1) 371 | context_target = take_last( 372 | past_target, dim=-1, num=self.context_length - 1 373 | ) # (bsz, context_length-1) # Basically removes the first value since it cannot be predicted 374 | target = torch.cat( 375 | (context_target, future_target_reshaped), 376 | dim=1, 377 | ) # (bsz, context_length-1+pred_len) # values that can be predicted 378 | context_observed = take_last( 379 | past_observed_values, dim=-1, num=self.context_length - 1 380 | ) # same as context_target, but for observed_values tensor 381 | observed_values = torch.cat( 382 | (context_observed, future_observed_reshaped), dim=1 383 | ) # same as target but for observed_values tensor 384 | 385 | if type(self.model.distr_output) == ImplicitQuantileNetworkOutput: 386 | if not do_not_average: 387 | loss = ( 388 | self.model.distr_output.loss(target, distr_args, loc, scale) 389 | * observed_values 390 | ).sum() / observed_values.sum().clamp_min(1.0) 391 | else: 392 | loss = ( 393 | self.model.distr_output.loss(target, distr_args, loc, scale) 394 | * observed_values 395 | ) 396 | else: 397 | distr = self.model.distr_output.distribution( 398 | distr_args, loc=loc, scale=scale 399 | ) # an object representing a distribution with the specified parameters. We need this to compute the NLL loss. 400 | if not do_not_average: 401 | loss = ( 402 | self.loss(distr, target) * observed_values 403 | ).sum() / observed_values.sum().clamp_min(1.0) 404 | else: 405 | loss = self.loss(distr, target) * observed_values 406 | 407 | if not return_observed_values: 408 | return loss 409 | else: 410 | return loss, observed_values 411 | 412 | def training_step(self, batch, batch_idx: int): # type: ignore 413 | """ 414 | Execute training step. 415 | """ 416 | if random.random() < self.aug_prob: 417 | # Freq mix and Freq mask have separate functions 418 | if self.freq_mask_rate > 0: 419 | batch["past_target"], batch["future_target"] = freq_mask( 420 | batch["past_target"], 421 | batch["future_target"], 422 | rate=self.freq_mask_rate, 423 | ) 424 | if self.freq_mixing_rate: 425 | batch["past_target"], batch["future_target"] = freq_mix( 426 | batch["past_target"], 427 | batch["future_target"], 428 | rate=self.freq_mixing_rate, 429 | ) 430 | # Other augmentation 431 | if len(self.transforms): 432 | batch["past_target"], batch["future_target"] = self.augmentations( 433 | batch["past_target"], batch["future_target"] 434 | ) 435 | 436 | train_loss_per_sample, observed_values = self._compute_loss( 437 | batch, do_not_average=True, return_observed_values=True 438 | ) 439 | 440 | train_loss_avg = train_loss_per_sample.sum() / observed_values.sum().clamp_min( 441 | 1.0 442 | ) 443 | self.log( 444 | "train_loss", train_loss_avg, on_epoch=True, on_step=False, prog_bar=False 445 | ) 446 | return train_loss_avg 447 | 448 | def on_train_epoch_end(self): 449 | # Log all losses 450 | for key, value in self.train_loss_dict.items(): 451 | loss_avg = np.mean(value) 452 | self.log( 453 | f"train_loss_avg_per_train_dataset/{self.data_id_to_name_map[key]}", 454 | loss_avg, 455 | on_epoch=True, 456 | on_step=False, 457 | prog_bar=False, 458 | ) 459 | 460 | if self.track_loss_per_series: 461 | # Log all losses 462 | for key, value in self.train_loss_dict_per_series.items(): 463 | loss_avg = np.mean(value) 464 | self.log( 465 | f"train_loss_avg_per_train_series/{key}", 466 | loss_avg, 467 | on_epoch=True, 468 | on_step=False, 469 | prog_bar=False, 470 | ) 471 | 472 | # Reset loss_dict 473 | self.train_loss_dict = {} 474 | self.train_loss_dict_per_series = {} 475 | 476 | def validation_step(self, batch, batch_idx: int): # type: ignore 477 | """ 478 | Execute validation step. 479 | """ 480 | val_loss_per_sample, observed_values = self._compute_loss( 481 | batch, do_not_average=True, return_observed_values=True 482 | ) 483 | 484 | val_loss_avg = val_loss_per_sample.sum() / observed_values.sum().clamp_min(1.0) 485 | self.log("val_loss", val_loss_avg, on_epoch=True, on_step=False, prog_bar=False) 486 | return val_loss_avg 487 | 488 | def on_validation_epoch_end(self): 489 | # Log all losses 490 | for key, value in self.val_loss_dict.items(): 491 | loss_avg = np.mean(value) 492 | if key >= 0: 493 | self.log( 494 | f"val_loss_avg_per_train_dataset/{self.data_id_to_name_map[key]}", 495 | loss_avg, 496 | on_epoch=True, 497 | on_step=False, 498 | prog_bar=False, 499 | ) 500 | else: 501 | self.log( 502 | f"val_loss_avg_per_test_dataset/{self.data_id_to_name_map[key]}", 503 | loss_avg, 504 | on_epoch=True, 505 | on_step=False, 506 | prog_bar=False, 507 | ) 508 | 509 | if self.track_loss_per_series: 510 | # Log all losses 511 | for key, value in self.val_loss_dict_per_series.items(): 512 | loss_avg = np.mean(value) 513 | self.log( 514 | f"val_loss_avg_per_train_series/{key}", 515 | loss_avg, 516 | on_epoch=True, 517 | on_step=False, 518 | prog_bar=False, 519 | ) 520 | 521 | # Reset loss_dict 522 | self.val_loss_dict = {} 523 | self.val_loss_dict_per_series = {} 524 | 525 | def configure_optimizers(self): 526 | """ 527 | Returns the optimizer to use. 528 | """ 529 | optimizer = torch.optim.Adam( 530 | self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay 531 | ) 532 | if self.use_cosine_annealing_lr: 533 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 534 | optimizer, **self.cosine_annealing_lr_args, verbose=True 535 | ) 536 | return {"optimizer": optimizer, "lr_scheduler": scheduler} 537 | else: 538 | return optimizer 539 | -------------------------------------------------------------------------------- /lag_llama/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. -------------------------------------------------------------------------------- /lag_llama/model/module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from dataclasses import dataclass 17 | from typing import List, Optional 18 | 19 | import torch 20 | from torch import nn 21 | from torch.nn import functional as F 22 | 23 | from gluonts.torch.distributions import DistributionOutput 24 | from gluonts.torch.scaler import MeanScaler, NOPScaler, StdScaler 25 | from gluonts.torch.util import lagged_sequence_values, unsqueeze_expand 26 | 27 | from gluon_utils.scalers.robust_scaler import RobustScaler 28 | 29 | 30 | @dataclass 31 | class LTSMConfig: 32 | feature_size: int = 3 + 6 # target + loc + scale + time features 33 | block_size: int = 2048 34 | n_layer: int = 32 35 | n_head: int = 32 36 | n_embd_per_head: int = 128 37 | rope_scaling: Optional[dict] = None 38 | dropout: float = 0.0 39 | 40 | 41 | class Block(nn.Module): 42 | def __init__(self, config: LTSMConfig) -> None: 43 | super().__init__() 44 | self.rms_1 = RMSNorm(config.n_embd_per_head * config.n_head) 45 | self.attn = CausalSelfAttention(config) 46 | self.rms_2 = RMSNorm(config.n_embd_per_head * config.n_head) 47 | self.mlp = MLP(config) 48 | 49 | def forward(self, x: torch.Tensor, use_kv_cache: bool) -> torch.Tensor: 50 | x = x + self.attn(self.rms_1(x), use_kv_cache) 51 | y = x + self.mlp(self.rms_2(x)) 52 | return y 53 | 54 | 55 | class LlamaRotaryEmbedding(torch.nn.Module): 56 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 57 | super().__init__() 58 | 59 | self.dim = dim 60 | self.max_position_embeddings = max_position_embeddings 61 | self.base = base 62 | inv_freq = 1.0 / ( 63 | self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) 64 | ) 65 | self.register_buffer("inv_freq", inv_freq, persistent=False) 66 | 67 | # Build here to make `torch.jit.trace` work. 68 | self._set_cos_sin_cache( 69 | seq_len=max_position_embeddings, 70 | device=self.inv_freq.device, 71 | dtype=torch.get_default_dtype(), 72 | ) 73 | 74 | def _set_cos_sin_cache(self, seq_len, device, dtype): 75 | self.max_seq_len_cached = seq_len 76 | t = torch.arange( 77 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 78 | ) 79 | 80 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 81 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 82 | emb = torch.cat((freqs, freqs), dim=-1) 83 | self.register_buffer( 84 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False 85 | ) 86 | self.register_buffer( 87 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False 88 | ) 89 | 90 | def forward(self, device, dtype, seq_len=None): 91 | # x: [bs, num_attention_heads, seq_len, head_size] 92 | if seq_len > self.max_seq_len_cached: 93 | self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype) 94 | 95 | return ( 96 | self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype), 97 | self.sin_cached[:, :, :seq_len, ...].to(dtype=dtype), 98 | ) 99 | 100 | 101 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): 102 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 103 | 104 | def __init__( 105 | self, 106 | dim, 107 | max_position_embeddings=2048, 108 | base=10000, 109 | device=None, 110 | scaling_factor=1.0, 111 | ): 112 | self.scaling_factor = scaling_factor 113 | super().__init__(dim, max_position_embeddings, base, device) 114 | 115 | def _set_cos_sin_cache(self, seq_len, device, dtype): 116 | self.max_seq_len_cached = seq_len 117 | t = torch.arange( 118 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 119 | ) 120 | t = t / self.scaling_factor 121 | 122 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 123 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 124 | emb = torch.cat((freqs, freqs), dim=-1) 125 | self.register_buffer( 126 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False 127 | ) 128 | self.register_buffer( 129 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False 130 | ) 131 | 132 | 133 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 134 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 135 | 136 | def __init__( 137 | self, 138 | dim, 139 | max_position_embeddings=2048, 140 | base=10000, 141 | device=None, 142 | scaling_factor=1.0, 143 | ): 144 | self.scaling_factor = scaling_factor 145 | super().__init__(dim, max_position_embeddings, base, device) 146 | 147 | def _set_cos_sin_cache(self, seq_len, device, dtype): 148 | self.max_seq_len_cached = seq_len 149 | 150 | if seq_len > self.max_position_embeddings: 151 | base = self.base * ( 152 | (self.scaling_factor * seq_len / self.max_position_embeddings) 153 | - (self.scaling_factor - 1) 154 | ) ** (self.dim / (self.dim - 2)) 155 | inv_freq = 1.0 / ( 156 | base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) 157 | ) 158 | self.register_buffer("inv_freq", inv_freq, persistent=False) 159 | 160 | t = torch.arange( 161 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 162 | ) 163 | 164 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 165 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 166 | emb = torch.cat((freqs, freqs), dim=-1) 167 | self.register_buffer( 168 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False 169 | ) 170 | self.register_buffer( 171 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False 172 | ) 173 | 174 | 175 | def rotate_half(x): 176 | """Rotates half the hidden dims of the input.""" 177 | x1 = x[..., : x.shape[-1] // 2] 178 | x2 = x[..., x.shape[-1] // 2 :] 179 | return torch.cat((-x2, x1), dim=-1) 180 | 181 | 182 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 183 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 184 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 185 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 186 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 187 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 188 | q_embed = (q * cos) + (rotate_half(q) * sin) 189 | k_embed = (k * cos) + (rotate_half(k) * sin) 190 | return q_embed, k_embed 191 | 192 | 193 | class CausalSelfAttention(nn.Module): 194 | def __init__(self, config: LTSMConfig) -> None: 195 | super().__init__() 196 | # query projections for all heads, but in a batch 197 | self.q_proj = nn.Linear( 198 | config.n_embd_per_head * config.n_head, 199 | config.n_embd_per_head * config.n_head, 200 | bias=False, 201 | ) 202 | # key, value projections 203 | self.kv_proj = nn.Linear( 204 | config.n_embd_per_head * config.n_head, 205 | 2 * config.n_embd_per_head * config.n_head, 206 | bias=False, 207 | ) 208 | # output projection 209 | self.c_proj = nn.Linear( 210 | config.n_embd_per_head * config.n_head, 211 | config.n_embd_per_head * config.n_head, 212 | bias=False, 213 | ) 214 | 215 | self.n_head = config.n_head 216 | self.n_embd_per_head = config.n_embd_per_head 217 | self.block_size = config.block_size 218 | self.dropout = config.dropout 219 | 220 | self.rope_scaling = config.rope_scaling 221 | self._rope_scaling_validation() 222 | 223 | self._init_rope() 224 | self.kv_cache = None 225 | 226 | def _init_rope(self): 227 | if self.rope_scaling is None: 228 | self.rotary_emb = LlamaRotaryEmbedding( 229 | self.n_embd_per_head, max_position_embeddings=self.block_size 230 | ) 231 | else: 232 | scaling_type = self.rope_scaling["type"] 233 | scaling_factor = self.rope_scaling["factor"] 234 | if scaling_type == "nope": 235 | self.rotary_emb = None 236 | elif scaling_type == "linear": 237 | self.rotary_emb = LlamaLinearScalingRotaryEmbedding( 238 | self.n_embd_per_head, 239 | max_position_embeddings=self.block_size, 240 | scaling_factor=scaling_factor, 241 | ) 242 | elif scaling_type == "dynamic": 243 | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( 244 | self.n_embd_per_head, 245 | max_position_embeddings=self.block_size, 246 | scaling_factor=scaling_factor, 247 | ) 248 | else: 249 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 250 | 251 | def _rope_scaling_validation(self): 252 | """ 253 | Validate the `rope_scaling` configuration. 254 | """ 255 | if self.rope_scaling is None: 256 | return 257 | 258 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 259 | raise ValueError( 260 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 261 | f"got {self.rope_scaling}" 262 | ) 263 | rope_scaling_type = self.rope_scaling.get("type", None) 264 | rope_scaling_factor = self.rope_scaling.get("factor", None) 265 | if rope_scaling_type is None or rope_scaling_type not in [ 266 | "linear", 267 | "dynamic", 268 | "nope", 269 | ]: 270 | raise ValueError( 271 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 272 | ) 273 | if rope_scaling_type in ["linear", "dynamic"]: 274 | if ( 275 | rope_scaling_factor is None 276 | or not isinstance(rope_scaling_factor, float) 277 | or rope_scaling_factor < 1.0 278 | ): 279 | raise ValueError( 280 | f"`rope_scaling`'s factor field must be an float >= 1, got {rope_scaling_factor}" 281 | ) 282 | 283 | def forward(self, x: torch.Tensor, use_kv_cache: bool) -> torch.Tensor: 284 | # batch size, sequence length, embedding dimensionality (n_embd) 285 | # sequence length will be 1 when using kv_cache and kv_cache has been initialised 286 | (B, T, C) = x.size() 287 | 288 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 289 | q = self.q_proj(x) 290 | k, v = self.kv_proj(x).split(self.n_embd_per_head * self.n_head, dim=2) 291 | 292 | cache_initialized = self.kv_cache is not None 293 | if use_kv_cache: 294 | # Optimized for single next prediction 295 | if cache_initialized: 296 | # Update cache 297 | k = torch.cat([self.kv_cache[0], k], dim=1)[:, 1:] 298 | v = torch.cat([self.kv_cache[1], v], dim=1)[:, 1:] 299 | self.kv_cache = k, v 300 | else: 301 | # Build cache 302 | self.kv_cache = k, v 303 | 304 | k = k.view(B, -1, self.n_head, self.n_embd_per_head).transpose( 305 | 1, 2 306 | ) # (B, nh, T, hs) 307 | q = q.view(B, -1, self.n_head, self.n_embd_per_head).transpose( 308 | 1, 2 309 | ) # (B, nh, T, hs) 310 | v = v.view(B, -1, self.n_head, self.n_embd_per_head).transpose( 311 | 1, 2 312 | ) # (B, nh, T, hs) 313 | 314 | # This the true sequence length after concatenation with kv_cache, 315 | # will be the same as `T` when kv_cache is not in use 316 | true_seq_len = k.size(2) 317 | if self.rotary_emb is not None: 318 | if use_kv_cache and cache_initialized: 319 | # When kv_cache is in use and we're working with only the last token (T = 1 instead of full sequence length `true_seq_len``) 320 | # Use the full sequence length for positional embeddings (true_seq_len) 321 | # q is the query vector for the last token, so it's position is the last index (-1) 322 | cos, sin = self.rotary_emb(device=v.device, dtype=v.dtype, seq_len=true_seq_len) 323 | q, _ = apply_rotary_pos_emb(q, k, cos, sin, position_ids=[-1]) 324 | 325 | # k is the key matrix after concatenation with cache, so no position_ids 326 | cos, sin = self.rotary_emb(device=v.device, dtype=v.dtype, seq_len=true_seq_len) 327 | _, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=None) 328 | else: 329 | cos, sin = self.rotary_emb(device=v.device, dtype=v.dtype, seq_len=T) 330 | q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=None) 331 | 332 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 333 | # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 334 | # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 335 | # att = F.softmax(att, dim=-1) 336 | # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 337 | 338 | # efficient attention using Flash Attention CUDA kernels 339 | # When using kv cache at inference, is_causal=False since decoder is causal, at each generation step we want 340 | # to avoid recalculating the same previous token attention 341 | 342 | if use_kv_cache and cache_initialized: 343 | y = F.scaled_dot_product_attention( 344 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False 345 | ) 346 | else: 347 | y = F.scaled_dot_product_attention( 348 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True 349 | ) 350 | 351 | # re-assemble all head outputs side by side 352 | y = y.transpose(1, 2).contiguous().view(B, T, C) 353 | 354 | # output projection 355 | y = self.c_proj(y) 356 | 357 | return y 358 | 359 | 360 | def find_multiple(n: int, k: int) -> int: 361 | if n % k == 0: 362 | return n 363 | return n + k - (n % k) 364 | 365 | 366 | class MLP(nn.Module): 367 | def __init__(self, config: LTSMConfig) -> None: 368 | super().__init__() 369 | hidden_dim = 4 * config.n_embd_per_head * config.n_head 370 | n_hidden = int(2 * hidden_dim / 3) 371 | n_hidden = find_multiple(n_hidden, 256) 372 | 373 | self.c_fc1 = nn.Linear( 374 | config.n_embd_per_head * config.n_head, n_hidden, bias=False 375 | ) 376 | self.c_fc2 = nn.Linear( 377 | config.n_embd_per_head * config.n_head, n_hidden, bias=False 378 | ) 379 | self.c_proj = nn.Linear( 380 | n_hidden, config.n_embd_per_head * config.n_head, bias=False 381 | ) 382 | 383 | def forward(self, x: torch.Tensor) -> torch.Tensor: 384 | x = F.silu(self.c_fc1(x)) * self.c_fc2(x) 385 | x = self.c_proj(x) 386 | return x 387 | 388 | 389 | class RMSNorm(nn.Module): 390 | """Root Mean Square Layer Normalization. 391 | 392 | Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: 393 | https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. 394 | """ 395 | 396 | def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: 397 | super().__init__() 398 | self.scale = nn.Parameter(torch.ones(size)) 399 | self.eps = eps 400 | self.dim = dim 401 | 402 | def forward(self, x: torch.Tensor) -> torch.Tensor: 403 | # NOTE: the original RMSNorm paper implementation is not equivalent 404 | # norm_x = x.norm(2, dim=self.dim, keepdim=True) 405 | # rms_x = norm_x * d_x ** (-1. / 2) 406 | # x_normed = x / (rms_x + self.eps) 407 | # keep RMSNorm in float32 408 | norm_x = x.to(torch.float32).pow(2).mean(dim=self.dim, keepdim=True) 409 | x_normed = x * torch.rsqrt(norm_x + self.eps) 410 | return (self.scale * x_normed).type_as(x) 411 | 412 | 413 | class LagLlamaModel(nn.Module): 414 | def __init__( 415 | self, 416 | context_length: int, 417 | max_context_length: int, 418 | scaling: str, 419 | input_size: int, 420 | n_layer: int, 421 | n_embd_per_head: int, 422 | n_head: int, 423 | lags_seq: List[int], 424 | distr_output: DistributionOutput, 425 | rope_scaling=None, 426 | num_parallel_samples: int = 100, 427 | time_feat: bool = True, 428 | dropout: float = 0.0, 429 | ) -> None: 430 | super().__init__() 431 | self.context_length = context_length 432 | self.lags_seq = lags_seq 433 | if time_feat: 434 | feature_size = input_size * (len(self.lags_seq)) + 2 * input_size + 6 435 | else: 436 | feature_size = input_size * (len(self.lags_seq)) + 2 * input_size 437 | 438 | config = LTSMConfig( 439 | n_layer=n_layer, 440 | n_embd_per_head=n_embd_per_head, 441 | n_head=n_head, 442 | block_size=max_context_length, 443 | feature_size=feature_size, 444 | rope_scaling=rope_scaling, 445 | dropout=dropout, 446 | ) 447 | self.num_parallel_samples = num_parallel_samples 448 | 449 | if scaling == "mean": 450 | self.scaler = MeanScaler(keepdim=True, dim=1) 451 | elif scaling == "std": 452 | self.scaler = StdScaler(keepdim=True, dim=1) 453 | elif scaling == "robust": 454 | self.scaler = RobustScaler(keepdim=True, dim=1) 455 | else: 456 | self.scaler = NOPScaler(keepdim=True, dim=1) 457 | 458 | self.distr_output = distr_output 459 | self.param_proj = self.distr_output.get_args_proj( 460 | config.n_embd_per_head * config.n_head 461 | ) 462 | 463 | self.transformer = nn.ModuleDict( 464 | dict( 465 | wte=nn.Linear( 466 | config.feature_size, config.n_embd_per_head * config.n_head 467 | ), 468 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 469 | ln_f=RMSNorm(config.n_embd_per_head * config.n_head), 470 | ) 471 | ) 472 | self.y_cache = False # used at time of inference when kv cached is used 473 | 474 | def _init_weights(self, module: nn.Module) -> None: 475 | if isinstance(module, nn.Linear): 476 | torch.nn.init.normal_( 477 | module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer) 478 | ) 479 | elif isinstance(module, nn.Embedding): 480 | torch.nn.init.normal_( 481 | module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer) 482 | ) 483 | 484 | def prepare_input( 485 | self, 486 | past_target: torch.Tensor, 487 | past_observed_values: torch.Tensor, 488 | past_time_feat: Optional[torch.Tensor] = None, 489 | future_time_feat: Optional[torch.Tensor] = None, 490 | future_target: Optional[torch.Tensor] = None, 491 | ): 492 | scaled_past_target, loc, scale = self.scaler( 493 | past_target, past_observed_values 494 | ) # Data is standardized (past_observed_values is passed as "weights" parameter) # (bsz, context_length+max(self.lags_seq) 495 | 496 | # In the below code, instead of max(self.lags_seq), it was previously -self.context_length 497 | if future_target is not None: 498 | input = torch.cat( 499 | ( 500 | scaled_past_target[..., max(self.lags_seq) :], # Just the context 501 | (future_target[..., :-1] - loc) 502 | / scale, # Not sure about the -1 here. Maybe so since the last value isn't used in the model for prediction of any new values. also if the prediction length is 1, this doesn't really affect anything 503 | ), 504 | dim=-1, 505 | ) # Shape is (bsz, context_length+(pred_len-1)) 506 | else: 507 | input = scaled_past_target[..., max(self.lags_seq) :] 508 | if (past_time_feat is not None) and (future_time_feat is not None): 509 | time_feat = ( 510 | torch.cat( 511 | ( 512 | past_time_feat[..., max(self.lags_seq) :, :], 513 | future_time_feat[..., :-1, :], 514 | ), 515 | dim=1, 516 | ) 517 | if future_time_feat is not None 518 | else past_time_feat[..., max(self.lags_seq) :, :] 519 | ) 520 | 521 | prior_input = ( 522 | past_target[..., : max(self.lags_seq)] - loc 523 | ) / scale # This the history used to construct lags. # bsz, max(self.lags_seq) 524 | 525 | lags = lagged_sequence_values( 526 | self.lags_seq, prior_input, input, dim=-1 527 | ) # Lags are added as an extra dim. Shape is (bsz, context_length+(pred_len-1), len(self.lags_seq)) 528 | 529 | static_feat = torch.cat( 530 | (loc.abs().log1p(), scale.log()), dim=-1 531 | ) # (bsz, 2) (loc and scale are concatenated) 532 | expanded_static_feat = unsqueeze_expand( 533 | static_feat, dim=-2, size=lags.shape[-2] 534 | ) # (bsz, context_length+(pred_len-1), 2) 535 | # expanded_static_feat: (bsz, context_length+(pred_len-1), len(self.lags_seq) + 2); (bsz, 1); (bsz, 1) 536 | 537 | if past_time_feat is not None: 538 | return ( 539 | torch.cat((lags, expanded_static_feat, time_feat), dim=-1), 540 | loc, 541 | scale, 542 | ) 543 | else: 544 | return torch.cat((lags, expanded_static_feat), dim=-1), loc, scale 545 | 546 | def forward( 547 | self, 548 | past_target: torch.Tensor, 549 | past_observed_values: torch.Tensor, 550 | past_time_feat: Optional[torch.Tensor] = None, 551 | future_time_feat: Optional[torch.Tensor] = None, 552 | future_target: Optional[torch.Tensor] = None, 553 | use_kv_cache: bool = False, 554 | ) -> torch.Tensor: 555 | # if past_time_feat is not None: 556 | transformer_input, loc, scale = self.prepare_input( 557 | past_target=past_target, 558 | past_observed_values=past_observed_values, 559 | future_target=future_target, 560 | past_time_feat=past_time_feat, 561 | future_time_feat=future_time_feat, 562 | ) # return: (bsz, context_length+(pred_len-1), len(self.lags_seq) + 2); (bsz, 1); (bsz, 1) 563 | # To use kv cache for inference and pass recent token to transformer 564 | if use_kv_cache and self.y_cache: 565 | # Only use the most recent one, rest is in cache 566 | transformer_input = transformer_input[:, -1:] 567 | 568 | # forward the LLaMA model itself 569 | x = self.transformer.wte( 570 | transformer_input 571 | ) # token embeddings of shape (b, t, n_embd_per_head*n_head) # (bsz, context_length+(pred_len-1), n_embd_per_head*n_head) 572 | 573 | for block in self.transformer.h: 574 | x = block(x, use_kv_cache) 575 | x = self.transformer.ln_f( 576 | x 577 | ) # (bsz, context_length+(pred_len-1), n_embd_per_head*n_head) 578 | if use_kv_cache: 579 | self.y_cache = True 580 | params = self.param_proj( 581 | x 582 | ) # (bsz, context_length+(pred_len-1)) ; (bsz, context_length+(pred_len-1)) 583 | return params, loc, scale 584 | 585 | def reset_cache(self) -> None: 586 | """ 587 | Resets all cached key-values in attention. 588 | Has to be called after prediction loop in predictor 589 | """ 590 | self.y_cache = False 591 | for block in self.transformer.h: 592 | block.attn.kv_cache = None 593 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [project] 7 | name = "lag-llama" 8 | dynamic = ["version", "dependencies"] 9 | description = "Lag-Llama is the first open-source foundation model for time series forecasting!" 10 | readme = "README.md" 11 | license = {file = "LICENSE"} 12 | authors = [ 13 | {name = "Arjun Ashok", email = "arjun.ashok@servicenow.com"}, 14 | {name = "Kashif Rasul", email = "kashif.rasul@gmail.com"} 15 | ] 16 | keywords = ["llama", "time series", "forecasting", "machine learning", "open-source", "foundation model", "lag-llama"] 17 | 18 | 19 | [project.urls] 20 | "Homepage" = "https://github.com/time-series-foundation-models/lag-llama" 21 | 22 | 23 | [tool.setuptools.dynamic] 24 | dependencies = {file = "requirements.txt"} 25 | version = {file = "VERSION"} 26 | 27 | [tool.setuptools] 28 | package-data = {"data" = ["data/*"], "configs" = ["configs/*"]} 29 | 30 | [tool.setuptools.packages.find] 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gluonts[torch]<=0.14.4 2 | numpy>=1.23.5 3 | torch>=2.0.0 4 | wandb 5 | scipy 6 | pandas 7 | huggingface_hub[cli] 8 | matplotlib 9 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | warnings.simplefilter(action="ignore", category=FutureWarning) 18 | warnings.simplefilter(action="ignore", category=UserWarning) 19 | 20 | import argparse 21 | import gc 22 | import json 23 | import os 24 | from hashlib import sha1 25 | 26 | import lightning 27 | import torch 28 | import wandb 29 | from gluonts.evaluation import Evaluator, make_evaluation_predictions 30 | from gluonts.evaluation._base import aggregate_valid 31 | from gluonts.transform import ExpectedNumInstanceSampler 32 | from lightning.pytorch.callbacks import ( 33 | EarlyStopping, 34 | ModelCheckpoint, 35 | StochasticWeightAveraging, 36 | LearningRateMonitor 37 | ) 38 | from lightning.pytorch.loggers import WandbLogger 39 | 40 | from data.data_utils import ( 41 | CombinedDataset, 42 | SingleInstanceSampler, 43 | create_test_dataset, 44 | create_train_and_val_datasets_with_dates, 45 | ) 46 | 47 | from data.dataset_list import ALL_DATASETS 48 | from utils.utils import plot_forecasts, set_seed 49 | 50 | 51 | from lag_llama.gluon.estimator import LagLlamaEstimator 52 | 53 | 54 | def train(args): 55 | # Set seed 56 | set_seed(args.seed) 57 | lightning.seed_everything(args.seed) 58 | 59 | # # Print GPU stats 60 | # print_gpu_stats() 61 | 62 | # Create a directory to store the results in 63 | # This string is made independent of hyperparameters here, as more hyperparameters / arguments may be added later 64 | # The name should be created in the calling bash script 65 | # This way, when that same script is executed again, automatically the model training is resumed from a checkpoint if available 66 | experiment_name = args.experiment_name 67 | fulldir_experiments = os.path.join(args.results_dir, experiment_name, str(args.seed)) 68 | if os.path.exists(fulldir_experiments): print(fulldir_experiments, "already exists.") 69 | os.makedirs(fulldir_experiments, exist_ok=True) 70 | 71 | # Create directory for checkpoints 72 | checkpoint_dir = os.path.join(fulldir_experiments, "checkpoints") 73 | os.makedirs(checkpoint_dir, exist_ok=True) 74 | 75 | # Code to retrieve the version with the highest #epoch stored and restore it incl directory and its checkpoint 76 | if args.ckpt_path: 77 | ckpt_path = args.ckpt_path 78 | elif args.get_ckpt_path_from_experiment_name: 79 | fulldir_experiments_for_ckpt_path = os.path.join(args.results_dir, args.get_ckpt_path_from_experiment_name, str(args.seed)) 80 | full_experiment_name_original = args.get_ckpt_path_from_experiment_name + "-seed-" + str(args.seed) 81 | experiment_id_original = sha1(full_experiment_name_original.encode("utf-8")).hexdigest()[:8] 82 | checkpoint_dir_wandb = os.path.join(fulldir_experiments_for_ckpt_path, "lag-llama", experiment_id_original, "checkpoints") 83 | file = os.listdir(checkpoint_dir_wandb)[-1] 84 | if file: ckpt_path = os.path.join(checkpoint_dir_wandb, file) 85 | if not ckpt_path: raise Exception("ckpt_path not found from experiment name") 86 | # Delete the EarlyStoppingCallback and save it in the current checkpoint_dir 87 | new_ckpt_path = checkpoint_dir + "/pretrained_ckpt.ckpt" 88 | print("Moving", ckpt_path, "to", new_ckpt_path) 89 | ckpt_loaded = torch.load(ckpt_path) 90 | del ckpt_loaded['callbacks']["EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"] 91 | ckpt_loaded['callbacks']["ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"]["best_model_path"] = new_ckpt_path 92 | ckpt_loaded['callbacks']["ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"]["dirpath"] = checkpoint_dir 93 | del ckpt_loaded['callbacks']["ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"]["last_model_path"] 94 | torch.save(ckpt_loaded, checkpoint_dir + "/pretrained_ckpt.ckpt") 95 | ckpt_path = checkpoint_dir + "/pretrained_ckpt.ckpt" 96 | else: 97 | ckpt_path = None 98 | if not args.evaluate_only: 99 | ckpt_path = checkpoint_dir + "/last.ckpt" 100 | if not os.path.isfile(ckpt_path): ckpt_path = None 101 | else: 102 | if args.evaluate_only: 103 | full_experiment_name_original = experiment_name + "-seed-" + str(args.seed) 104 | experiment_id_original = sha1(full_experiment_name_original.encode("utf-8")).hexdigest()[:8] 105 | checkpoint_dir_wandb = os.path.join(fulldir_experiments, "lag-llama", experiment_id_original, "checkpoints") 106 | file = os.listdir(checkpoint_dir_wandb)[-1] 107 | if file: ckpt_path = os.path.join(checkpoint_dir_wandb, file) 108 | elif args.evaluate_only: 109 | for file in os.listdir(checkpoint_dir): 110 | if "best" in file: 111 | ckpt_path = checkpoint_dir + "/" + file 112 | break 113 | 114 | if ckpt_path: 115 | print("Checkpoint", ckpt_path, "retrieved from experiment directory") 116 | else: 117 | print("No checkpoints found. Training from scratch.") 118 | 119 | # W&B logging 120 | # NOTE: Caution when using `full_experiment_name` after this 121 | if args.eval_prefix and (args.evaluate_only): experiment_name = args.eval_prefix + "_" + experiment_name 122 | full_experiment_name = experiment_name + "-seed-" + str(args.seed) 123 | experiment_id = sha1(full_experiment_name.encode("utf-8")).hexdigest()[:8] 124 | logger = WandbLogger(name=full_experiment_name, \ 125 | save_dir=fulldir_experiments, group=experiment_name, \ 126 | tags=args.wandb_tags, entity=args.wandb_entity, \ 127 | project=args.wandb_project, allow_val_change=True, \ 128 | config=vars(args), id=experiment_id, \ 129 | mode=args.wandb_mode, settings=wandb.Settings(code_dir=".")) 130 | 131 | # Callbacks 132 | swa_callbacks = StochasticWeightAveraging( 133 | swa_lrs=args.swa_lrs, 134 | swa_epoch_start=args.swa_epoch_start, 135 | annealing_epochs=args.annealing_epochs, 136 | annealing_strategy=args.annealing_strategy, 137 | ) 138 | early_stop_callback = EarlyStopping( 139 | monitor="val_loss", 140 | min_delta=0.00, 141 | patience=int(args.early_stopping_patience), 142 | verbose=True, 143 | mode="min", 144 | ) 145 | model_checkpointing = ModelCheckpoint( 146 | dirpath=checkpoint_dir, 147 | save_last=True, 148 | save_top_k=1, 149 | filename="best-{epoch}-{val_loss:.2f}", 150 | ) 151 | lr_monitor = LearningRateMonitor(logging_interval='step') 152 | callbacks = [early_stop_callback, 153 | lr_monitor, 154 | model_checkpointing 155 | ] 156 | if args.swa: 157 | print("Using SWA") 158 | callbacks.append(swa_callbacks) 159 | 160 | # Create train and test datasets 161 | if not args.single_dataset: 162 | train_dataset_names = args.all_datasets 163 | for test_dataset in args.test_datasets: 164 | train_dataset_names.remove(test_dataset) 165 | print("Training datasets:", train_dataset_names) 166 | print("Test datasets:", args.test_datasets) 167 | data_id_to_name_map = {} 168 | name_to_data_id_map = {} 169 | for data_id, name in enumerate(train_dataset_names): 170 | data_id_to_name_map[data_id] = name 171 | name_to_data_id_map[name] = data_id 172 | test_data_id = -1 173 | for name in args.test_datasets: 174 | data_id_to_name_map[test_data_id] = name 175 | name_to_data_id_map[name] = test_data_id 176 | test_data_id -= 1 177 | else: 178 | print("Training and test on", args.single_dataset) 179 | data_id_to_name_map = {} 180 | name_to_data_id_map = {} 181 | data_id_to_name_map[0] = args.single_dataset 182 | name_to_data_id_map[args.single_dataset] = 0 183 | 184 | # Get prediction length and set it if we are in the single dataset 185 | if args.single_dataset and args.use_dataset_prediction_length: 186 | _, prediction_length, _ = create_test_dataset( 187 | args.single_dataset, args.dataset_path, 0 188 | ) 189 | args.prediction_length = prediction_length 190 | 191 | # Cosine Annealing LR 192 | if args.use_cosine_annealing_lr: 193 | cosine_annealing_lr_args = {"T_max": args.cosine_annealing_lr_t_max, \ 194 | "eta_min": args.cosine_annealing_lr_eta_min} 195 | else: 196 | cosine_annealing_lr_args = {} 197 | 198 | # Create the estimator 199 | estimator = LagLlamaEstimator( 200 | prediction_length=args.prediction_length, 201 | context_length=args.context_length, 202 | input_size=1, 203 | batch_size=args.batch_size, 204 | n_layer=args.n_layer, 205 | n_embd_per_head=args.n_embd_per_head, 206 | n_head=args.n_head, 207 | max_context_length=2048, 208 | rope_scaling=None, 209 | scaling=args.data_normalization, 210 | lr=args.lr, 211 | weight_decay=args.weight_decay, 212 | distr_output=args.distr_output, 213 | # augmentations 214 | aug_prob=args.aug_prob, 215 | freq_mask_rate=args.freq_mask_rate, 216 | freq_mixing_rate=args.freq_mixing_rate, 217 | jitter_prob=args.jitter_prob, 218 | jitter_sigma=args.jitter_sigma, 219 | scaling_prob=args.scaling_prob, 220 | scaling_sigma=args.scaling_sigma, 221 | rotation_prob=args.rotation_prob, 222 | permutation_prob=args.permutation_prob, 223 | permutation_max_segments=args.permutation_max_segments, 224 | permutation_seg_mode=args.permutation_seg_mode, 225 | magnitude_warp_prob=args.magnitude_warp_prob, 226 | magnitude_warp_sigma=args.magnitude_warp_sigma, 227 | magnitude_warp_knot=args.magnitude_warp_knot, 228 | time_warp_prob=args.time_warp_prob, 229 | time_warp_sigma=args.time_warp_sigma, 230 | time_warp_knot=args.time_warp_knot, 231 | window_slice_prob=args.window_slice_prob, 232 | window_slice_reduce_ratio=args.window_slice_reduce_ratio, 233 | window_warp_prob=args.window_warp_prob, 234 | window_warp_window_ratio=args.window_warp_window_ratio, 235 | window_warp_scales=args.window_warp_scales, 236 | # others 237 | num_batches_per_epoch=args.num_batches_per_epoch, 238 | num_parallel_samples=args.num_parallel_samples, 239 | time_feat=args.time_feat, 240 | dropout=args.dropout, 241 | lags_seq=args.lags_seq, 242 | data_id_to_name_map=data_id_to_name_map, 243 | use_cosine_annealing_lr=args.use_cosine_annealing_lr, 244 | cosine_annealing_lr_args=cosine_annealing_lr_args, 245 | track_loss_per_series=args.single_dataset != None, 246 | ckpt_path=ckpt_path, 247 | trainer_kwargs=dict( 248 | max_epochs=args.max_epochs, 249 | accelerator="gpu", 250 | devices=[args.gpu], 251 | limit_val_batches=args.limit_val_batches, 252 | logger=logger, 253 | callbacks=callbacks, 254 | default_root_dir=fulldir_experiments, 255 | ), 256 | ) 257 | 258 | # Save the args as config to the directory 259 | config_filepath = fulldir_experiments + "/args.json" 260 | with open(config_filepath, "w") as config_savefile: 261 | json.dump(vars(args), config_savefile, indent=4) 262 | 263 | # Save the number of parameters to the directory for easy retrieval 264 | num_parameters = sum( 265 | p.numel() for p in estimator.create_lightning_module().parameters() 266 | ) 267 | num_parameters_path = fulldir_experiments + "/num_parameters.txt" 268 | with open(num_parameters_path, "w") as num_parameters_savefile: 269 | num_parameters_savefile.write(str(num_parameters)) 270 | # Log num_parameters 271 | logger.log_metrics({"num_parameters": num_parameters}) 272 | 273 | # Create samplers 274 | # Here we make a window slightly bigger so that instance sampler can sample from each window 275 | # An alternative is to have exact size and use different instance sampler (e.g. ValidationSplitSampler) 276 | # We change ValidationSplitSampler to add min_past 277 | history_length = estimator.context_length + max(estimator.lags_seq) 278 | prediction_length = args.prediction_length 279 | window_size = history_length + prediction_length 280 | print( 281 | "Context length:", 282 | estimator.context_length, 283 | "Prediction Length:", 284 | estimator.prediction_length, 285 | "max(lags_seq):", 286 | max(estimator.lags_seq), 287 | "Therefore, window size:", 288 | window_size, 289 | ) 290 | 291 | # Remove max(estimator.lags_seq) if the dataset is too small 292 | if args.use_single_instance_sampler: 293 | estimator.train_sampler = SingleInstanceSampler( 294 | min_past=estimator.context_length + max(estimator.lags_seq), 295 | min_future=estimator.prediction_length, 296 | ) 297 | estimator.validation_sampler = SingleInstanceSampler( 298 | min_past=estimator.context_length + max(estimator.lags_seq), 299 | min_future=estimator.prediction_length, 300 | ) 301 | else: 302 | estimator.train_sampler = ExpectedNumInstanceSampler( 303 | num_instances=1.0, 304 | min_past=estimator.context_length + max(estimator.lags_seq), 305 | min_future=estimator.prediction_length, 306 | ) 307 | estimator.validation_sampler = ExpectedNumInstanceSampler( 308 | num_instances=1.0, 309 | min_past=estimator.context_length + max(estimator.lags_seq), 310 | min_future=estimator.prediction_length, 311 | ) 312 | 313 | ## Batch size 314 | batch_size = args.batch_size 315 | 316 | if args.evaluate_only: 317 | pass 318 | else: 319 | if not args.single_dataset: 320 | # Create training and validation data 321 | all_datasets, val_datasets, dataset_num_series = [], [], [] 322 | dataset_train_num_points, dataset_val_num_points = [], [] 323 | 324 | for data_id, name in enumerate(train_dataset_names): 325 | data_id = name_to_data_id_map[name] 326 | ( 327 | train_dataset, 328 | val_dataset, 329 | total_train_points, 330 | total_val_points, 331 | total_val_windows, 332 | max_train_end_date, 333 | total_points, 334 | ) = create_train_and_val_datasets_with_dates( 335 | name, 336 | args.dataset_path, 337 | data_id, 338 | history_length, 339 | prediction_length, 340 | num_val_windows=args.num_validation_windows, 341 | last_k_percentage=args.single_dataset_last_k_percentage 342 | ) 343 | print( 344 | "Dataset:", 345 | name, 346 | "Total train points:", total_train_points, 347 | "Total val points:", total_val_points, 348 | ) 349 | all_datasets.append(train_dataset) 350 | val_datasets.append(val_dataset) 351 | dataset_num_series.append(len(train_dataset)) 352 | dataset_train_num_points.append(total_train_points) 353 | dataset_val_num_points.append(total_val_points) 354 | 355 | # Add test splits of test data to validation dataset, just for tracking purposes 356 | test_datasets_num_series = [] 357 | test_datasets_num_points = [] 358 | test_datasets = [] 359 | 360 | if args.stratified_sampling: 361 | if args.stratified_sampling == "series": 362 | train_weights = dataset_num_series 363 | val_weights = dataset_num_series + test_datasets_num_series # If there is just 1 series (airpassengers or saugeenday) this will fail 364 | elif args.stratified_sampling == "series_inverse": 365 | train_weights = [1/x for x in dataset_num_series] 366 | val_weights = [1/x for x in dataset_num_series + test_datasets_num_series] # If there is just 1 series (airpassengers or saugeenday) this will fail 367 | elif args.stratified_sampling == "timesteps": 368 | train_weights = dataset_train_num_points 369 | val_weights = dataset_val_num_points + test_datasets_num_points 370 | elif args.stratified_sampling == "timesteps_inverse": 371 | train_weights = [1 / x for x in dataset_train_num_points] 372 | val_weights = [1 / x for x in dataset_val_num_points + test_datasets_num_points] 373 | else: 374 | train_weights = val_weights = None 375 | 376 | train_data = CombinedDataset(all_datasets, weights=train_weights) 377 | val_data = CombinedDataset(val_datasets+test_datasets, weights=val_weights) 378 | else: 379 | ( 380 | train_data, 381 | val_data, 382 | total_train_points, 383 | total_val_points, 384 | total_val_windows, 385 | max_train_end_date, 386 | total_points, 387 | ) = create_train_and_val_datasets_with_dates( 388 | args.single_dataset, 389 | args.dataset_path, 390 | 0, 391 | history_length, 392 | prediction_length, 393 | num_val_windows=args.num_validation_windows, 394 | last_k_percentage=args.single_dataset_last_k_percentage 395 | ) 396 | print( 397 | "Dataset:", 398 | args.single_dataset, 399 | "Total train points:", total_train_points, 400 | "Total val points:", total_val_points, 401 | ) 402 | 403 | # Batch size search since when we scale up, we might not be able to use the same batch size for all models 404 | if args.search_batch_size: 405 | estimator.num_batches_per_epoch = 10 406 | estimator.limit_val_batches = 10 407 | estimator.trainer_kwargs["max_epochs"] = 1 408 | estimator.trainer_kwargs["callbacks"] = [] 409 | estimator.trainer_kwargs["logger"] = None 410 | fulldir_batchsize_search = os.path.join( 411 | fulldir_experiments, "batch-size-search" 412 | ) 413 | os.makedirs(fulldir_batchsize_search, exist_ok=True) 414 | while batch_size >= 1: 415 | try: 416 | print("Trying batch size:", batch_size) 417 | batch_size_search_dir = os.path.join( 418 | fulldir_batchsize_search, "batch-size-search-" + str(batch_size) 419 | ) 420 | os.makedirs(batch_size_search_dir, exist_ok=True) 421 | estimator.batch_size = batch_size 422 | estimator.trainer_kwargs[ 423 | "default_root_dir" 424 | ] = fulldir_batchsize_search 425 | # Train 426 | train_output = estimator.train_model( 427 | training_data=train_data, 428 | validation_data=val_data, 429 | shuffle_buffer_length=None, 430 | ckpt_path=None, 431 | ) 432 | break 433 | except RuntimeError as e: 434 | if "out of memory" in str(e): 435 | gc.collect() 436 | torch.cuda.empty_cache() 437 | if batch_size == 1: 438 | print( 439 | "Batch is already at the minimum. Cannot reduce further. Exiting..." 440 | ) 441 | exit(0) 442 | else: 443 | print("Caught OutOfMemoryError. Reducing batch size...") 444 | batch_size //= 2 445 | continue 446 | else: 447 | print(e) 448 | exit(1) 449 | estimator.num_batches_per_epoch = args.num_batches_per_epoch 450 | estimator.limit_val_batches = args.limit_val_batches 451 | estimator.trainer_kwargs["max_epochs"] = args.max_epochs 452 | estimator.trainer_kwargs["callbacks"] = callbacks 453 | estimator.trainer_kwargs["logger"] = logger 454 | estimator.trainer_kwargs["default_root_dir"] = fulldir_experiments 455 | if batch_size > 1: batch_size //= 2 456 | estimator.batch_size = batch_size 457 | print("\nUsing a batch size of", batch_size, "\n") 458 | wandb.config.update({"batch_size": batch_size}, allow_val_change=True) 459 | 460 | 461 | # Train 462 | train_output = estimator.train_model( 463 | training_data=train_data, 464 | validation_data=val_data, 465 | shuffle_buffer_length=None, 466 | ckpt_path=ckpt_path, 467 | ) 468 | 469 | # Set checkpoint path before evaluating 470 | best_model_path = train_output.trainer.checkpoint_callback.best_model_path 471 | estimator.ckpt_path = best_model_path 472 | 473 | 474 | print("Using checkpoint:", estimator.ckpt_path, "for evaluation") 475 | # Make directory to store metrics 476 | metrics_dir = os.path.join(fulldir_experiments, "metrics") 477 | os.makedirs(metrics_dir, exist_ok=True) 478 | 479 | # Evaluate 480 | evaluation_datasets = args.test_datasets + train_dataset_names if not args.single_dataset else [args.single_dataset] 481 | 482 | for name in evaluation_datasets: # [test_dataset]: 483 | print("Evaluating on", name) 484 | test_data, prediction_length, total_points = create_test_dataset( 485 | name, args.dataset_path, window_size 486 | ) 487 | print("# of Series in the test data:", len(test_data)) 488 | 489 | # Adapt evaluator to new dataset 490 | estimator.prediction_length = prediction_length 491 | # Batch size loop just in case. This is mandatory as it involves sampling etc. 492 | # NOTE: In case can't do sampling with even batch size of 1, then keep reducing num_parallel_samples until we can (keeping batch size at 1) 493 | while batch_size >= 1: 494 | try: 495 | # Batch size 496 | print("Trying batch size:", batch_size) 497 | estimator.batch_size = batch_size 498 | predictor = estimator.create_predictor( 499 | estimator.create_transformation(), 500 | estimator.create_lightning_module(), 501 | ) 502 | # Make evaluations 503 | forecast_it, ts_it = make_evaluation_predictions( 504 | dataset=test_data, predictor=predictor, num_samples=args.num_samples 505 | ) 506 | forecasts = list(forecast_it) 507 | tss = list(ts_it) 508 | break 509 | except RuntimeError as e: 510 | if "out of memory" in str(e): 511 | gc.collect() 512 | torch.cuda.empty_cache() 513 | if batch_size == 1: 514 | print( 515 | "Batch is already at the minimum. Cannot reduce further. Exiting..." 516 | ) 517 | exit(0) 518 | else: 519 | print("Caught OutOfMemoryError. Reducing batch size...") 520 | batch_size //= 2 521 | continue 522 | else: 523 | print(e) 524 | exit(1) 525 | 526 | if args.plot_test_forecasts: 527 | print("Plotting forecasts") 528 | figure = plot_forecasts(forecasts, tss, prediction_length) 529 | wandb.log({f"Forecast plot of {name}": wandb.Image(figure)}) 530 | 531 | # Get metrics 532 | evaluator = Evaluator( 533 | num_workers=args.num_workers, aggregation_strategy=aggregate_valid 534 | ) 535 | agg_metrics, _ = evaluator( 536 | iter(tss), iter(forecasts), num_series=len(test_data) 537 | ) 538 | # Save metrics 539 | metrics_savepath = metrics_dir + "/" + name + ".json" 540 | with open(metrics_savepath, "w") as metrics_savefile: 541 | json.dump(agg_metrics, metrics_savefile) 542 | 543 | # Log metrics. For now only CRPS is logged. 544 | wandb_metrics = {} 545 | wandb_metrics["test/" + name + "/" + "CRPS"] = agg_metrics["mean_wQuantileLoss"] 546 | logger.log_metrics(wandb_metrics) 547 | 548 | wandb.finish() 549 | 550 | if __name__ == "__main__": 551 | parser = argparse.ArgumentParser() 552 | 553 | # Experiment args 554 | parser.add_argument("-e", "--experiment_name", type=str, required=True) 555 | 556 | # Data arguments 557 | parser.add_argument( 558 | "-d", 559 | "--dataset_path", 560 | type=str, 561 | default="datasets", 562 | help="Enter the datasets folder path here" 563 | ) 564 | parser.add_argument("--all_datasets", type=str, nargs="+", default=ALL_DATASETS) 565 | parser.add_argument("-t", "--test_datasets", type=str, nargs="+", default=[]) 566 | parser.add_argument( 567 | "--stratified_sampling", 568 | type=str, 569 | choices=["series", "series_inverse", "timesteps", "timesteps_inverse"], 570 | ) 571 | 572 | # Seed 573 | parser.add_argument("--seed", type=int, default=42) 574 | 575 | # Model hyperparameters 576 | parser.add_argument("--context_length", type=int, default=256) 577 | parser.add_argument("--prediction_length", type=int, default=1) 578 | parser.add_argument("--max_prediction_length", type=int, default=1024) 579 | parser.add_argument("--n_layer", type=int, default=4) 580 | parser.add_argument("--num_encoder_layer", type=int, default=4, help="Only for lag-transformer") 581 | parser.add_argument("--n_embd_per_head", type=int, default=64) 582 | parser.add_argument("--n_head", type=int, default=4) 583 | parser.add_argument("--dim_feedforward", type=int, default=256) 584 | parser.add_argument("--lags_seq", type=str, nargs="+", default=["Q", "M", "W", "D", "H", "T", "S"]) 585 | 586 | # Data normalization 587 | parser.add_argument( 588 | "--data_normalization", default=None, choices=["mean", "std", "robust", "none"] 589 | ) 590 | 591 | ## Augmentation hyperparameters 592 | # Augmentation probability 593 | parser.add_argument("--aug_prob", type=float, default=0) 594 | 595 | # Frequency Masking 596 | parser.add_argument( 597 | "--freq_mask_rate", type=float, default=0.1, help="Rate of frequency masking" 598 | ) 599 | 600 | # Frequency Mixing 601 | parser.add_argument( 602 | "--freq_mixing_rate", type=float, default=0.1, help="Rate of frequency mixing" 603 | ) 604 | 605 | # Jitter 606 | parser.add_argument( 607 | "--jitter_prob", 608 | type=float, 609 | default=0, 610 | help="Probability of applying Jitter augmentation", 611 | ) 612 | parser.add_argument( 613 | "--jitter_sigma", 614 | type=float, 615 | default=0.03, 616 | help="Standard deviation for Jitter augmentation", 617 | ) 618 | 619 | # Scaling 620 | parser.add_argument( 621 | "--scaling_prob", 622 | type=float, 623 | default=0, 624 | help="Probability of applying Scaling augmentation", 625 | ) 626 | parser.add_argument( 627 | "--scaling_sigma", 628 | type=float, 629 | default=0.1, 630 | help="Standard deviation for Scaling augmentation", 631 | ) 632 | 633 | # Rotation 634 | parser.add_argument( 635 | "--rotation_prob", 636 | type=float, 637 | default=0, 638 | help="Probability of applying Rotation augmentation", 639 | ) 640 | 641 | # Permutation 642 | parser.add_argument( 643 | "--permutation_prob", 644 | type=float, 645 | default=0, 646 | help="Probability of applying Permutation augmentation", 647 | ) 648 | parser.add_argument( 649 | "--permutation_max_segments", 650 | type=int, 651 | default=5, 652 | help="Maximum segments for Permutation augmentation", 653 | ) 654 | parser.add_argument( 655 | "--permutation_seg_mode", 656 | type=str, 657 | default="equal", 658 | choices=["equal", "random"], 659 | help="Segment mode for Permutation augmentation", 660 | ) 661 | 662 | # MagnitudeWarp 663 | parser.add_argument( 664 | "--magnitude_warp_prob", 665 | type=float, 666 | default=0, 667 | help="Probability of applying MagnitudeWarp augmentation", 668 | ) 669 | parser.add_argument( 670 | "--magnitude_warp_sigma", 671 | type=float, 672 | default=0.2, 673 | help="Standard deviation for MagnitudeWarp augmentation", 674 | ) 675 | parser.add_argument( 676 | "--magnitude_warp_knot", 677 | type=int, 678 | default=4, 679 | help="Number of knots for MagnitudeWarp augmentation", 680 | ) 681 | 682 | # TimeWarp 683 | parser.add_argument( 684 | "--time_warp_prob", 685 | type=float, 686 | default=0, 687 | help="Probability of applying TimeWarp augmentation", 688 | ) 689 | parser.add_argument( 690 | "--time_warp_sigma", 691 | type=float, 692 | default=0.2, 693 | help="Standard deviation for TimeWarp augmentation", 694 | ) 695 | parser.add_argument( 696 | "--time_warp_knot", 697 | type=int, 698 | default=4, 699 | help="Number of knots for TimeWarp augmentation", 700 | ) 701 | 702 | # WindowSlice 703 | parser.add_argument( 704 | "--window_slice_prob", 705 | type=float, 706 | default=0, 707 | help="Probability of applying WindowSlice augmentation", 708 | ) 709 | parser.add_argument( 710 | "--window_slice_reduce_ratio", 711 | type=float, 712 | default=0.9, 713 | help="Reduce ratio for WindowSlice augmentation", 714 | ) 715 | 716 | # WindowWarp 717 | parser.add_argument( 718 | "--window_warp_prob", 719 | type=float, 720 | default=0, 721 | help="Probability of applying WindowWarp augmentation", 722 | ) 723 | parser.add_argument( 724 | "--window_warp_window_ratio", 725 | type=float, 726 | default=0.1, 727 | help="Window ratio for WindowWarp augmentation", 728 | ) 729 | parser.add_argument( 730 | "--window_warp_scales", 731 | nargs="+", 732 | type=float, 733 | default=[0.5, 2.0], 734 | help="Scales for WindowWarp augmentation", 735 | ) 736 | 737 | # Argument to include time-features 738 | parser.add_argument( 739 | "--time_feat", 740 | help="include time features", 741 | action="store_true", 742 | ) 743 | 744 | # Training arguments 745 | parser.add_argument("-b", "--batch_size", type=int, default=256) 746 | parser.add_argument("-m", "--max_epochs", type=int, default=10000) 747 | parser.add_argument("-n", "--num_batches_per_epoch", type=int, default=100) 748 | parser.add_argument("--limit_val_batches", type=int) 749 | parser.add_argument("--early_stopping_patience", default=50) 750 | parser.add_argument("--dropout", type=float, default=0.0) 751 | 752 | # Evaluation arguments 753 | parser.add_argument("--num_parallel_samples", type=int, default=100) 754 | parser.add_argument("--num_samples", type=int, default=100) 755 | parser.add_argument("--num_workers", type=int, default=1) 756 | 757 | # GPU ID 758 | parser.add_argument("--gpu", type=int, default=0) 759 | 760 | # Directory to save everything in 761 | parser.add_argument("-r", "--results_dir", type=str, required=True) 762 | 763 | # W&B 764 | parser.add_argument("-w", "--wandb_entity", type=str, default=None) 765 | parser.add_argument("--wandb_project", type=str, default="lag-llama-test") 766 | parser.add_argument("--wandb_tags", nargs="+") 767 | parser.add_argument( 768 | "--wandb_mode", type=str, default="online", choices=["offline", "online"] 769 | ) 770 | 771 | # Other arguments 772 | parser.add_argument( 773 | "--evaluate_only", action="store_true", help="Only evaluate, do not train" 774 | ) 775 | parser.add_argument( 776 | "--use_kv_cache", 777 | help="KV caching during infernce. Only for Lag-LLama.", 778 | action="store_true", 779 | default=True 780 | ) 781 | 782 | # SWA arguments 783 | parser.add_argument( 784 | "--swa", action="store_true", help="Using Stochastic Weight Averaging" 785 | ) 786 | parser.add_argument("--swa_lrs", type=float, default=1e-2) 787 | parser.add_argument("--swa_epoch_start", type=float, default=0.8) 788 | parser.add_argument("--annealing_epochs", type=int, default=10) 789 | parser.add_argument( 790 | "--annealing_strategy", type=str, default="cos", choices=["cos", "linear"] 791 | ) 792 | 793 | # Training/validation iterator type switching 794 | parser.add_argument("--use_single_instance_sampler", action="store_true", default=True) 795 | 796 | # Plot forecasts 797 | parser.add_argument("--plot_test_forecasts", action="store_true", default=True) 798 | 799 | # Search search_batch_size 800 | parser.add_argument("--search_batch_size", action="store_true", default=False) 801 | 802 | # Number of validation windows 803 | parser.add_argument("--num_validation_windows", type=int, default=14) 804 | 805 | # Training KWARGS 806 | parser.add_argument('--lr', type=float, default=1e-3) 807 | parser.add_argument('--weight_decay', type=float, default=1e-8) 808 | 809 | # Override arguments with a dictionary file with args 810 | parser.add_argument('--args_from_dict_path', type=str) 811 | 812 | # Evaluation utils 813 | parser.add_argument("--eval_prefix", type=str) 814 | 815 | # Checkpoints args 816 | parser.add_argument("--ckpt_path", type=str) 817 | parser.add_argument("--get_ckpt_path_from_experiment_name", type=str) 818 | 819 | # Single dataset setup: used typically for finetuning 820 | parser.add_argument("--single_dataset", type=str) 821 | parser.add_argument("--use_dataset_prediction_length", action="store_true", default=False) 822 | parser.add_argument("--single_dataset_last_k_percentage", type=float) 823 | 824 | # CosineAnnealingLR 825 | parser.add_argument("--use_cosine_annealing_lr", action="store_true", default=False) 826 | parser.add_argument("--cosine_annealing_lr_t_max", type=int, default=10000) 827 | parser.add_argument("--cosine_annealing_lr_eta_min", type=float, default=1e-2) 828 | 829 | # Distribution output 830 | parser.add_argument('--distr_output', type=str, default="studentT", choices=["studentT"]) 831 | 832 | args = parser.parse_args() 833 | 834 | if args.args_from_dict_path: 835 | with open(args.args_from_dict_path, "r") as read_file: loaded_args = json.load(read_file) 836 | for key, value in loaded_args.items(): 837 | setattr(args, key, value) 838 | 839 | # print args for logging 840 | for arg in vars(args): 841 | print(arg, ":", getattr(args, arg)) 842 | 843 | train(args) 844 | -------------------------------------------------------------------------------- /scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | # This script can only be executed once you have trained a model. The experiment name used for the trained model should be specified in line 4 2 | 3 | CONFIGPATH="configs/lag_llama.json" 4 | PRETRAINING_EXP_NAME="pretraining_lag_llama" 5 | PERCENTAGE=100 # Change to lesser value to limit the history. Use 20, 40, 60, 80 to reproduce experiments in the paper. 6 | 7 | for FINETUNE_DATASET in "weather" "pedestrian_counts" "exchange_rate" "ett_m2" "platform_delay_minute" "requests_minute" "beijing_pm25" 8 | do 9 | EXP_NAME="${PRETRAINING_EXP_NAME}_finetune_on_${FINETUNE_DATASET}" 10 | 11 | # We reuse the same seeds as used for pretraining 12 | FILENAME="experiments/seeds/${PRETRAINING_EXP_NAME}" 13 | echo $PRETRAINING_EXP_NAME 14 | 15 | # Get the seeds 16 | if [ -f $FILENAME ]; then 17 | echo "${FILENAME} found. Reading seeds." 18 | SEEDS=() 19 | while read -r LINE; do 20 | SEEDS+=("$LINE") 21 | done < $FILENAME 22 | echo "Found ${#SEEDS[@]} seeds for finetuning." 23 | else 24 | echo "${FILENAME} does not exist. Cannot perform finetuning." 25 | exit 0 26 | fi 27 | 28 | # Iterate through all training dataset 29 | for SEED in "${SEEDS[@]}" 30 | do 31 | EXPERIMENT_NAME="${EXP_NAME}_seed_${SEED}" 32 | 33 | python run.py \ 34 | -e $EXPERIMENT_NAME -d "datasets" --seed $SEED \ 35 | -r "experiments/results" \ 36 | --batch_size 512 -m 1000 -n 128 \ 37 | --wandb_entity "enter-wandb-entity" --wandb_project "enter-wandb-project" --wandb_tags "enter-wandb-tags-or-remove-this-argument" \ 38 | --num_workers 2 --args_from_dict_path $CONFIGPATH --search_batch_size \ 39 | --single_dataset $FINETUNE_DATASET \ 40 | --get_ckpt_path_from_experiment_name $PRETRAINING_EXP_NAME --lr 0.00001 --use_dataset_prediction_length --num_validation_windows 1 \ 41 | --single_dataset_last_k_percentage $PERCENTAGE 42 | done 43 | done -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | # PLEASE FOLLOW THE BELOW INSTRUCTIONS FIRST 2 | 3 | # 1. Install the requirements. It is recommend to use a new Anaconda environment with Python 3.10.8. Execute the below command (remove the #) 4 | # !pip install -r requirements.txt 5 | 6 | # 2. Please download https://drive.google.com/file/d/1JrDWMZyoPsc6d1wAAjgm3PosbGus-jCE/view?usp=sharing and use the below command to download the non-monash datasets (remove the #) 7 | # tar -xvzf nonmonash_datasets.tar.gz -C datasets 8 | 9 | # 3. Edit the Weights and Biases arguments on line 59 of this script 10 | 11 | mkdir -p experiments 12 | mkdir -p experiments/seeds 13 | mkdir -p experiments/results 14 | 15 | EXP_NAME="pretraining_lag_llama" 16 | FILENAME="experiments/seeds/${EXP_NAME}" 17 | CONFIGPATH="configs/lag_llama.json" 18 | 19 | echo $EXP_NAME 20 | 21 | # NUM_SEEDS used only if it is a new experiment 22 | NUM_SEEDS=1 23 | 24 | # Create seeds 25 | if [ -f $FILENAME ]; then 26 | echo "${FILENAME} already exists." 27 | 28 | SEEDS=() 29 | while read -r LINE; do 30 | SEEDS+=("$LINE") 31 | done < $FILENAME 32 | echo "Found ${#SEEDS[@]} seeds for training." 33 | else 34 | # Write seeds 35 | echo "${FILENAME} created. Writing seeds." 36 | touch $FILENAME 37 | for (( i = 0; i < $NUM_SEEDS; i++ )) 38 | do 39 | SEED=$((RANDOM + 1)) 40 | echo $SEED >> $FILENAME 41 | done 42 | 43 | # Read them 44 | SEEDS=() 45 | while read -r LINE; do 46 | SEEDS+=("$LINE") 47 | done < $FILENAME 48 | fi 49 | 50 | # Train 51 | for SEED in "${SEEDS[@]}" 52 | do 53 | EXPERIMENT_NAME="${EXP_NAME}_seed_${SEED}" 54 | 55 | python run.py \ 56 | -e $EXP_NAME -d "datasets" --seed $SEED \ 57 | -r "experiments/results" \ 58 | --batch_size 512 -m 1000 -n 128 \ 59 | --wandb_entity "enter-wandb-entity" --wandb_project "enter-wandb-project" --wandb_tags "enter-wandb-tags-or-remove-this-argument" \ 60 | --all_datasets "australian_electricity_demand" "electricity_hourly" "london_smart_meters_without_missing" "solar_10_minutes" "wind_farms_without_missing" "pedestrian_counts" "uber_tlc_hourly" "traffic" "kdd_cup_2018_without_missing" "saugeenday" "sunspot_without_missing" "exchange_rate" "cpu_limit_minute" "cpu_usage_minute" "function_delay_minute" "instances_minute" "memory_limit_minute" "memory_usage_minute" "platform_delay_minute" "requests_minute" "ett_h1" "ett_h2" "ett_m1" "ett_m2" "beijing_pm25" "AirQualityUCI" "beijing_multisite" "weather" \ 61 | --test_datasets "weather" "pedestrian_counts" "exchange_rate" "ett_m2" "platform_delay_minute" "requests_minute" "beijing_pm25" \ 62 | --num_workers 2 --args_from_dict_path $CONFIGPATH --search_batch_size \ 63 | --lr 0.0001 64 | done 65 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Arjun Ashok 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import random 17 | from itertools import islice 18 | 19 | import matplotlib.lines as mlines 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | import torch 23 | 24 | 25 | def set_seed(seed, deterministic=True): 26 | random.seed(seed) 27 | np.random.seed(seed) 28 | torch.manual_seed(seed) 29 | if torch.cuda.is_available(): 30 | torch.cuda.manual_seed(seed) 31 | torch.cuda.manual_seed_all(seed) 32 | torch.backends.cudnn.deterministic = deterministic 33 | torch.backends.cudnn.benchmark = False 34 | os.environ["PYTHONHASHSEED"] = str(seed) 35 | 36 | 37 | def print_gpu_stats(): 38 | # Print GPU stats 39 | device = torch.cuda.current_device() 40 | memory_stats = torch.cuda.memory_stats(device=device) 41 | t = torch.cuda.get_device_properties(0).total_memory / (1024**3) 42 | allocated_memory_gb = memory_stats["allocated_bytes.all.current"] / (1024**3) 43 | print(f"Total Memory: {t:.2f} GB") 44 | print(f"Allocated Memory: {allocated_memory_gb:.2f} GB") 45 | 46 | 47 | def plot_forecasts(forecasts, tss, prediction_length): 48 | plt.figure(figsize=(20, 15)) 49 | plt.rcParams.update({"font.size": 15}) 50 | 51 | # Create custom legend handles 52 | forecast_line = mlines.Line2D([], [], color="g", label="Forecast") 53 | target_line = mlines.Line2D([], [], color="blue", label="Target") 54 | 55 | for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9): 56 | ax = plt.subplot(3, 3, idx + 1) 57 | forecast.plot(color="g") 58 | # ax.plot(forecast, color='g', label="Forecast") 59 | # ts[-3 * dataset.metadata.prediction_length:][0].plot(label="target") 60 | ts[-3 * prediction_length :][0].plot(label="target", ax=ax) 61 | plt.xticks(rotation=60) 62 | ax.set_title(forecast.item_id) 63 | # ax.legend() # Add legend to each subplot 64 | ax.legend(handles=[forecast_line, target_line]) 65 | 66 | plt.gcf().tight_layout() 67 | return plt.gcf() 68 | --------------------------------------------------------------------------------