├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── VERSION
├── configs
└── lag_llama.json
├── data
├── augmentations
│ ├── __init__.py
│ ├── augmentations.py
│ ├── freq_mask.py
│ └── freq_mix.py
├── data_utils.py
├── dataset_list.py
└── read_new_dataset.py
├── gluon_utils
├── gluon_ts_distributions
│ └── implicit_quantile_network.py
└── scalers
│ └── robust_scaler.py
├── images
└── lagllama.webp
├── lag_llama
├── gluon
│ ├── __init__.py
│ ├── estimator.py
│ └── lightning_module.py
└── model
│ ├── __init__.py
│ └── module.py
├── pyproject.toml
├── requirements.txt
├── run.py
├── scripts
├── finetune.sh
└── pretrain.sh
└── utils
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # Custom files
163 |
164 | experiments
165 | scratch/*
166 | wandb/*
167 | stats
168 | hyperparameters
169 | data/datasets
170 | .DS_Store
171 | hyperparameters_lag_transformer
172 | notebooks
173 | datasets
174 | *ckpt
175 | *tar.gz
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include VERSION
2 | include requirements.txt
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting
2 |
3 | 
4 |
5 | Lag-Llama is the first open-source foundation model for time series forecasting!
6 |
7 | [[Tweet Thread](https://twitter.com/arjunashok37/status/1755261111233114165)]
8 |
9 | [[Model Weights](https://huggingface.co/time-series-foundation-models/Lag-Llama)] [[Colab Demo 1: Zero-Shot Forecasting](https://colab.research.google.com/drive/1DRAzLUPxsd-0r8b-o4nlyFXrjw_ZajJJ?usp=sharing)] [[Colab Demo 2: (Preliminary Finetuning)](https://colab.research.google.com/drive/1uvTmh-pe1zO5TeaaRVDdoEWJ5dFDI-pA?usp=sharing)]
10 |
11 | [[Paper](https://arxiv.org/abs/2310.08278)]
12 |
13 | [[Video](https://www.youtube.com/watch?v=Mf2FOzDPxck)]
14 | ____
15 |
16 | Updates:
17 | * **27-June-2024**: Fixed critical issues in the kv_cache implementation, improving forecast accuracy. The fixes include: resetting the self.y_cache flag globally, using causal attention correctly during kv_cache initialization, and adjusting rotary embeddings post-concatenation. Contribution by [@KelianM](https://github.com/KelianM).
18 | * **16-Apr-2024**: Released pretraining and finetuning scripts to replicate the experiments in the paper. See [Reproducing Experiments in the Paper](https://github.com/time-series-foundation-models/lag-llama?tab=readme-ov-file#reproducing-experiments-in-the-paper) for details.
19 | * **9-Apr-2024**: We have released a 15-minute video 🎥 on Lag-Llama on [YouTube](https://www.youtube.com/watch?v=Mf2FOzDPxck).
20 | * **5-Apr-2024**: Added a [section](https://colab.research.google.com/drive/1DRAzLUPxsd-0r8b-o4nlyFXrjw_ZajJJ?authuser=1#scrollTo=Mj9LXMpJ01d7&line=6&uniqifier=1) in Colab Demo 1 on the importance of tuning the context length for zero-shot forecasting. Added a [best practices section](https://github.com/time-series-foundation-models/lag-llama?tab=readme-ov-file#best-practices) in the README; added recommendations for finetuning. These recommendations will be demonstrated with an example in [Colab Demo 2](https://colab.research.google.com/drive/1uvTmh-pe1zO5TeaaRVDdoEWJ5dFDI-pA?usp=sharing) soon.
21 | * **4-Apr-2024**: We have updated our requirements file with new versions of certain packages. Please update/recreate your environments if you have previously used the code locally.
22 | * **7-Mar-2024**: We have released a preliminary [Colab Demo 2](https://colab.research.google.com/drive/1uvTmh-pe1zO5TeaaRVDdoEWJ5dFDI-pA?usp=sharing) for finetuning. Please note this is a preliminary tutorial. We recommend taking a look at the best practices if you are finetuning the model or using it for benchmarking.
23 | * **17-Feb-2024**: We have released a new updated [Colab Demo 1](https://colab.research.google.com/drive/1DRAzLUPxsd-0r8b-o4nlyFXrjw_ZajJJ?usp=sharing) for zero-shot forecasting that shows how one can load time series of different formats.
24 | * **7-Feb-2024**: We released Lag-Llama, with open-source model checkpoints and a Colab Demo for zero-shot forecasting.
25 |
26 | ____
27 |
28 | **Current Features**:
29 |
30 | 💫 Zero-shot forecasting on a dataset of any frequency for any prediction length, using Colab Demo 1.
31 |
32 | 💫 Finetuning on a dataset using [Colab Demo 2](https://colab.research.google.com/drive/1uvTmh-pe1zO5TeaaRVDdoEWJ5dFDI-pA?usp=sharing).
33 |
34 | 💫 Reproducing experiments in the paper using the released scripts. See [Reproducing Experiments in the Paper](https://github.com/time-series-foundation-models/lag-llama?tab=readme-ov-file#reproducing-experiments-in-the-paper) for details.
35 |
36 | **Note**: Please see the [best practices section](https://github.com/time-series-foundation-models/lag-llama?tab=readme-ov-file#best-practices) when using the model for zero-shot prediction and finetuning.
37 |
38 | ____
39 |
40 | ## Reproducing Experiments in the Paper
41 |
42 | To replicate the pretraining setup used in the paper, please see [the pretraining script](scripts/pretrain.sh). Once a model is pretrained, instructions to finetune it with the setup in the paper can be found in [the finetuning script](scripts/finetune.sh).
43 |
44 |
45 | ## Best Practices
46 |
47 | Here are some general tips in using Lag-Llama.
48 |
49 |
50 | ### General Information
51 |
52 | * Lag-Llama is a **probabilistic** forecasting model trained to output a probability distribution for each timestep to be predicted. For your own specific use-case, we would recommend benchmarking the zero-shot performance of the model on your data first, and then finetuning if necessary. As we show in our paper, Lag-Llama has strong zero-shot capabilities, but performs best when finetuned. The more data you finetune on, the better. For specific tips on applying on model zero-shot or on finetuning, please refer to the sections below.
53 |
54 | #### Zero-Shot Forecasting
55 |
56 | * Importantly, we recommend trying different **context lengths** (starting from $32$ which it was trained on) and identifying what works best for your data. As we show in [this section of the zero-shot forecasting demo](https://colab.research.google.com/drive/1DRAzLUPxsd-0r8b-o4nlyFXrjw_ZajJJ?authuser=1#scrollTo=Mj9LXMpJ01d7&line=6&uniqifier=1), the model's zero-shot performance improves as the context length is increased, until a certain context length which may be specific to your data. Further, we recommend enabling RoPE scaling for the model to work well with context lengths larger than what it was trained on.
57 |
58 | #### Fine-Tuning
59 |
60 | If you are trying to **benchmark** the performance of the model under finetuning, or trying to obtain maximum performance from the model:
61 |
62 | * We recommend tuning two important hyperparameters for each dataset that you finetune on: the **context length** (suggested values: $32$, $64$, $128$, $256$, $512$, $1024$) and the **learning rate** (suggested values: $10^{-2}$, $5 * 10^{-3}$, $10^{-3}$, $5 * 10^{-3}$, $1 * 10^{-4}$, $5 * 10^{-4}$).
63 | * We also highly recommend using a validation split of your dataset to early stop your model, with an early stopping patience of 50 epochs.
64 |
65 | ## Contact
66 |
67 | We are dedicated to ensuring the reproducility of our results, and would be happy to help clarify questions about benchmarking our model or about the experiments in the paper.
68 | The quickest way to reach us would be by email. Please email **both**:
69 | 1. [Arjun Ashok](https://ashok-arjun.github.io/) - arjun [dot] ashok [at] servicenow [dot] com
70 | 2. [Kashif Rasul](https://scholar.google.de/citations?user=cfIrwmAAAAAJ&hl=en) - kashif [dot] rasul [at] gmail [dot] com
71 |
72 | If you have questions about the model usage (or) code (or) have specific errors (eg. using it with your own dataset), it would be best to create an issue in the GitHub repository.
73 |
74 | ## Citing this work
75 |
76 | Please use the following Bibtex entry to cite Lag-Llama.
77 |
78 | ```
79 | @misc{rasul2024lagllama,
80 | title={Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting},
81 | author={Kashif Rasul and Arjun Ashok and Andrew Robert Williams and Hena Ghonia and Rishika Bhagwatkar and Arian Khorasani and Mohammad Javad Darvishi Bayazi and George Adamopoulos and Roland Riachi and Nadhir Hassen and Marin Biloš and Sahil Garg and Anderson Schneider and Nicolas Chapados and Alexandre Drouin and Valentina Zantedeschi and Yuriy Nevmyvaka and Irina Rish},
82 | year={2024},
83 | eprint={2310.08278},
84 | archivePrefix={arXiv},
85 | primaryClass={cs.LG}
86 | }
87 | ```
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 0.1.0
--------------------------------------------------------------------------------
/configs/lag_llama.json:
--------------------------------------------------------------------------------
1 | {
2 | "use_single_instance_sampler": true,
3 | "stratified_sampling": "series",
4 | "data_normalization": "robust",
5 | "n_layer": 8,
6 | "n_head": 9,
7 | "n_embd_per_head": 16,
8 | "time_feat": true,
9 | "context_length": 32,
10 | "aug_prob": 0.5,
11 | "freq_mask_rate": 0.5,
12 | "freq_mixing_rate": 0.25,
13 | "weight_decay": 0.0,
14 | "dropout": 0.0
15 | }
--------------------------------------------------------------------------------
/data/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
--------------------------------------------------------------------------------
/data/augmentations/augmentations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Adapted from https://github.com/vafl/gluon-ts/blob/ts_embeddings/src/gluonts/nursery/ts_embeddings/pt_augmentation.py
16 |
17 | import numpy as np
18 | import torch
19 | import torch.nn as nn
20 | from scipy.interpolate import CubicSpline
21 |
22 |
23 | class ApplyAugmentations(nn.Module):
24 | def __init__(self, transforms):
25 | super().__init__()
26 | self.transformation = RandomApply(transforms)
27 |
28 | def forward(self, ip1, ip2):
29 | ip_concat = torch.concat((ip1, ip2), dim=1)
30 | ip_aug = self.transformation(ip_concat)
31 | op1, op2 = torch.split(ip_aug, [ip1.size(1), ip2.size(1)], dim=1)
32 | assert len(op1.shape) == 2 and len(op2.shape) == 2
33 | return op1, op2
34 |
35 |
36 | class RandomApply(nn.Module):
37 | def __init__(self, transforms, p=0.5):
38 | super().__init__()
39 | # 'transforms' is a list of transformation modules to be applied to the input data.
40 | # These could be any transformations like normalization, augmentation, etc.
41 | self.transforms = nn.ModuleList(transforms)
42 |
43 | # 'p' is the probability with which the transformations will be applied.
44 | # It's a floating-point number between 0 and 1.
45 | self.p = p
46 |
47 | def forward(self, x):
48 | # Randomly decide whether to apply the transformations or not, based on probability 'p'.
49 | if self.p < torch.rand(1):
50 | # If the random number is greater than 'p', return the input as it is.
51 | return x
52 |
53 | # Assert shape is (bsz, seq_len)
54 | assert len(x.shape) == 2
55 |
56 | # Unsqueeze last dimension
57 | x = x.unsqueeze(-1)
58 |
59 | # Apply each transformation in the list to the transposed input.
60 | for t in self.transforms:
61 | x = t(x)
62 |
63 | # Squeeze last dimension
64 | x = x.squeeze(-1)
65 |
66 | # Finally, transpose the tensor back to its original dimension order.
67 | # return x_time_first.transpose(1, 2)
68 | return x
69 |
70 |
71 | class Jitter(nn.Module):
72 | """
73 | The Jitter class implements a jittering transformation as described in the paper:
74 | 'Data Augmentation for Machine Learning Algorithms' (https://arxiv.org/pdf/1706.00527.pdf).
75 | It adds random noise to the input data, which is a common technique for data augmentation.
76 | """
77 |
78 | def __init__(self, p, sigma=0.03):
79 | super().__init__()
80 | # 'p' is the probability with which the jitter (noise) will be applied to the input data.
81 | self.p = p
82 |
83 | # 'sigma' defines the standard deviation of the normal distribution used for generating the jitter.
84 | # It controls the magnitude of the noise added to the data.
85 | self.sigma = sigma
86 |
87 | def forward(self, x):
88 | # Randomly decide whether to apply jitter (noise) or not, based on the probability 'p'.
89 | if self.p < torch.rand(1):
90 | # If the random number is greater than 'p', return the input as it is, without adding noise.
91 | return x
92 |
93 | # Generate random noise from a normal distribution with mean 0 and standard deviation 'sigma'.
94 | # The size and device of the noise tensor are the same as the input tensor 'x'.
95 | noise = torch.normal(mean=0.0, std=self.sigma, size=x.shape, device=x.device)
96 |
97 | # Add the generated noise to the input data and return the result.
98 | return x + noise
99 |
100 |
101 | class Scaling(nn.Module):
102 | """
103 | The Scaling class implements a scaling transformation as described in the paper:
104 | 'Data Augmentation for Machine Learning Algorithms' (https://arxiv.org/pdf/1706.00527.pdf).
105 | This transformation scales the input data by a random factor, which can be useful for data augmentation.
106 | """
107 |
108 | def __init__(self, p, sigma=0.1):
109 | super().__init__()
110 | # 'p' is the probability with which the scaling will be applied to the input data.
111 | self.p = p
112 |
113 | # 'sigma' defines the standard deviation of the normal distribution used for generating the scaling factor.
114 | # It controls the variability of the scaling factor.
115 | self.sigma = sigma
116 |
117 | def forward(self, x):
118 | # Randomly decide whether to apply scaling or not, based on the probability 'p'.
119 | if self.p < torch.rand(1):
120 | # If the random number is greater than 'p', return the input as it is, without scaling.
121 | return x
122 |
123 | # Generate random scaling factors from a normal distribution with mean 1 and standard deviation 'sigma'.
124 | # The size of the scaling factor tensor is tailored to match the batch and time dimensions of the input tensor 'x',
125 | # but it has a single channel so that the same factor is applied across all channels.
126 | factor = torch.normal(
127 | mean=1.0, std=self.sigma, size=(x.shape[0], 1, x.shape[2]), device=x.device
128 | )
129 |
130 | # Multiply the input data by the scaling factors and return the result.
131 | return x * factor
132 |
133 |
134 | class Rotation(nn.Module):
135 | """
136 | This Rotation class is designed to randomly rotate the input data.
137 | It's a form of data augmentation that can be particularly useful in scenarios where
138 | the orientation of the data is not a defining characteristic.
139 | """
140 |
141 | def __init__(self, p):
142 | super().__init__()
143 | # 'p' is the probability of applying the rotation to the input data.
144 | self.p = p
145 |
146 | def forward(self, x):
147 | # Randomly decide whether to rotate the data or not, based on the probability 'p'.
148 | if self.p < torch.rand(1):
149 | # If the random number is greater than 'p', return the input as it is, without rotation.
150 | return x
151 |
152 | # Create an index for flipping, where each element has a 50% chance of being 0 or 1.
153 | flip_index = torch.multinomial(
154 | torch.tensor([0.5, 0.5], dtype=x.dtype, device=x.device),
155 | num_samples=x.shape[0] * x.shape[2],
156 | replacement=True,
157 | )
158 |
159 | # Create a tensor of ones, which will be used to flip the sign of the data based on the flip_index.
160 | ones = torch.ones((x.shape[0] * x.shape[2]), device=x.device)
161 | flip = torch.where(flip_index == 0, -ones, ones)
162 |
163 | # Randomly shuffle the axes along which the data will be rotated.
164 | rotate_axis = np.arange(x.shape[2])
165 | np.random.shuffle(rotate_axis)
166 |
167 | # Apply the flipping and rotation to the data and return the result.
168 | return flip.reshape(x.shape[0], 1, x.shape[2]) * x[:, :, rotate_axis]
169 |
170 |
171 | class Permutation(nn.Module):
172 | """
173 | The Permutation class implements a data augmentation technique where the data is divided into segments,
174 | and these segments are then randomly permuted. This can be useful for tasks where the order of data points
175 | is not crucial and can help in improving the robustness of models.
176 | """
177 |
178 | def __init__(self, p, max_segments=5, seg_mode="equal"):
179 | super().__init__()
180 | # 'p' is the probability of applying the permutation to the input data.
181 | self.p = p
182 |
183 | # 'max_segments' defines the maximum number of segments into which the data can be split for permutation.
184 | self.max_segments = max_segments
185 |
186 | # 'seg_mode' determines how the segments are created: 'equal' for equal-sized segments, 'random' for random splits.
187 | self.seg_mode = seg_mode
188 |
189 | def forward(self, x):
190 | # Randomly decide whether to permute the data or not, based on the probability 'p'.
191 | if self.p < torch.rand(1):
192 | # If the random number is greater than 'p', return the input as it is, without permutation.
193 | return x
194 |
195 | # Create an array representing the original order of data points.
196 | orig_steps = np.arange(x.shape[1])
197 |
198 | # Randomly decide the number of segments for each batch in the data.
199 | num_segs = np.random.randint(1, self.max_segments, size=(x.shape[0]))
200 |
201 | # Initialize a tensor to hold the permuted data.
202 | ret = torch.zeros_like(x)
203 | for i, pat in enumerate(x):
204 | if num_segs[i] > 1:
205 | if self.seg_mode == "random":
206 | # In 'random' mode, choose random split points.
207 | split_points = np.random.choice(
208 | x.shape[1] - 2, num_segs[i] - 1, replace=False
209 | )
210 | split_points.sort()
211 | splits = np.split(orig_steps, split_points)
212 | else:
213 | # In 'equal' mode, split the data into roughly equal segments.
214 | splits = np.array_split(orig_steps, num_segs[i])
215 |
216 | # Permute the segments and recombine them.
217 | warp = np.concatenate(np.random.permutation(splits)).ravel()
218 | ret[i] = pat[warp]
219 | else:
220 | # If there's only one segment, keep the data as it is.
221 | ret[i] = pat
222 |
223 | return ret
224 |
225 |
226 | class MagnitudeWarp(nn.Module):
227 | """
228 | The MagnitudeWarp class applies a non-linear warping to the magnitude of the input data.
229 | This is achieved by using cubic splines to create smooth, random warp functions that are
230 | then applied to the input. It's a form of data augmentation useful in scenarios where the
231 | model needs to be robust to variations in the magnitude of the input data.
232 | """
233 |
234 | def __init__(self, p, sigma=0.2, knot=4):
235 | super().__init__()
236 | # 'p' is the probability with which the magnitude warp will be applied.
237 | self.p = p
238 |
239 | # 'sigma' controls the variability of the warp. Higher values lead to more pronounced warping.
240 | self.sigma = sigma
241 |
242 | # 'knot' is the number of points in the cubic spline used for warping.
243 | self.knot = knot
244 |
245 | def forward(self, x):
246 | # Decide whether to apply the warp based on the probability 'p'.
247 | if self.p < torch.rand(1):
248 | return x
249 |
250 | # Generate an array representing the original order of data points.
251 | orig_steps = np.arange(x.shape[1])
252 |
253 | # Generate random warps using a normal distribution centered at 1.0.
254 | random_warps = np.random.normal(
255 | loc=1.0,
256 | scale=self.sigma,
257 | size=(x.shape[0], self.knot + 2, x.shape[2]),
258 | )
259 |
260 | # Create warp steps evenly distributed across the data length.
261 | warp_steps = (
262 | np.ones((x.shape[2], 1))
263 | * (np.linspace(0, x.shape[1] - 1.0, num=self.knot + 2))
264 | ).T
265 |
266 | # Initialize a tensor to hold the warped data.
267 | ret = torch.zeros_like(x)
268 | for i, pat in enumerate(x):
269 | # For each dimension, create a cubic spline based on the warp steps and random warps,
270 | # and apply it to the original steps to get the warper.
271 | warper = np.array(
272 | [
273 | CubicSpline(warp_steps[:, dim], random_warps[i, :, dim])(orig_steps)
274 | for dim in range(x.shape[2])
275 | ]
276 | ).T
277 |
278 | # Apply the warper to the pattern and store it in the result tensor.
279 | ret[i] = pat * torch.from_numpy(warper).float().to(x.device)
280 |
281 | return ret
282 |
283 |
284 | class TimeWarp(nn.Module):
285 | """
286 | The TimeWrap class applies a non-linear warping to the time axis of the input data.
287 | This is achieved by using cubic splines to create smooth, random warp functions that
288 | distort the time dimension of the input. It's a form of data augmentation useful for
289 | tasks where the model needs to be robust to variations in the timing of the input data.
290 | """
291 |
292 | def __init__(self, p, sigma=0.2, knot=4):
293 | super().__init__()
294 | # 'p' is the probability with which the time warp will be applied.
295 | self.p = p
296 |
297 | # 'sigma' controls the variability of the warp. Higher values lead to more pronounced warping.
298 | self.sigma = sigma
299 |
300 | # 'knot' is the number of points in the cubic spline used for warping.
301 | self.knot = knot
302 |
303 | def forward(self, x):
304 | # Decide whether to apply the warp based on the probability 'p'.
305 | if self.p < torch.rand(1):
306 | return x
307 |
308 | # Generate an array representing the original time steps of the data.
309 | orig_steps = np.arange(x.shape[1])
310 |
311 | # Generate random warps using a normal distribution centered at 1.0.
312 | random_warps = np.random.normal(
313 | loc=1.0,
314 | scale=self.sigma,
315 | size=(x.shape[0], self.knot + 2, x.shape[2]),
316 | )
317 |
318 | # Create warp steps evenly distributed across the data length.
319 | warp_steps = (
320 | np.ones((x.shape[2], 1))
321 | * (np.linspace(0, x.shape[1] - 1.0, num=self.knot + 2))
322 | ).T
323 |
324 | # Initialize a tensor to hold the time-warped data.
325 | ret = torch.zeros_like(x)
326 | for i, pat in enumerate(x):
327 | for dim in range(x.shape[2]):
328 | # Create a cubic spline based on the warp steps and random warps to generate the time warp.
329 | time_warp = CubicSpline(
330 | warp_steps[:, dim],
331 | warp_steps[:, dim] * random_warps[i, :, dim],
332 | )(orig_steps)
333 | # Scale the time warp to fit the original data length.
334 | scale = (x.shape[1] - 1) / time_warp[-1]
335 | wrap = np.interp(
336 | orig_steps,
337 | np.clip(scale * time_warp, 0, x.shape[1] - 1),
338 | pat[:, dim].cpu().numpy(),
339 | ).T
340 | # Apply the time warp to the corresponding dimension of the data.
341 | ret[i, :, dim] = torch.from_numpy(wrap).float().to(x.device)
342 |
343 | return ret
344 |
345 |
346 | class WindowSlice(nn.Module):
347 | """
348 | The WindowSlice class implements a data augmentation technique where a slice of the input data
349 | is stretched to fill the entire length of the input. This technique is useful for training models
350 | to focus on local features of the data and can be found in literature such as:
351 | 'Time Series Data Augmentation for Deep Learning: A Survey' (https://halshs.archives-ouvertes.fr/halshs-01357973/document).
352 | """
353 |
354 | def __init__(self, p, reduce_ratio=0.9):
355 | super().__init__()
356 | # 'p' is the probability of applying the window slicing to the input data.
357 | self.p = p
358 |
359 | # 'reduce_ratio' determines the size of the slice relative to the original data.
360 | self.reduce_ratio = reduce_ratio
361 |
362 | def forward(self, x):
363 | # Decide whether to apply the slice based on the probability 'p'.
364 | if self.p < torch.rand(1):
365 | return x
366 |
367 | # Calculate the target length of the slice.
368 | target_len = np.ceil(self.reduce_ratio * x.shape[1]).astype(int)
369 | if target_len >= x.shape[1]:
370 | return x
371 |
372 | # Randomly select start points for the slice in each batch.
373 | starts = np.random.randint(
374 | low=0, high=x.shape[1] - target_len, size=(x.shape[0])
375 | ).astype(int)
376 | ends = (target_len + starts).astype(int)
377 |
378 | # Initialize a tensor to hold the sliced and stretched data.
379 | ret = torch.zeros_like(x)
380 | for i, pat in enumerate(x):
381 | for dim in range(x.shape[2]):
382 | # Interpolate the slice to stretch it across the original length of the data.
383 | warp = np.interp(
384 | np.linspace(0, target_len, num=x.shape[1]),
385 | np.arange(target_len),
386 | pat[starts[i] : ends[i], dim].cpu().numpy(),
387 | ).T
388 | # Apply the stretched slice to the corresponding dimension of the data.
389 | ret[i, :, dim] = torch.from_numpy(warp).float().to(x.device)
390 | return ret
391 |
392 |
393 | class WindowWarp(nn.Module):
394 | """
395 | The WindowWarp class implements a data augmentation technique where a segment (window) of the input data
396 | is selected and warped in size. This technique is useful for simulating variations in the speed or rate
397 | of the data within a certain window, as discussed in:
398 | 'Time Series Data Augmentation for Deep Learning: A Survey' (https://halshs.archives-ouvertes.fr/halshs-01357973/document).
399 | """
400 |
401 | def __init__(self, p, window_ratio=0.1, scales=[0.5, 2.0]):
402 | super().__init__()
403 | # 'p' is the probability of applying the window warp to the input data.
404 | self.p = p
405 |
406 | # 'window_ratio' determines the size of the window relative to the original data.
407 | self.window_ratio = window_ratio
408 |
409 | # 'scales' are the possible scaling factors to be applied to the window.
410 | self.scales = scales
411 |
412 | def forward(self, x):
413 | # Decide whether to apply the warp based on the probability 'p'.
414 | if self.p < torch.rand(1):
415 | return x
416 |
417 | # Randomly choose a scaling factor for each batch in the data.
418 | warp_scales = np.random.choice(self.scales, x.shape[0])
419 |
420 | # Calculate the size of the warp window.
421 | warp_size = np.ceil(self.window_ratio * x.shape[1]).astype(int)
422 | window_steps = np.arange(warp_size)
423 |
424 | # Randomly select start points for the window in each batch.
425 | window_starts = np.random.randint(
426 | low=1, high=x.shape[1] - warp_size - 1, size=(x.shape[0])
427 | ).astype(int)
428 | window_ends = (window_starts + warp_size).astype(int)
429 |
430 | # Initialize a tensor to hold the window-warped data.
431 | ret = torch.zeros_like(x)
432 | for i, pat in enumerate(x):
433 | for dim in range(x.shape[2]):
434 | # Isolate the segments before, within, and after the window.
435 | start_seg = pat[: window_starts[i], dim].cpu().numpy()
436 | window_seg = np.interp(
437 | np.linspace(
438 | 0,
439 | warp_size - 1,
440 | num=int(warp_size * warp_scales[i]),
441 | ),
442 | window_steps,
443 | pat[window_starts[i] : window_ends[i], dim].cpu().numpy(),
444 | )
445 | end_seg = pat[window_ends[i] :, dim].cpu().numpy()
446 |
447 | # Concatenate the segments and stretch them to fit the original data length.
448 | warped = np.concatenate((start_seg, window_seg, end_seg))
449 | warp = np.interp(
450 | np.arange(x.shape[1]),
451 | np.linspace(0, x.shape[1] - 1.0, num=warped.size),
452 | warped,
453 | ).T
454 |
455 | # Apply the window warp to the corresponding dimension of the data.
456 | ret[i, :, dim] = torch.from_numpy(warp).float().to(x.device)
457 | return ret
458 |
--------------------------------------------------------------------------------
/data/augmentations/freq_mask.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 |
18 | @torch.no_grad()
19 | def freq_mask(x, y, rate=0.1, dim=1):
20 | # Get lengths of the input tensors along the specified dimension.
21 | x_len = x.shape[dim]
22 | y_len = y.shape[dim]
23 |
24 | # Concatenate x and y along the specified dimension.
25 | # x and y represent past and future targets respectively.
26 | xy = torch.cat([x, y], dim=dim)
27 |
28 | # Perform a real-valued fast Fourier transform (RFFT) on the concatenated tensor.
29 | # This transforms the time series data into the frequency domain.
30 | xy_f = torch.fft.rfft(xy, dim=dim)
31 |
32 | # Create a random mask with a probability defined by 'rate'.
33 | # This mask will be used to randomly select frequencies to be zeroed out.
34 | m = torch.rand_like(xy_f, dtype=xy.dtype) < rate
35 |
36 | # Apply the mask to the real and imaginary parts of the frequency data,
37 | # setting the selected frequencies to zero. This 'masks' those frequencies.
38 | freal = xy_f.real.masked_fill(m, 0)
39 | fimag = xy_f.imag.masked_fill(m, 0)
40 |
41 | # Combine the masked real and imaginary parts back into complex frequency data.
42 | xy_f = torch.complex(freal, fimag)
43 |
44 | # Perform an inverse RFFT to transform the data back to the time domain.
45 | # The masked frequencies will affect the reconstructed time series.
46 | xy = torch.fft.irfft(xy_f, dim=dim)
47 |
48 | # If the reconstructed data length differs from the original concatenated length,
49 | # adjust it to maintain consistency. This step ensures the output shape matches the input.
50 | if x_len + y_len != xy.shape[dim]:
51 | xy = torch.cat([x[:, 0:1, ...], xy], 1)
52 |
53 | # Split the reconstructed data back into two parts corresponding to the original x and y.
54 | return torch.split(xy, [x_len, y_len], dim=dim)
55 |
--------------------------------------------------------------------------------
/data/augmentations/freq_mix.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import torch
17 |
18 |
19 | @torch.no_grad()
20 | def freq_mix(x, y, rate=0.1, dim=1):
21 | # Get lengths of the input tensors along the specified dimension.
22 | x_len = x.shape[dim]
23 | y_len = y.shape[dim]
24 |
25 | # Concatenate x and y along the specified dimension.
26 | # x and y represent past and future targets respectively.
27 | xy = torch.cat([x, y], dim=dim)
28 |
29 | # Perform a real-valued fast Fourier transform (RFFT) on the concatenated tensor.
30 | xy_f = torch.fft.rfft(xy, dim=dim)
31 |
32 | # Create a random mask with a probability defined by 'rate'.
33 | # This mask will be used to select which frequencies to manipulate.
34 | m = torch.rand_like(xy_f, dtype=xy.dtype) < rate
35 |
36 | # Calculate the amplitude of the frequency components.
37 | amp = abs(xy_f)
38 |
39 | # Sort the amplitudes and create a mask to ignore the most dominant frequencies.
40 | _, index = amp.sort(dim=dim, descending=True)
41 | dominant_mask = index > 2
42 | m = torch.bitwise_and(m, dominant_mask)
43 |
44 | # Apply the mask to the real and imaginary parts of the frequency data,
45 | # setting masked frequencies to zero.
46 | freal = xy_f.real.masked_fill(m, 0)
47 | fimag = xy_f.imag.masked_fill(m, 0)
48 |
49 | # Shuffle the batches in x and y to mix data from different sequences.
50 | b_idx = np.arange(x.shape[0])
51 | np.random.shuffle(b_idx)
52 | x2, y2 = x[b_idx], y[b_idx]
53 |
54 | # Concatenate the shuffled tensors and perform RFFT.
55 | xy2 = torch.cat([x2, y2], dim=dim)
56 | xy2_f = torch.fft.rfft(xy2, dim=dim)
57 |
58 | # Invert the mask and apply it to the shuffled frequency data.
59 | m = torch.bitwise_not(m)
60 | freal2 = xy2_f.real.masked_fill(m, 0)
61 | fimag2 = xy2_f.imag.masked_fill(m, 0)
62 |
63 | # Combine the original and shuffled frequency data.
64 | freal += freal2
65 | fimag += fimag2
66 |
67 | # Reconstruct the complex frequency data and perform an inverse RFFT.
68 | xy_f = torch.complex(freal, fimag)
69 | xy = torch.fft.irfft(xy_f, dim=dim)
70 |
71 | # If the reconstructed data length differs from the original concatenated length,
72 | # adjust it to maintain consistency.
73 | if x_len + y_len != xy.shape[dim]:
74 | xy = torch.cat([x[:, 0:1, ...], xy], 1)
75 |
76 | # Split the reconstructed data back into two parts corresponding to the original x and y.
77 | return torch.split(xy, [x_len, y_len], dim=dim)
78 |
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import copy
16 | import random
17 | import warnings
18 | import json
19 | import os
20 | from pathlib import Path
21 | warnings.simplefilter(action="ignore", category=FutureWarning)
22 | warnings.simplefilter(action="ignore", category=UserWarning)
23 | from pathlib import Path
24 |
25 | import numpy as np
26 | import pandas as pd
27 | from tqdm import tqdm
28 | from gluonts.dataset.common import ListDataset
29 | from gluonts.dataset.repository.datasets import get_dataset
30 | from gluonts.transform import InstanceSampler
31 | from pandas.tseries.frequencies import to_offset
32 |
33 | from data.read_new_dataset import get_ett_dataset, create_train_dataset_without_last_k_timesteps, TrainDatasets, MetaData
34 |
35 | class CombinedDatasetIterator:
36 | def __init__(self, datasets, seed, weights):
37 | self._datasets = [iter(el) for el in datasets]
38 | self._weights = weights
39 | self._rng = random.Random(seed)
40 |
41 | def __next__(self):
42 | (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
43 | return next(dataset)
44 |
45 |
46 | class CombinedDataset:
47 | def __init__(self, datasets, seed=None, weights=None):
48 | self._seed = seed
49 | self._datasets = datasets
50 | self._weights = weights
51 | n_datasets = len(datasets)
52 | if weights is None:
53 | self._weights = [1 / n_datasets] * n_datasets
54 |
55 | def __iter__(self):
56 | return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
57 |
58 | def __len__(self):
59 | return sum([len(ds) for ds in self._datasets])
60 |
61 |
62 | class SingleInstanceSampler(InstanceSampler):
63 | """
64 | Randomly pick a single valid window in the given time series.
65 | This fix the bias in ExpectedNumInstanceSampler which leads to varying sampling frequency
66 | of time series of unequal length, not only based on their length, but when they were sampled.
67 | """
68 |
69 | """End index of the history"""
70 |
71 | def __call__(self, ts: np.ndarray) -> np.ndarray:
72 | a, b = self._get_bounds(ts)
73 | window_size = b - a + 1
74 | if window_size <= 0:
75 | return np.array([], dtype=int)
76 | indices = np.random.randint(window_size, size=1)
77 | return indices + a
78 |
79 |
80 | def _count_timesteps(
81 | left: pd.Timestamp, right: pd.Timestamp, delta: pd.DateOffset
82 | ) -> int:
83 | """
84 | Count how many timesteps there are between left and right, according to the given timesteps delta.
85 | If the number if not integer, round down.
86 | """
87 | # This is due to GluonTS replacing Timestamp by Period for version 0.10.0.
88 | # Original code was tested on version 0.9.4
89 | if type(left) == pd.Period:
90 | left = left.to_timestamp()
91 | if type(right) == pd.Period:
92 | right = right.to_timestamp()
93 | assert (
94 | right >= left
95 | ), f"Case where left ({left}) is after right ({right}) is not implemented in _count_timesteps()."
96 | try:
97 | return (right - left) // delta
98 | except TypeError:
99 | # For MonthEnd offsets, the division does not work, so we count months one by one.
100 | for i in range(10000):
101 | if left + (i + 1) * delta > right:
102 | return i
103 | else:
104 | raise RuntimeError(
105 | f"Too large difference between both timestamps ({left} and {right}) for _count_timesteps()."
106 | )
107 |
108 | from pathlib import Path
109 | from gluonts.dataset.common import ListDataset
110 | from gluonts.dataset.repository.datasets import get_dataset
111 |
112 | def create_train_dataset_last_k_percentage(
113 | raw_train_dataset,
114 | freq,
115 | k=100
116 | ):
117 | # Get training data
118 | train_data = []
119 | for i, series in enumerate(raw_train_dataset):
120 | s_train = series.copy()
121 | number_of_values = int(len(s_train["target"]) * k / 100)
122 | train_start_index = len(s_train["target"]) - number_of_values
123 | s_train["target"] = s_train["target"][train_start_index:]
124 | train_data.append(s_train)
125 |
126 | train_data = ListDataset(train_data, freq=freq)
127 |
128 | return train_data
129 |
130 | def create_train_and_val_datasets_with_dates(
131 | name,
132 | dataset_path,
133 | data_id,
134 | history_length,
135 | prediction_length=None,
136 | num_val_windows=None,
137 | val_start_date=None,
138 | train_start_date=None,
139 | freq=None,
140 | last_k_percentage=None
141 | ):
142 | """
143 | Train Start date is assumed to be the start of the series if not provided
144 | Freq is not given is inferred from the data
145 | We can use ListDataset to just group multiple time series - https://github.com/awslabs/gluonts/issues/695
146 | """
147 |
148 | if name in ("ett_h1", "ett_h2", "ett_m1", "ett_m2"):
149 | path = os.path.join(dataset_path, "ett_datasets")
150 | raw_dataset = get_ett_dataset(name, path)
151 | elif name in ("cpu_limit_minute", "cpu_usage_minute", \
152 | "function_delay_minute", "instances_minute", \
153 | "memory_limit_minute", "memory_usage_minute", \
154 | "platform_delay_minute", "requests_minute"):
155 | path = os.path.join(dataset_path, "huawei/" + name + ".json")
156 | with open(path, "r") as f: data = json.load(f)
157 | metadata = MetaData(**data["metadata"])
158 | train_data = [x for x in data["train"] if type(x["target"][0]) != str]
159 | test_data = [x for x in data["test"] if type(x["target"][0]) != str]
160 | train_ds = ListDataset(train_data, freq=metadata.freq)
161 | test_ds = ListDataset(test_data, freq=metadata.freq)
162 | raw_dataset = TrainDatasets(metadata=metadata, train=train_ds, test=test_ds)
163 | elif name in ("beijing_pm25", "AirQualityUCI", "beijing_multisite"):
164 | path = os.path.join(dataset_path, "air_quality/" + name + ".json")
165 | with open(path, "r") as f:
166 | data = json.load(f)
167 | metadata = MetaData(**data["metadata"])
168 | train_test_data = [x for x in data["data"] if type(x["target"][0]) != str]
169 | full_dataset = ListDataset(train_test_data, freq=metadata.freq)
170 | train_ds = create_train_dataset_without_last_k_timesteps(full_dataset, freq=metadata.freq, k=24)
171 | raw_dataset = TrainDatasets(metadata=metadata, train=train_ds, test=full_dataset)
172 | else:
173 | raw_dataset = get_dataset(name, path=Path(dataset_path))
174 |
175 | if prediction_length is None:
176 | prediction_length = raw_dataset.metadata.prediction_length
177 | if freq is None:
178 | freq = raw_dataset.metadata.freq
179 | timestep_delta = pd.tseries.frequencies.to_offset(freq)
180 | raw_train_dataset = raw_dataset.train
181 |
182 | if not num_val_windows and not val_start_date:
183 | raise Exception("Either num_val_windows or val_start_date must be provided")
184 | if num_val_windows and val_start_date:
185 | raise Exception("Either num_val_windows or val_start_date must be provided")
186 |
187 | max_train_end_date = None
188 |
189 | # Get training data
190 | total_train_points = 0
191 | train_data = []
192 | for i, series in enumerate(raw_train_dataset):
193 | s_train = series.copy()
194 | if val_start_date is not None:
195 | train_end_index = _count_timesteps(
196 | series["start"] if not train_start_date else train_start_date,
197 | val_start_date,
198 | timestep_delta,
199 | )
200 | else:
201 | train_end_index = len(series["target"]) - num_val_windows
202 | # Compute train_start_index based on last_k_percentage
203 | if last_k_percentage:
204 | number_of_values = int(len(s_train["target"]) * last_k_percentage / 100)
205 | train_start_index = train_end_index - number_of_values
206 | else:
207 | train_start_index = 0
208 | s_train["target"] = series["target"][train_start_index:train_end_index]
209 | s_train["item_id"] = i
210 | s_train["data_id"] = data_id
211 | train_data.append(s_train)
212 | total_train_points += len(s_train["target"])
213 |
214 | # Calculate the end date
215 | end_date = s_train["start"] + to_offset(freq) * (len(s_train["target"]) - 1)
216 | if max_train_end_date is None or end_date > max_train_end_date:
217 | max_train_end_date = end_date
218 |
219 | train_data = ListDataset(train_data, freq=freq)
220 |
221 | # Get validation data
222 | total_val_points = 0
223 | total_val_windows = 0
224 | val_data = []
225 | for i, series in enumerate(raw_train_dataset):
226 | s_val = series.copy()
227 | if val_start_date is not None:
228 | train_end_index = _count_timesteps(
229 | series["start"], val_start_date, timestep_delta
230 | )
231 | else:
232 | train_end_index = len(series["target"]) - num_val_windows
233 | val_start_index = train_end_index - prediction_length - history_length
234 | s_val["start"] = series["start"] + val_start_index * timestep_delta
235 | s_val["target"] = series["target"][val_start_index:]
236 | s_val["item_id"] = i
237 | s_val["data_id"] = data_id
238 | val_data.append(s_val)
239 | total_val_points += len(s_val["target"])
240 | total_val_windows += len(s_val["target"]) - prediction_length - history_length
241 | val_data = ListDataset(val_data, freq=freq)
242 |
243 | total_points = (
244 | total_train_points
245 | + total_val_points
246 | - (len(raw_train_dataset) * (prediction_length + history_length))
247 | )
248 |
249 | return (
250 | train_data,
251 | val_data,
252 | total_train_points,
253 | total_val_points,
254 | total_val_windows,
255 | max_train_end_date,
256 | total_points,
257 | )
258 |
259 |
260 | def create_test_dataset(
261 | name, dataset_path, history_length, freq=None, data_id=None
262 | ):
263 | """
264 | For now, only window per series is used.
265 | make_evaluation_predictions automatically only predicts for the last "prediction_length" timesteps
266 | NOTE / TODO: For datasets where the test set has more series (possibly due to more timestamps), \
267 | we should check if we only use the last N series where N = series per single timestamp, or if we should do something else.
268 | """
269 |
270 | if name in ("ett_h1", "ett_h2", "ett_m1", "ett_m2"):
271 | path = os.path.join(dataset_path, "ett_datasets")
272 | dataset = get_ett_dataset(name, path)
273 | elif name in ("cpu_limit_minute", "cpu_usage_minute", \
274 | "function_delay_minute", "instances_minute", \
275 | "memory_limit_minute", "memory_usage_minute", \
276 | "platform_delay_minute", "requests_minute"):
277 | path = os.path.join(dataset_path, "huawei/" + name + ".json")
278 | with open(path, "r") as f: data = json.load(f)
279 | metadata = MetaData(**data["metadata"])
280 | train_data = [x for x in data["train"] if type(x["target"][0]) != str]
281 | test_data = [x for x in data["test"] if type(x["target"][0]) != str]
282 | train_ds = ListDataset(train_data, freq=metadata.freq)
283 | test_ds = ListDataset(test_data, freq=metadata.freq)
284 | dataset = TrainDatasets(metadata=metadata, train=train_ds, test=test_ds)
285 | elif name in ("beijing_pm25", "AirQualityUCI", "beijing_multisite"):
286 | path = os.path.join(dataset_path, "air_quality/" + name + ".json")
287 | with open(path, "r") as f:
288 | data = json.load(f)
289 | metadata = MetaData(**data["metadata"])
290 | train_test_data = [x for x in data["data"] if type(x["target"][0]) != str]
291 | full_dataset = ListDataset(train_test_data, freq=metadata.freq)
292 | train_ds = create_train_dataset_without_last_k_timesteps(full_dataset, freq=metadata.freq, k=24)
293 | dataset = TrainDatasets(metadata=metadata, train=train_ds, test=full_dataset)
294 | else:
295 | dataset = get_dataset(name, path=Path(dataset_path))
296 |
297 | if freq is None:
298 | freq = dataset.metadata.freq
299 | prediction_length = dataset.metadata.prediction_length
300 | data = []
301 | total_points = 0
302 | for i, series in enumerate(dataset.test):
303 | offset = len(series["target"]) - (history_length + prediction_length)
304 | if offset > 0:
305 | target = series["target"][-(history_length + prediction_length) :]
306 | data.append(
307 | {
308 | "target": target,
309 | "start": series["start"] + offset,
310 | "item_id": i,
311 | "data_id": data_id,
312 | }
313 | )
314 | else:
315 | series_copy = copy.deepcopy(series)
316 | series_copy["item_id"] = i
317 | series_copy["data_id"] = data_id
318 | data.append(series_copy)
319 | total_points += len(data[-1]["target"])
320 | return ListDataset(data, freq=freq), prediction_length, total_points
--------------------------------------------------------------------------------
/data/dataset_list.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | ALL_DATASETS = ["australian_electricity_demand", "electricity_hourly", "london_smart_meters_without_missing", "solar_10_minutes", "wind_farms_without_missing", "pedestrian_counts", "uber_tlc_hourly", "traffic", "kdd_cup_2018_without_missing", "saugeenday", "sunspot_without_missing", "exchange_rate", "cpu_limit_minute", "cpu_usage_minute", "function_delay_minute", "instances_minute", "memory_limit_minute", "memory_usage_minute", "platform_delay_minute", "requests_minute", "ett_h1", "ett_h2", "ett_m1", "ett_m2", "beijing_pm25", "AirQualityUCI", "beijing_multisite"]
--------------------------------------------------------------------------------
/data/read_new_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | warnings.simplefilter(action="ignore", category=FutureWarning)
17 | warnings.simplefilter(action="ignore", category=UserWarning)
18 |
19 | import gzip, json
20 | from gluonts.dataset.common import ListDataset, TrainDatasets, MetaData
21 | from pathlib import Path
22 | from gluonts.dataset.repository.datasets import get_dataset
23 | import os
24 |
25 | def create_train_dataset_without_last_k_timesteps(
26 | raw_train_dataset,
27 | freq,
28 | k=0
29 | ):
30 | train_data = []
31 | for i, series in enumerate(raw_train_dataset):
32 | s_train = series.copy()
33 | s_train["target"] = s_train["target"][:len(s_train["target"])-k]
34 | train_data.append(s_train)
35 | train_data = ListDataset(train_data, freq=freq)
36 | return train_data
37 |
38 | def load_jsonl_gzip_file(file_path):
39 | with gzip.open(file_path, 'rt') as f:
40 | return [json.loads(line) for line in f]
41 |
42 | def get_ett_dataset(dataset_name, path):
43 | dataset_path = Path(path) / dataset_name
44 | metadata_path = dataset_path / 'metadata.json'
45 | with open(metadata_path, 'r') as f:
46 | metadata_dict = json.load(f)
47 | metadata = MetaData(**metadata_dict)
48 | # Load train and test datasets
49 | train_data_path = dataset_path / 'train' / 'data.json.gz'
50 | test_data_path = dataset_path / 'test' / 'data.json.gz'
51 | # test dataset
52 | test_data = load_jsonl_gzip_file(test_data_path)
53 | # Create GluonTS ListDatasets
54 | test_ds = ListDataset(test_data, freq=metadata.freq)
55 | train_ds = create_train_dataset_without_last_k_timesteps(test_ds, freq=metadata.freq, k=24)
56 | return TrainDatasets(metadata=metadata, train=train_ds, test=test_ds)
57 |
58 | if __name__ == "__main__":
59 | dataset_name = "ett_h1"
60 |
61 | if dataset_name in ("ett_h1", "ett_h2", "ett_m1", "ett_m2"):
62 | path = "data/datasets/ett_datasets"
63 | ds = get_ett_dataset(dataset_name, path)
64 |
65 | if dataset_name in ("cpu_limit_minute", "cpu_usage_minute", \
66 | "function_delay_minute", "instances_minute", \
67 | "memory_limit_minute", "memory_usage_minute", \
68 | "platform_delay_minute", "requests_minute"):
69 | path = "data/datasets/huawei/" + dataset_name + ".json"
70 | with open(path, "r") as f: data = json.load(f)
71 | metadata = MetaData(**data["metadata"])
72 | train_data = [x for x in data["train"] if type(x["target"][0]) != str]
73 | test_data = [x for x in data["test"] if type(x["target"][0]) != str]
74 | train_ds = ListDataset(train_data, freq=metadata.freq)
75 | test_ds = ListDataset(test_data, freq=metadata.freq)
76 | ds = TrainDatasets(metadata=metadata, train=train_ds, test=test_ds)
77 |
78 | if dataset_name in ("beijing_pm25", "AirQualityUCI", "beijing_multisite"):
79 | path = "data/datasets/air_quality/" + dataset_name + ".json"
80 | with open(path, "r") as f:
81 | data = json.load(f)
82 | metadata = MetaData(**data["metadata"])
83 | train_test_data = [x for x in data["data"] if type(x["target"][0]) != str]
84 | full_dataset = ListDataset(train_test_data, freq=metadata.freq)
85 | train_ds = create_train_dataset_without_last_k_timesteps(test_ds, freq=metadata.freq, k=24)
86 | ds = TrainDatasets(metadata=metadata, train=train_ds, test=full_dataset)
--------------------------------------------------------------------------------
/gluon_utils/gluon_ts_distributions/implicit_quantile_network.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License").
4 | # You may not use this file except in compliance with the License.
5 | # A copy of the License is located at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # or in the "license" file accompanying this file. This file is distributed
10 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11 | # express or implied. See the License for the specific language governing
12 | # permissions and limitations under the License.
13 |
14 | from functools import partial
15 | from typing import Callable, Dict, Optional, Tuple
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | from torch.distributions import Beta, Distribution, constraints
21 |
22 | from gluonts.core.component import validated
23 | from gluonts.torch.distributions import DistributionOutput
24 | from gluonts.torch.modules.lambda_layer import LambdaLayer
25 |
26 |
27 | class QuantileLayer(nn.Module):
28 | r"""
29 | Implicit Quantile Layer from the paper ``IQN for Distributional
30 | Reinforcement Learning`` (https://arxiv.org/abs/1806.06923) by
31 | Dabney et al. 2018.
32 | """
33 |
34 | def __init__(self, num_output: int, cos_embedding_dim: int = 128):
35 | super().__init__()
36 |
37 | self.output_layer = nn.Sequential(
38 | nn.Linear(cos_embedding_dim, cos_embedding_dim),
39 | nn.PReLU(),
40 | nn.Linear(cos_embedding_dim, num_output),
41 | )
42 |
43 | self.register_buffer("integers", torch.arange(0, cos_embedding_dim))
44 |
45 | def forward(self, tau: torch.Tensor) -> torch.Tensor: # tau: [B, T]
46 | cos_emb_tau = torch.cos(tau.unsqueeze(-1) * self.integers * torch.pi)
47 | return self.output_layer(cos_emb_tau)
48 |
49 |
50 | class ImplicitQuantileModule(nn.Module):
51 | r"""
52 | Implicit Quantile Network from the paper ``IQN for Distributional
53 | Reinforcement Learning`` (https://arxiv.org/abs/1806.06923) by
54 | Dabney et al. 2018.
55 | """
56 |
57 | def __init__(
58 | self,
59 | in_features: int,
60 | args_dim: Dict[str, int],
61 | domain_map: Callable[..., Tuple[torch.Tensor]],
62 | concentration1: float = 1.0,
63 | concentration0: float = 1.0,
64 | output_domain_map=None,
65 | cos_embedding_dim: int = 64,
66 | ):
67 | super().__init__()
68 | self.output_domain_map = output_domain_map
69 | self.domain_map = domain_map
70 | self.beta = Beta(concentration1=concentration1, concentration0=concentration0)
71 |
72 | self.quantile_layer = QuantileLayer(
73 | in_features, cos_embedding_dim=cos_embedding_dim
74 | )
75 | self.output_layer = nn.Sequential(
76 | nn.Linear(in_features, in_features), nn.PReLU()
77 | )
78 |
79 | self.proj = nn.ModuleList(
80 | [nn.Linear(in_features, dim) for dim in args_dim.values()]
81 | )
82 |
83 | def forward(self, inputs: torch.Tensor):
84 | if self.training:
85 | taus = self.beta.sample(sample_shape=inputs.shape[:-1]).to(inputs.device)
86 | else:
87 | taus = torch.rand(size=inputs.shape[:-1], device=inputs.device)
88 |
89 | emb_taus = self.quantile_layer(taus)
90 | emb_inputs = inputs * (1.0 + emb_taus)
91 |
92 | emb_outputs = self.output_layer(emb_inputs)
93 | outputs = [proj(emb_outputs).squeeze(-1) for proj in self.proj]
94 | if self.output_domain_map is not None:
95 | outputs = [self.output_domain_map(output) for output in outputs]
96 | return (*self.domain_map(*outputs), taus)
97 |
98 |
99 | class ImplicitQuantileNetwork(Distribution):
100 | r"""
101 | Distribution class for the Implicit Quantile from which
102 | we can sample or calculate the quantile loss.
103 |
104 | Parameters
105 | ----------
106 | outputs
107 | Outputs from the Implicit Quantile Network.
108 | taus
109 | Tensor random numbers from the Beta or Uniform distribution for the
110 | corresponding outputs.
111 | """
112 |
113 | arg_constraints: Dict[str, constraints.Constraint] = {}
114 |
115 | def __init__(self, outputs: torch.Tensor, taus: torch.Tensor, validate_args=None):
116 | self.taus = taus
117 | self.outputs = outputs
118 |
119 | super().__init__(batch_shape=outputs.shape, validate_args=validate_args)
120 |
121 | @torch.no_grad()
122 | def sample(self, sample_shape=torch.Size()) -> torch.Tensor:
123 | return self.outputs
124 |
125 | def quantile_loss(self, value: torch.Tensor) -> torch.Tensor:
126 | # penalize by tau for under-predicting
127 | # and by 1-tau for over-predicting
128 | return (self.taus - (value < self.outputs).float()) * (value - self.outputs)
129 |
130 |
131 | class ImplicitQuantileNetworkOutput(DistributionOutput):
132 | r"""
133 | DistributionOutput class for the IQN from the paper
134 | ``Probabilistic Time Series Forecasting with Implicit Quantile Networks``
135 | (https://arxiv.org/abs/2107.03743) by Gouttes et al. 2021.
136 |
137 | Parameters
138 | ----------
139 | output_domain
140 | Optional domain mapping of the output. Can be "positive", "unit"
141 | or None.
142 | concentration1
143 | Alpha parameter of the Beta distribution when sampling the taus
144 | during training.
145 | concentration0
146 | Beta parameter of the Beta distribution when sampling the taus
147 | during training.
148 | cos_embedding_dim
149 | The embedding dimension for the taus embedding layer of IQN.
150 | Default is 64.
151 | """
152 |
153 | distr_cls = ImplicitQuantileNetwork
154 | args_dim = {"quantile_function": 1}
155 |
156 | @validated()
157 | def __init__(
158 | self,
159 | output_domain: Optional[str] = None,
160 | concentration1: float = 1.0,
161 | concentration0: float = 1.0,
162 | cos_embedding_dim: int = 64,
163 | ) -> None:
164 | super().__init__()
165 |
166 | self.concentration1 = concentration1
167 | self.concentration0 = concentration0
168 | self.cos_embedding_dim = cos_embedding_dim
169 |
170 | if output_domain in ["positive", "unit"]:
171 | output_domain_map_func = {
172 | "positive": F.softplus,
173 | "unit": partial(F.softmax, dim=-1),
174 | }
175 | self.output_domain_map = output_domain_map_func[output_domain]
176 | else:
177 | self.output_domain_map = None
178 |
179 | def get_args_proj(self, in_features: int) -> nn.Module:
180 | return ImplicitQuantileModule(
181 | in_features=in_features,
182 | args_dim=self.args_dim,
183 | output_domain_map=self.output_domain_map,
184 | domain_map=LambdaLayer(self.domain_map),
185 | concentration1=self.concentration1,
186 | concentration0=self.concentration0,
187 | cos_embedding_dim=self.cos_embedding_dim,
188 | )
189 |
190 | @classmethod
191 | def domain_map(cls, *args):
192 | return args
193 |
194 | def distribution(self, distr_args, loc=0, scale=None) -> ImplicitQuantileNetwork:
195 | (outputs, taus) = distr_args
196 |
197 | if scale is not None:
198 | outputs = outputs * scale
199 | if loc is not None:
200 | outputs = outputs + loc
201 | return self.distr_cls(outputs=outputs, taus=taus)
202 |
203 | @property
204 | def event_shape(self):
205 | return ()
206 |
207 | def loss(
208 | self,
209 | target: torch.Tensor,
210 | distr_args: Tuple[torch.Tensor, ...],
211 | loc: Optional[torch.Tensor] = None,
212 | scale: Optional[torch.Tensor] = None,
213 | ) -> torch.Tensor:
214 | distribution = self.distribution(distr_args, loc=loc, scale=scale)
215 | return distribution.quantile_loss(target)
216 |
217 |
218 | iqn = ImplicitQuantileNetworkOutput()
219 |
--------------------------------------------------------------------------------
/gluon_utils/scalers/robust_scaler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import torch
18 | from gluonts.core.component import validated
19 | from gluonts.torch.scaler import Scaler
20 |
21 |
22 | class RobustScaler(Scaler):
23 | """
24 | Computes a scaling factor by removing the median and scaling by the
25 | interquartile range (IQR).
26 |
27 | Parameters
28 | ----------
29 | dim
30 | dimension along which to compute the scale
31 | keepdim
32 | controls whether to retain dimension ``dim`` (of length 1) in the
33 | scale tensor, or suppress it.
34 | minimum_scale
35 | minimum possible scale that is used for any item.
36 | """
37 |
38 | @validated()
39 | def __init__(
40 | self,
41 | dim: int = -1,
42 | keepdim: bool = False,
43 | minimum_scale: float = 1e-10,
44 | ) -> None:
45 | self.dim = dim
46 | self.keepdim = keepdim
47 | self.minimum_scale = minimum_scale
48 |
49 | def __call__(
50 | self, data: torch.Tensor, weights: torch.Tensor
51 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
52 | assert (
53 | data.shape == weights.shape
54 | ), "data and observed_indicator must have same shape"
55 |
56 | with torch.no_grad():
57 | observed_data = torch.where(weights == 1, data, torch.nan)
58 |
59 | med = torch.nanmedian(observed_data, dim=self.dim, keepdim=True).values
60 | q1 = torch.nanquantile(observed_data, 0.25, dim=self.dim, keepdim=True)
61 | q3 = torch.nanquantile(observed_data, 0.75, dim=self.dim, keepdim=True)
62 | iqr = q3 - q1
63 |
64 | # if observed data is all zeros, nanmedian returns nan
65 | loc = torch.where(torch.isnan(med), torch.zeros_like(med), med)
66 | scale = torch.where(torch.isnan(iqr), torch.ones_like(iqr), iqr)
67 | scale = torch.maximum(scale, torch.full_like(iqr, self.minimum_scale))
68 |
69 | scaled_data = (data - loc) / scale
70 |
71 | if not self.keepdim:
72 | loc = torch.squeeze(loc, dim=self.dim)
73 | scale = torch.squeeze(scale, dim=self.dim)
74 |
75 | # assert no nans in scaled data, loc or scale
76 | assert not torch.any(torch.isnan(scaled_data))
77 | assert not torch.any(torch.isnan(loc))
78 | assert not torch.any(torch.isnan(scale))
79 | assert not torch.any(scale == 0)
80 |
81 | return scaled_data, loc, scale
82 |
--------------------------------------------------------------------------------
/images/lagllama.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/time-series-foundation-models/lag-llama/013ff0c786c6fee677a360b76df95685c9dac25d/images/lagllama.webp
--------------------------------------------------------------------------------
/lag_llama/gluon/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
--------------------------------------------------------------------------------
/lag_llama/gluon/estimator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any, Dict, Iterable, Optional
16 |
17 | import pytorch_lightning as pl
18 | import torch
19 |
20 | from gluonts.core.component import validated
21 | from gluonts.dataset.common import Dataset
22 | from gluonts.dataset.field_names import FieldName
23 | from gluonts.dataset.loader import as_stacked_batches
24 | from gluonts.dataset.stat import calculate_dataset_statistics
25 | from gluonts.itertools import Cyclic
26 | from gluonts.time_feature import (
27 | get_lags_for_frequency,
28 | time_features_from_frequency_str,
29 | )
30 | from gluonts.torch.distributions import StudentTOutput, NegativeBinomialOutput
31 | from gluonts.torch.model.estimator import PyTorchLightningEstimator
32 | from gluonts.torch.model.predictor import PyTorchPredictor
33 | from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
34 | from gluonts.transform import (
35 | AddObservedValuesIndicator,
36 | AddTimeFeatures,
37 | Chain,
38 | DummyValueImputation,
39 | ExpectedNumInstanceSampler,
40 | InstanceSampler,
41 | InstanceSplitter,
42 | TestSplitSampler,
43 | Transformation,
44 | ValidationSplitSampler,
45 | )
46 |
47 | from gluon_utils.gluon_ts_distributions.implicit_quantile_network import (
48 | ImplicitQuantileNetworkOutput,
49 | )
50 | from lag_llama.gluon.lightning_module import LagLlamaLightningModule
51 |
52 | PREDICTION_INPUT_NAMES = [
53 | "past_target",
54 | "past_observed_values",
55 | ]
56 | TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
57 | "future_target",
58 | "future_observed_values",
59 | ]
60 |
61 |
62 | class LagLlamaEstimator(PyTorchLightningEstimator):
63 | """
64 | An estimator training a ConvTSMixer model for forecasting.
65 |
66 | This class is uses the model defined in ``ConvTSMixerModel``,
67 | and wraps it into a ``ConvTSMixerLightningModule`` for training
68 | purposes: training is performed using PyTorch Lightning's ``pl.Trainer``
69 | class.
70 |
71 | Parameters
72 | ----------
73 | prediction_length
74 | Length of the prediction horizon.
75 | context_length
76 | Number of time steps prior to prediction time that the model
77 | takes as inputs (default: ``10 * prediction_length``).
78 | lr
79 | Learning rate (default: ``1e-3``).
80 | weight_decay
81 | Weight decay regularization parameter (default: ``1e-8``).
82 | distr_output
83 | Distribution to use to evaluate observations and sample predictions
84 | (default: StudentTOutput()).
85 | loss
86 | Loss to be optimized during training
87 | (default: ``NegativeLogLikelihood()``).
88 | batch_norm
89 | Whether to apply batch normalization.
90 | batch_size
91 | The size of the batches to be used for training (default: 32).
92 | num_batches_per_epoch
93 | Number of batches to be processed in each training epoch
94 | (default: 50).
95 | trainer_kwargs
96 | Additional arguments to provide to ``pl.Trainer`` for construction.
97 | train_sampler
98 | Controls the sampling of windows during training.
99 | validation_sampler
100 | Controls the sampling of windows during validation.
101 | use_single_pass_sampling
102 | If True, use a single forward pass and sample N times from the saved distribution, much more efficient.
103 | If False, perform N forward passes and maintain N parallel prediction paths, this is true probalistic forecasting.
104 | (default: False)
105 | """
106 |
107 | @validated()
108 | def __init__(
109 | self,
110 | prediction_length: int,
111 | context_length: Optional[int] = None,
112 | input_size: int = 1,
113 | n_layer: int = 1,
114 | n_embd_per_head: int = 32,
115 | n_head: int = 4,
116 | max_context_length: int = 2048,
117 | rope_scaling=None,
118 | scaling: Optional[str] = "mean",
119 | lr: float = 1e-3,
120 | weight_decay: float = 1e-8,
121 | # Augmentations arguments
122 | aug_prob: float = 0.1,
123 | freq_mask_rate: float = 0.1,
124 | freq_mixing_rate: float = 0.1,
125 | jitter_prob: float = 0.0,
126 | jitter_sigma: float = 0.03,
127 | scaling_prob: float = 0.0,
128 | scaling_sigma: float = 0.1,
129 | rotation_prob: float = 0.0,
130 | permutation_prob: float = 0.0,
131 | permutation_max_segments: int = 5,
132 | permutation_seg_mode: str = "equal",
133 | magnitude_warp_prob: float = 0.0,
134 | magnitude_warp_sigma: float = 0.2,
135 | magnitude_warp_knot: int = 4,
136 | time_warp_prob: float = 0.0,
137 | time_warp_sigma: float = 0.2,
138 | time_warp_knot: int = 4,
139 | window_slice_prob: float = 0.0,
140 | window_slice_reduce_ratio: float = 0.9,
141 | window_warp_prob: float = 0.0,
142 | window_warp_window_ratio: float = 0.1,
143 | window_warp_scales: list = [0.5, 2.0],
144 | # Continuning model arguments
145 | distr_output: str = "studentT",
146 | loss: DistributionLoss = NegativeLogLikelihood(),
147 | num_parallel_samples: int = 100,
148 | batch_size: int = 32,
149 | num_batches_per_epoch: int = 50,
150 | trainer_kwargs: Optional[Dict[str, Any]] = None,
151 | train_sampler: Optional[InstanceSampler] = None,
152 | validation_sampler: Optional[InstanceSampler] = None,
153 | time_feat: bool = False,
154 | dropout: float = 0.0,
155 | lags_seq: list = ["Q", "M", "W", "D", "H", "T", "S"],
156 | data_id_to_name_map: dict = {},
157 | use_cosine_annealing_lr: bool = False,
158 | cosine_annealing_lr_args: dict = {},
159 | track_loss_per_series: bool = False,
160 | ckpt_path: Optional[str] = None,
161 | nonnegative_pred_samples: bool = False,
162 | use_single_pass_sampling: bool = False,
163 | device: torch.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
164 | ) -> None:
165 | default_trainer_kwargs = {"max_epochs": 100}
166 | if trainer_kwargs is not None:
167 | default_trainer_kwargs.update(trainer_kwargs)
168 | super().__init__(trainer_kwargs=default_trainer_kwargs)
169 |
170 | self.scaling = scaling
171 | self.input_size = input_size
172 | self.prediction_length = prediction_length
173 | self.context_length = context_length
174 | self.max_context_length = max_context_length
175 |
176 | lag_indices = []
177 | for freq in lags_seq:
178 | lag_indices.extend(
179 | get_lags_for_frequency(freq_str=freq, num_default_lags=1)
180 | )
181 |
182 | if len(lag_indices):
183 | self.lags_seq = sorted(set(lag_indices))
184 | self.lags_seq = [lag_index - 1 for lag_index in self.lags_seq]
185 | else:
186 | self.lags_seq = []
187 |
188 | self.n_head = n_head
189 | self.n_layer = n_layer
190 | self.n_embd_per_head = n_embd_per_head
191 | self.rope_scaling = rope_scaling
192 |
193 | self.lr = lr
194 | self.weight_decay = weight_decay
195 | if distr_output == "studentT":
196 | distr_output = StudentTOutput()
197 | elif distr_output == "neg_bin":
198 | distr_output = NegativeBinomialOutput()
199 | elif distr_output == "iqn":
200 | distr_output = ImplicitQuantileNetworkOutput()
201 | self.distr_output = distr_output
202 | self.num_parallel_samples = num_parallel_samples
203 | self.loss = loss
204 | self.batch_size = batch_size
205 | self.num_batches_per_epoch = num_batches_per_epoch
206 | self.nonnegative_pred_samples = nonnegative_pred_samples
207 | self.use_single_pass_sampling = use_single_pass_sampling
208 |
209 | self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
210 | num_instances=1.0,
211 | min_future=prediction_length,
212 | min_instances=1,
213 | )
214 | self.validation_sampler = validation_sampler or ValidationSplitSampler(
215 | min_future=prediction_length
216 | )
217 |
218 | self.aug_prob = aug_prob
219 | self.freq_mask_rate = freq_mask_rate
220 | self.freq_mixing_rate = freq_mixing_rate
221 | self.jitter_prob = jitter_prob
222 | self.jitter_sigma = jitter_sigma
223 | self.scaling_prob = scaling_prob
224 | self.scaling_sigma = scaling_sigma
225 | self.rotation_prob = rotation_prob
226 | self.permutation_prob = permutation_prob
227 | self.permutation_max_segments = permutation_max_segments
228 | self.permutation_seg_mode = permutation_seg_mode
229 | self.magnitude_warp_prob = magnitude_warp_prob
230 | self.magnitude_warp_sigma = magnitude_warp_sigma
231 | self.magnitude_warp_knot = magnitude_warp_knot
232 | self.time_warp_prob = time_warp_prob
233 | self.time_warp_sigma = time_warp_sigma
234 | self.time_warp_knot = time_warp_knot
235 | self.window_slice_prob = window_slice_prob
236 | self.window_slice_reduce_ratio = window_slice_reduce_ratio
237 | self.window_warp_prob = window_warp_prob
238 | self.window_warp_window_ratio = window_warp_window_ratio
239 | self.window_warp_scales = window_warp_scales
240 | self.track_loss_per_series = track_loss_per_series
241 |
242 | self.time_feat = time_feat
243 | self.dropout = dropout
244 | self.data_id_to_name_map = data_id_to_name_map
245 | self.ckpt_path = ckpt_path
246 |
247 | self.use_cosine_annealing_lr = use_cosine_annealing_lr
248 | self.cosine_annealing_lr_args = cosine_annealing_lr_args
249 | self.device = device
250 |
251 | @classmethod
252 | def derive_auto_fields(cls, train_iter):
253 | stats = calculate_dataset_statistics(train_iter)
254 |
255 | return {
256 | "num_feat_dynamic_real": stats.num_feat_dynamic_real,
257 | "num_feat_static_cat": len(stats.feat_static_cat),
258 | "cardinality": [len(cats) for cats in stats.feat_static_cat],
259 | }
260 |
261 | def create_transformation(self) -> Transformation:
262 | if self.time_feat:
263 | return Chain(
264 | [
265 | AddTimeFeatures(
266 | start_field=FieldName.START,
267 | target_field=FieldName.TARGET,
268 | output_field=FieldName.FEAT_TIME,
269 | time_features=time_features_from_frequency_str("S"),
270 | pred_length=self.prediction_length,
271 | ),
272 | AddObservedValuesIndicator(
273 | target_field=FieldName.TARGET,
274 | output_field=FieldName.OBSERVED_VALUES,
275 | imputation_method=DummyValueImputation(0.0),
276 | ),
277 | ]
278 | )
279 | else:
280 | return Chain(
281 | [
282 | AddObservedValuesIndicator(
283 | target_field=FieldName.TARGET,
284 | output_field=FieldName.OBSERVED_VALUES,
285 | imputation_method=DummyValueImputation(0.0),
286 | ),
287 | ]
288 | )
289 |
290 | def create_lightning_module(self, use_kv_cache: bool = False) -> pl.LightningModule:
291 | model_kwargs = {
292 | "input_size": self.input_size,
293 | "context_length": self.context_length,
294 | "max_context_length": self.max_context_length,
295 | "lags_seq": self.lags_seq,
296 | "n_layer": self.n_layer,
297 | "n_embd_per_head": self.n_embd_per_head,
298 | "n_head": self.n_head,
299 | "scaling": self.scaling,
300 | "distr_output": self.distr_output,
301 | "num_parallel_samples": self.num_parallel_samples,
302 | "rope_scaling": self.rope_scaling,
303 | "time_feat": self.time_feat,
304 | "dropout": self.dropout,
305 | }
306 | if self.ckpt_path is not None:
307 | return LagLlamaLightningModule.load_from_checkpoint(
308 | checkpoint_path=self.ckpt_path,
309 | map_location=self.device,
310 | strict=False,
311 | loss=self.loss,
312 | lr=self.lr,
313 | weight_decay=self.weight_decay,
314 | context_length=self.context_length,
315 | prediction_length=self.prediction_length,
316 | model_kwargs=model_kwargs,
317 | # Augmentations
318 | aug_prob=self.aug_prob,
319 | freq_mask_rate=self.freq_mask_rate,
320 | freq_mixing_rate=self.freq_mixing_rate,
321 | jitter_prob=self.jitter_prob,
322 | jitter_sigma=self.jitter_sigma,
323 | scaling_prob=self.scaling_prob,
324 | scaling_sigma=self.scaling_sigma,
325 | rotation_prob=self.rotation_prob,
326 | permutation_prob=self.permutation_prob,
327 | permutation_max_segments=self.permutation_max_segments,
328 | permutation_seg_mode=self.permutation_seg_mode,
329 | magnitude_warp_prob=self.magnitude_warp_prob,
330 | magnitude_warp_sigma=self.magnitude_warp_sigma,
331 | magnitude_warp_knot=self.magnitude_warp_knot,
332 | time_warp_prob=self.time_warp_prob,
333 | time_warp_sigma=self.time_warp_sigma,
334 | time_warp_knot=self.time_warp_knot,
335 | window_slice_prob=self.window_slice_prob,
336 | window_slice_reduce_ratio=self.window_slice_reduce_ratio,
337 | window_warp_prob=self.window_warp_prob,
338 | window_warp_window_ratio=self.window_warp_window_ratio,
339 | window_warp_scales=self.window_warp_scales,
340 | use_kv_cache=use_kv_cache,
341 | data_id_to_name_map=self.data_id_to_name_map,
342 | use_cosine_annealing_lr=self.use_cosine_annealing_lr,
343 | cosine_annealing_lr_args=self.cosine_annealing_lr_args,
344 | track_loss_per_series=self.track_loss_per_series,
345 | nonnegative_pred_samples=self.nonnegative_pred_samples,
346 | )
347 | else:
348 | return LagLlamaLightningModule(
349 | loss=self.loss,
350 | lr=self.lr,
351 | weight_decay=self.weight_decay,
352 | context_length=self.context_length,
353 | prediction_length=self.prediction_length,
354 | model_kwargs=model_kwargs,
355 | # Augmentations
356 | aug_prob=self.aug_prob,
357 | freq_mask_rate=self.freq_mask_rate,
358 | freq_mixing_rate=self.freq_mixing_rate,
359 | jitter_prob=self.jitter_prob,
360 | jitter_sigma=self.jitter_sigma,
361 | scaling_prob=self.scaling_prob,
362 | scaling_sigma=self.scaling_sigma,
363 | rotation_prob=self.rotation_prob,
364 | permutation_prob=self.permutation_prob,
365 | permutation_max_segments=self.permutation_max_segments,
366 | permutation_seg_mode=self.permutation_seg_mode,
367 | magnitude_warp_prob=self.magnitude_warp_prob,
368 | magnitude_warp_sigma=self.magnitude_warp_sigma,
369 | magnitude_warp_knot=self.magnitude_warp_knot,
370 | time_warp_prob=self.time_warp_prob,
371 | time_warp_sigma=self.time_warp_sigma,
372 | time_warp_knot=self.time_warp_knot,
373 | window_slice_prob=self.window_slice_prob,
374 | window_slice_reduce_ratio=self.window_slice_reduce_ratio,
375 | window_warp_prob=self.window_warp_prob,
376 | window_warp_window_ratio=self.window_warp_window_ratio,
377 | window_warp_scales=self.window_warp_scales,
378 | use_kv_cache=use_kv_cache,
379 | data_id_to_name_map=self.data_id_to_name_map,
380 | use_cosine_annealing_lr=self.use_cosine_annealing_lr,
381 | cosine_annealing_lr_args=self.cosine_annealing_lr_args,
382 | track_loss_per_series=self.track_loss_per_series,
383 | nonnegative_pred_samples=self.nonnegative_pred_samples,
384 | )
385 |
386 | def _create_instance_splitter(self, module: LagLlamaLightningModule, mode: str):
387 | assert mode in ["training", "validation", "test"]
388 |
389 | instance_sampler = {
390 | "training": self.train_sampler,
391 | "validation": self.validation_sampler,
392 | "test": TestSplitSampler(),
393 | }[mode]
394 |
395 | return InstanceSplitter(
396 | target_field=FieldName.TARGET,
397 | is_pad_field=FieldName.IS_PAD,
398 | start_field=FieldName.START,
399 | forecast_start_field=FieldName.FORECAST_START,
400 | instance_sampler=instance_sampler,
401 | past_length=self.context_length + max(self.lags_seq),
402 | future_length=self.prediction_length,
403 | time_series_fields=[FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES]
404 | if self.time_feat
405 | else [FieldName.OBSERVED_VALUES],
406 | dummy_value=self.distr_output.value_in_support,
407 | )
408 |
409 | def create_training_data_loader(
410 | self,
411 | data: Dataset,
412 | module: LagLlamaLightningModule,
413 | shuffle_buffer_length: Optional[int] = None,
414 | **kwargs,
415 | ) -> Iterable:
416 | data = Cyclic(data).stream()
417 | instances = self._create_instance_splitter(module, "training").apply(
418 | data, is_train=True
419 | )
420 | if self.time_feat:
421 | return as_stacked_batches(
422 | instances,
423 | batch_size=self.batch_size,
424 | shuffle_buffer_length=shuffle_buffer_length,
425 | field_names=TRAINING_INPUT_NAMES
426 | + ["past_time_feat", "future_time_feat"],
427 | output_type=torch.tensor,
428 | num_batches_per_epoch=self.num_batches_per_epoch,
429 | )
430 |
431 | else:
432 | return as_stacked_batches(
433 | instances,
434 | batch_size=self.batch_size,
435 | shuffle_buffer_length=shuffle_buffer_length,
436 | field_names=TRAINING_INPUT_NAMES,
437 | output_type=torch.tensor,
438 | num_batches_per_epoch=self.num_batches_per_epoch,
439 | )
440 |
441 | def create_validation_data_loader(
442 | self,
443 | data: Dataset,
444 | module: LagLlamaLightningModule,
445 | **kwargs,
446 | ) -> Iterable:
447 | instances = self._create_instance_splitter(module, "validation").apply(
448 | data, is_train=True
449 | )
450 | if self.time_feat:
451 | return as_stacked_batches(
452 | instances,
453 | batch_size=self.batch_size,
454 | field_names=TRAINING_INPUT_NAMES
455 | + ["past_time_feat", "future_time_feat"],
456 | output_type=torch.tensor,
457 | )
458 | else:
459 | return as_stacked_batches(
460 | instances,
461 | batch_size=self.batch_size,
462 | field_names=TRAINING_INPUT_NAMES,
463 | output_type=torch.tensor,
464 | )
465 |
466 | def create_predictor(
467 | self,
468 | transformation: Transformation,
469 | module,
470 | ) -> PyTorchPredictor:
471 | prediction_splitter = self._create_instance_splitter(module, "test")
472 | if self.time_feat:
473 | return PyTorchPredictor(
474 | input_transform=transformation + prediction_splitter,
475 | input_names=PREDICTION_INPUT_NAMES
476 | + ["past_time_feat", "future_time_feat"],
477 | prediction_net=module,
478 | batch_size=self.batch_size,
479 | prediction_length=self.prediction_length,
480 | device="cuda" if torch.cuda.is_available() else "cpu",
481 | )
482 | else:
483 | return PyTorchPredictor(
484 | input_transform=transformation + prediction_splitter,
485 | input_names=PREDICTION_INPUT_NAMES,
486 | prediction_net=module,
487 | batch_size=self.batch_size,
488 | prediction_length=self.prediction_length,
489 | device="cuda" if torch.cuda.is_available() else "cpu",
490 | )
491 |
--------------------------------------------------------------------------------
/lag_llama/gluon/lightning_module.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import random
16 |
17 | import numpy as np
18 |
19 | from lightning import LightningModule
20 | import torch
21 | import torch.nn.functional as F
22 |
23 | from gluonts.core.component import validated
24 | from gluonts.itertools import prod
25 | from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
26 | from gluonts.torch.util import repeat_along_dim, take_last
27 |
28 | from data.augmentations.freq_mask import freq_mask
29 | from data.augmentations.freq_mix import freq_mix
30 | from data.augmentations.augmentations import (
31 | ApplyAugmentations,
32 | Jitter,
33 | MagnitudeWarp,
34 | Permutation,
35 | Rotation,
36 | Scaling,
37 | TimeWarp,
38 | WindowSlice,
39 | WindowWarp,
40 | )
41 | from gluon_utils.gluon_ts_distributions.implicit_quantile_network import (
42 | ImplicitQuantileNetworkOutput,
43 | )
44 | from lag_llama.model.module import LagLlamaModel
45 |
46 |
47 | class LagLlamaLightningModule(LightningModule):
48 | """
49 | A ``pl.LightningModule`` class that can be used to train a
50 | ``LagLlamaLightningModule`` with PyTorch Lightning.
51 |
52 | This is a thin layer around a (wrapped) ``LagLlamaLightningModule`` object,
53 | that exposes the methods to evaluate training and validation loss.
54 |
55 | Parameters
56 | ----------
57 | model
58 | ``LagLlamaLightningModule`` to be trained.
59 | loss
60 | Loss function to be used for training,
61 | default: ``NegativeLogLikelihood()``.
62 | lr
63 | Learning rate, default: ``1e-3``.
64 | weight_decay
65 | Weight decay regularization parameter, default: ``1e-8``.
66 | """
67 |
68 | @validated()
69 | def __init__(
70 | self,
71 | model_kwargs: dict,
72 | context_length: int,
73 | prediction_length: int,
74 | loss: DistributionLoss = NegativeLogLikelihood(),
75 | lr: float = 1e-3,
76 | weight_decay: float = 1e-8,
77 | aug_prob: float = 0.1,
78 | freq_mask_rate: float = 0.1,
79 | freq_mixing_rate: float = 0.1,
80 | jitter_prob: float = 0.0,
81 | jitter_sigma: float = 0.03,
82 | scaling_prob: float = 0.0,
83 | scaling_sigma: float = 0.1,
84 | rotation_prob: float = 0.0,
85 | permutation_prob: float = 0.0,
86 | permutation_max_segments: int = 5,
87 | permutation_seg_mode: str = "equal",
88 | magnitude_warp_prob: float = 0.0,
89 | magnitude_warp_sigma: float = 0.2,
90 | magnitude_warp_knot: int = 4,
91 | time_warp_prob: float = 0.0,
92 | time_warp_sigma: float = 0.2,
93 | time_warp_knot: int = 4,
94 | window_slice_prob: float = 0.0,
95 | window_slice_reduce_ratio: float = 0.9,
96 | window_warp_prob: float = 0.0,
97 | window_warp_window_ratio: float = 0.1,
98 | window_warp_scales: list = [0.5, 2.0],
99 | data_id_to_name_map: dict = {},
100 | use_cosine_annealing_lr: bool = False,
101 | cosine_annealing_lr_args: dict = {},
102 | track_loss_per_series: bool = False,
103 | nonnegative_pred_samples: bool = False,
104 | use_kv_cache: bool = True,
105 | use_single_pass_sampling: bool = False,
106 | ):
107 | super().__init__()
108 | self.save_hyperparameters()
109 | self.context_length = self.hparams.context_length
110 | self.prediction_length = self.hparams.prediction_length
111 | self.model = LagLlamaModel(**self.hparams.model_kwargs)
112 | self.loss = self.hparams.loss
113 | self.lr = self.hparams.lr
114 | self.weight_decay = self.hparams.weight_decay
115 | self.aug_prob = self.hparams.aug_prob
116 | self.freq_mask_rate = self.hparams.freq_mask_rate
117 | self.freq_mixing_rate = self.hparams.freq_mixing_rate
118 | self.jitter_prob = self.hparams.jitter_prob
119 | self.jitter_sigma = self.hparams.jitter_sigma
120 | self.scaling_prob = self.hparams.scaling_prob
121 | self.scaling_sigma = self.hparams.scaling_sigma
122 | self.rotation_prob = self.hparams.rotation_prob
123 | self.permutation_prob = self.hparams.permutation_prob
124 | self.permutation_max_segments = self.hparams.permutation_max_segments
125 | self.permutation_seg_mode = self.hparams.permutation_seg_mode
126 | self.magnitude_warp_prob = self.hparams.magnitude_warp_prob
127 | self.magnitude_warp_sigma = self.hparams.magnitude_warp_sigma
128 | self.magnitude_warp_knot = self.hparams.magnitude_warp_knot
129 | self.time_warp_prob = self.hparams.time_warp_prob
130 | self.time_warp_sigma = self.hparams.time_warp_sigma
131 | self.time_warp_knot = self.hparams.time_warp_knot
132 | self.window_slice_prob = self.hparams.window_slice_prob
133 | self.window_slice_reduce_ratio = self.hparams.window_slice_reduce_ratio
134 | self.window_warp_prob = self.hparams.window_warp_prob
135 | self.window_warp_window_ratio = self.hparams.window_warp_window_ratio
136 | self.window_warp_scales = self.hparams.window_warp_scales
137 | self.data_id_to_name_map = self.hparams.data_id_to_name_map
138 | self.use_cosine_annealing_lr = self.hparams.use_cosine_annealing_lr
139 | self.cosine_annealing_lr_args = self.hparams.cosine_annealing_lr_args
140 | self.track_loss_per_series = self.hparams.track_loss_per_series
141 | self.nonnegative_pred_samples = self.hparams.nonnegative_pred_samples
142 |
143 | self.time_feat = self.hparams.model_kwargs["time_feat"]
144 | # data_id based
145 | self.train_loss_dict = {}
146 | self.val_loss_dict = {}
147 | # item_id based - to be used only in single-dataset mode
148 | self.train_loss_dict_per_series = {}
149 | self.val_loss_dict_per_series = {}
150 | self.use_kv_cache = use_kv_cache
151 | self.use_single_pass_sampling = use_single_pass_sampling
152 | self.transforms = []
153 | aug_probs = dict(
154 | Jitter=dict(prob=self.jitter_prob, sigma=self.jitter_sigma),
155 | Scaling=dict(prob=self.scaling_prob, sigma=self.scaling_sigma),
156 | Rotation=dict(prob=self.rotation_prob),
157 | Permutation=dict(
158 | prob=self.permutation_prob,
159 | max_segments=self.permutation_max_segments,
160 | seg_mode=self.permutation_seg_mode,
161 | ),
162 | MagnitudeWarp=dict(
163 | prob=self.magnitude_warp_prob,
164 | sigma=self.magnitude_warp_sigma,
165 | knot=self.magnitude_warp_knot,
166 | ),
167 | TimeWarp=dict(
168 | prob=self.time_warp_prob,
169 | sigma=self.time_warp_sigma,
170 | knot=self.time_warp_knot,
171 | ),
172 | WindowSlice=dict(
173 | prob=self.window_slice_prob, reduce_ratio=self.window_slice_reduce_ratio
174 | ),
175 | WindowWarp=dict(
176 | prob=self.window_warp_prob,
177 | window_ratio=self.window_warp_window_ratio,
178 | warp_slices=self.window_warp_scales,
179 | ),
180 | )
181 | for aug, params in aug_probs.items():
182 | if params["prob"] > 0:
183 | if aug == "Jitter":
184 | self.transforms.append(Jitter(params["prob"], params["sigma"]))
185 | elif aug == "Scaling":
186 | self.transforms.append(Scaling(params["prob"], params["sigma"]))
187 | elif aug == "Rotation":
188 | self.transforms.append(Rotation(params["prob"]))
189 | elif aug == "Permutation":
190 | self.transforms.append(
191 | Permutation(
192 | params["prob"], params["max_segments"], params["seg_mode"]
193 | )
194 | )
195 | elif aug == "MagnitudeWarp":
196 | self.transforms.append(
197 | MagnitudeWarp(params["prob"], params["sigma"], params["knot"])
198 | )
199 | elif aug == "TimeWarp":
200 | self.transforms.append(
201 | TimeWarp(params["prob"], params["sigma"], params["knot"])
202 | )
203 | elif aug == "WindowSlice":
204 | self.transforms.append(
205 | WindowSlice(params["prob"], params["reduce_ratio"])
206 | )
207 | elif aug == "WindowWarp":
208 | self.transforms.append(
209 | WindowWarp(
210 | params["prob"],
211 | params["window_ratio"],
212 | params["warp_slices"],
213 | )
214 | )
215 |
216 | self.augmentations = ApplyAugmentations(self.transforms)
217 |
218 | # greedy prediction
219 | def forward(self, *args, **kwargs):
220 | past_target = kwargs[
221 | "past_target"
222 | ] # (bsz, model.context_length+max(model.lags_seq))
223 | past_observed_values = kwargs[
224 | "past_observed_values"
225 | ] # (bsz, model.context_length+max(model.lags_seq))
226 | if self.time_feat:
227 | past_time_feat = kwargs["past_time_feat"]
228 | future_time_feat = kwargs["future_time_feat"]
229 |
230 | use_single_pass_sampling = self.use_single_pass_sampling
231 |
232 | future_samples = []
233 |
234 | if use_single_pass_sampling:
235 | # Single-pass sampling mode: Single forward pass per step, save distribution parameters, sample `num_parallel_samples` times, add mean to context.
236 | for t in range(self.prediction_length):
237 | params, loc, scale = self.model(
238 | *args,
239 | past_time_feat=past_time_feat if self.time_feat else None,
240 | future_time_feat=future_time_feat[..., : t + 1, :] if self.time_feat else None,
241 | past_target=past_target,
242 | past_observed_values=past_observed_values,
243 | use_kv_cache=self.use_kv_cache,
244 | )
245 |
246 | sliced_params = [
247 | p[:, -1:] for p in params
248 | ] # Take the last timestep predicted. Each tensor is of shape (#bsz, 1)
249 | # Singular distribution is used for getting the greedy prediction (mean)
250 | distr = self.model.distr_output.distribution(sliced_params, loc, scale)
251 | greedy_prediction = distr.mean # (#bsz, 1)
252 |
253 | repeated_sliced_params = [
254 | p[:, -1:].repeat_interleave(
255 | self.model.num_parallel_samples, 0
256 | ) for p in params
257 | ] # Take the last timestep predicted and repeat for number of samples. Each tensor is of shape (#bsz*#parallel_samples, 1)
258 | repeated_loc = loc.repeat_interleave(
259 | self.model.num_parallel_samples, 0
260 | )
261 | repeated_scale = scale.repeat_interleave(
262 | self.model.num_parallel_samples, 0
263 | )
264 | # Repeated distribution is used for getting the parallel samples
265 | # (distr.sample([self.model.num_parallel_samples]) seems to give terrible results)
266 | repeated_distr = self.model.distr_output.distribution(repeated_sliced_params, repeated_loc, repeated_scale)
267 | sample = repeated_distr.sample() # (#bsz*#parallel_samples, 1)
268 | if self.nonnegative_pred_samples:
269 | sample = F.relu(sample)
270 | future_samples.append(sample)
271 |
272 | past_target = torch.cat((past_target, greedy_prediction), dim=1)
273 | past_observed_values = torch.cat(
274 | (past_observed_values, torch.ones_like(greedy_prediction)), dim=1
275 | )
276 | else:
277 | # Original probabilistic forecasting: Duplicate input, `num_parallel_samples` forward passes per step, sample each distribution once, add samples to context.
278 | repeated_past_target = past_target.repeat_interleave(self.model.num_parallel_samples, 0)
279 | repeated_past_observed_values = past_observed_values.repeat_interleave(self.model.num_parallel_samples, 0)
280 | if self.time_feat:
281 | repeated_past_time_feat = past_time_feat.repeat_interleave(self.model.num_parallel_samples, 0)
282 | repeated_future_time_feat = future_time_feat.repeat_interleave(self.model.num_parallel_samples, 0)
283 |
284 | for t in range(self.prediction_length):
285 | if self.time_feat:
286 | params, loc, scale = self.model(
287 | *args,
288 | past_time_feat=repeated_past_time_feat,
289 | future_time_feat=repeated_future_time_feat[..., : t + 1, :],
290 | past_target=repeated_past_target,
291 | past_observed_values=repeated_past_observed_values,
292 | use_kv_cache=self.use_kv_cache,
293 | )
294 | else:
295 | params, loc, scale = self.model(
296 | *args,
297 | past_time_feat=None,
298 | future_time_feat=None,
299 | past_target=repeated_past_target,
300 | past_observed_values=repeated_past_observed_values,
301 | use_kv_cache=self.use_kv_cache,
302 | )
303 |
304 | sliced_params = [p[:, -1:] for p in params]
305 | distr = self.model.distr_output.distribution(sliced_params, loc, scale)
306 | sample = distr.sample()
307 | if self.nonnegative_pred_samples:
308 | sample = F.relu(sample)
309 | future_samples.append(sample)
310 |
311 | repeated_past_target = torch.cat((repeated_past_target, sample), dim=1)
312 | repeated_past_observed_values = torch.cat(
313 | (repeated_past_observed_values, torch.ones_like(sample)), dim=1
314 | )
315 |
316 | self.model.reset_cache()
317 |
318 | concat_future_samples = torch.cat(future_samples, dim=-1)
319 | return concat_future_samples.reshape(
320 | (-1, self.model.num_parallel_samples, self.prediction_length)
321 | + self.model.distr_output.event_shape,
322 | )
323 |
324 |
325 | # train
326 | def _compute_loss(self, batch, do_not_average=False, return_observed_values=False):
327 | past_target = batch[
328 | "past_target"
329 | ] # (bsz, model.context_length+max(model.lags_seq))
330 | past_observed_values = batch[
331 | "past_observed_values"
332 | ] # (bsz, model.context_length+max(model.lags_seq)) with 0s or 1s indicating available (1s) or missing (0s)
333 | future_target = batch["future_target"] # (bsz, model.prediction_length)
334 | future_observed_values = batch[
335 | "future_observed_values"
336 | ] # (bsz, model.prediction_length) with 0s or 1s indicating available (1s) or missing (0s)
337 | if self.time_feat:
338 | past_time_feat = batch["past_time_feat"]
339 | future_time_feat = batch["future_time_feat"]
340 | else:
341 | past_time_feat = None
342 | future_time_feat = None
343 |
344 | extra_dims = len(future_target.shape) - len(past_target.shape) # usually 0
345 | extra_shape = future_target.shape[:extra_dims] # shape remains the same
346 |
347 | repeats = prod(extra_shape) # usually 1
348 | past_target = repeat_along_dim(
349 | past_target, 0, repeats
350 | ) # (bsz, model.context_length+max(model.lags_seq))
351 | past_observed_values = repeat_along_dim(
352 | past_observed_values, 0, repeats
353 | ) # (bsz, model.context_length+max(model.lags_seq))
354 |
355 | future_target_reshaped = future_target.reshape(
356 | -1,
357 | *future_target.shape[extra_dims + 1 :],
358 | ) # (bsz, model.prediction_length)
359 | future_observed_reshaped = future_observed_values.reshape(
360 | -1,
361 | *future_observed_values.shape[extra_dims + 1 :],
362 | ) # (bsz, model.prediction_length)
363 |
364 | distr_args, loc, scale = self.model(
365 | past_target=past_target,
366 | past_observed_values=past_observed_values,
367 | past_time_feat=past_time_feat,
368 | future_time_feat=future_time_feat,
369 | future_target=future_target_reshaped,
370 | ) # distr_args is a tuple with two tensors of shape (bsz, context_length+pred_len-1)
371 | context_target = take_last(
372 | past_target, dim=-1, num=self.context_length - 1
373 | ) # (bsz, context_length-1) # Basically removes the first value since it cannot be predicted
374 | target = torch.cat(
375 | (context_target, future_target_reshaped),
376 | dim=1,
377 | ) # (bsz, context_length-1+pred_len) # values that can be predicted
378 | context_observed = take_last(
379 | past_observed_values, dim=-1, num=self.context_length - 1
380 | ) # same as context_target, but for observed_values tensor
381 | observed_values = torch.cat(
382 | (context_observed, future_observed_reshaped), dim=1
383 | ) # same as target but for observed_values tensor
384 |
385 | if type(self.model.distr_output) == ImplicitQuantileNetworkOutput:
386 | if not do_not_average:
387 | loss = (
388 | self.model.distr_output.loss(target, distr_args, loc, scale)
389 | * observed_values
390 | ).sum() / observed_values.sum().clamp_min(1.0)
391 | else:
392 | loss = (
393 | self.model.distr_output.loss(target, distr_args, loc, scale)
394 | * observed_values
395 | )
396 | else:
397 | distr = self.model.distr_output.distribution(
398 | distr_args, loc=loc, scale=scale
399 | ) # an object representing a distribution with the specified parameters. We need this to compute the NLL loss.
400 | if not do_not_average:
401 | loss = (
402 | self.loss(distr, target) * observed_values
403 | ).sum() / observed_values.sum().clamp_min(1.0)
404 | else:
405 | loss = self.loss(distr, target) * observed_values
406 |
407 | if not return_observed_values:
408 | return loss
409 | else:
410 | return loss, observed_values
411 |
412 | def training_step(self, batch, batch_idx: int): # type: ignore
413 | """
414 | Execute training step.
415 | """
416 | if random.random() < self.aug_prob:
417 | # Freq mix and Freq mask have separate functions
418 | if self.freq_mask_rate > 0:
419 | batch["past_target"], batch["future_target"] = freq_mask(
420 | batch["past_target"],
421 | batch["future_target"],
422 | rate=self.freq_mask_rate,
423 | )
424 | if self.freq_mixing_rate:
425 | batch["past_target"], batch["future_target"] = freq_mix(
426 | batch["past_target"],
427 | batch["future_target"],
428 | rate=self.freq_mixing_rate,
429 | )
430 | # Other augmentation
431 | if len(self.transforms):
432 | batch["past_target"], batch["future_target"] = self.augmentations(
433 | batch["past_target"], batch["future_target"]
434 | )
435 |
436 | train_loss_per_sample, observed_values = self._compute_loss(
437 | batch, do_not_average=True, return_observed_values=True
438 | )
439 |
440 | train_loss_avg = train_loss_per_sample.sum() / observed_values.sum().clamp_min(
441 | 1.0
442 | )
443 | self.log(
444 | "train_loss", train_loss_avg, on_epoch=True, on_step=False, prog_bar=False
445 | )
446 | return train_loss_avg
447 |
448 | def on_train_epoch_end(self):
449 | # Log all losses
450 | for key, value in self.train_loss_dict.items():
451 | loss_avg = np.mean(value)
452 | self.log(
453 | f"train_loss_avg_per_train_dataset/{self.data_id_to_name_map[key]}",
454 | loss_avg,
455 | on_epoch=True,
456 | on_step=False,
457 | prog_bar=False,
458 | )
459 |
460 | if self.track_loss_per_series:
461 | # Log all losses
462 | for key, value in self.train_loss_dict_per_series.items():
463 | loss_avg = np.mean(value)
464 | self.log(
465 | f"train_loss_avg_per_train_series/{key}",
466 | loss_avg,
467 | on_epoch=True,
468 | on_step=False,
469 | prog_bar=False,
470 | )
471 |
472 | # Reset loss_dict
473 | self.train_loss_dict = {}
474 | self.train_loss_dict_per_series = {}
475 |
476 | def validation_step(self, batch, batch_idx: int): # type: ignore
477 | """
478 | Execute validation step.
479 | """
480 | val_loss_per_sample, observed_values = self._compute_loss(
481 | batch, do_not_average=True, return_observed_values=True
482 | )
483 |
484 | val_loss_avg = val_loss_per_sample.sum() / observed_values.sum().clamp_min(1.0)
485 | self.log("val_loss", val_loss_avg, on_epoch=True, on_step=False, prog_bar=False)
486 | return val_loss_avg
487 |
488 | def on_validation_epoch_end(self):
489 | # Log all losses
490 | for key, value in self.val_loss_dict.items():
491 | loss_avg = np.mean(value)
492 | if key >= 0:
493 | self.log(
494 | f"val_loss_avg_per_train_dataset/{self.data_id_to_name_map[key]}",
495 | loss_avg,
496 | on_epoch=True,
497 | on_step=False,
498 | prog_bar=False,
499 | )
500 | else:
501 | self.log(
502 | f"val_loss_avg_per_test_dataset/{self.data_id_to_name_map[key]}",
503 | loss_avg,
504 | on_epoch=True,
505 | on_step=False,
506 | prog_bar=False,
507 | )
508 |
509 | if self.track_loss_per_series:
510 | # Log all losses
511 | for key, value in self.val_loss_dict_per_series.items():
512 | loss_avg = np.mean(value)
513 | self.log(
514 | f"val_loss_avg_per_train_series/{key}",
515 | loss_avg,
516 | on_epoch=True,
517 | on_step=False,
518 | prog_bar=False,
519 | )
520 |
521 | # Reset loss_dict
522 | self.val_loss_dict = {}
523 | self.val_loss_dict_per_series = {}
524 |
525 | def configure_optimizers(self):
526 | """
527 | Returns the optimizer to use.
528 | """
529 | optimizer = torch.optim.Adam(
530 | self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay
531 | )
532 | if self.use_cosine_annealing_lr:
533 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
534 | optimizer, **self.cosine_annealing_lr_args, verbose=True
535 | )
536 | return {"optimizer": optimizer, "lr_scheduler": scheduler}
537 | else:
538 | return optimizer
539 |
--------------------------------------------------------------------------------
/lag_llama/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
--------------------------------------------------------------------------------
/lag_llama/model/module.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | from dataclasses import dataclass
17 | from typing import List, Optional
18 |
19 | import torch
20 | from torch import nn
21 | from torch.nn import functional as F
22 |
23 | from gluonts.torch.distributions import DistributionOutput
24 | from gluonts.torch.scaler import MeanScaler, NOPScaler, StdScaler
25 | from gluonts.torch.util import lagged_sequence_values, unsqueeze_expand
26 |
27 | from gluon_utils.scalers.robust_scaler import RobustScaler
28 |
29 |
30 | @dataclass
31 | class LTSMConfig:
32 | feature_size: int = 3 + 6 # target + loc + scale + time features
33 | block_size: int = 2048
34 | n_layer: int = 32
35 | n_head: int = 32
36 | n_embd_per_head: int = 128
37 | rope_scaling: Optional[dict] = None
38 | dropout: float = 0.0
39 |
40 |
41 | class Block(nn.Module):
42 | def __init__(self, config: LTSMConfig) -> None:
43 | super().__init__()
44 | self.rms_1 = RMSNorm(config.n_embd_per_head * config.n_head)
45 | self.attn = CausalSelfAttention(config)
46 | self.rms_2 = RMSNorm(config.n_embd_per_head * config.n_head)
47 | self.mlp = MLP(config)
48 |
49 | def forward(self, x: torch.Tensor, use_kv_cache: bool) -> torch.Tensor:
50 | x = x + self.attn(self.rms_1(x), use_kv_cache)
51 | y = x + self.mlp(self.rms_2(x))
52 | return y
53 |
54 |
55 | class LlamaRotaryEmbedding(torch.nn.Module):
56 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
57 | super().__init__()
58 |
59 | self.dim = dim
60 | self.max_position_embeddings = max_position_embeddings
61 | self.base = base
62 | inv_freq = 1.0 / (
63 | self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
64 | )
65 | self.register_buffer("inv_freq", inv_freq, persistent=False)
66 |
67 | # Build here to make `torch.jit.trace` work.
68 | self._set_cos_sin_cache(
69 | seq_len=max_position_embeddings,
70 | device=self.inv_freq.device,
71 | dtype=torch.get_default_dtype(),
72 | )
73 |
74 | def _set_cos_sin_cache(self, seq_len, device, dtype):
75 | self.max_seq_len_cached = seq_len
76 | t = torch.arange(
77 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
78 | )
79 |
80 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
81 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
82 | emb = torch.cat((freqs, freqs), dim=-1)
83 | self.register_buffer(
84 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
85 | )
86 | self.register_buffer(
87 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
88 | )
89 |
90 | def forward(self, device, dtype, seq_len=None):
91 | # x: [bs, num_attention_heads, seq_len, head_size]
92 | if seq_len > self.max_seq_len_cached:
93 | self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype)
94 |
95 | return (
96 | self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype),
97 | self.sin_cached[:, :, :seq_len, ...].to(dtype=dtype),
98 | )
99 |
100 |
101 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
102 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
103 |
104 | def __init__(
105 | self,
106 | dim,
107 | max_position_embeddings=2048,
108 | base=10000,
109 | device=None,
110 | scaling_factor=1.0,
111 | ):
112 | self.scaling_factor = scaling_factor
113 | super().__init__(dim, max_position_embeddings, base, device)
114 |
115 | def _set_cos_sin_cache(self, seq_len, device, dtype):
116 | self.max_seq_len_cached = seq_len
117 | t = torch.arange(
118 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
119 | )
120 | t = t / self.scaling_factor
121 |
122 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
123 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
124 | emb = torch.cat((freqs, freqs), dim=-1)
125 | self.register_buffer(
126 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
127 | )
128 | self.register_buffer(
129 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
130 | )
131 |
132 |
133 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
134 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
135 |
136 | def __init__(
137 | self,
138 | dim,
139 | max_position_embeddings=2048,
140 | base=10000,
141 | device=None,
142 | scaling_factor=1.0,
143 | ):
144 | self.scaling_factor = scaling_factor
145 | super().__init__(dim, max_position_embeddings, base, device)
146 |
147 | def _set_cos_sin_cache(self, seq_len, device, dtype):
148 | self.max_seq_len_cached = seq_len
149 |
150 | if seq_len > self.max_position_embeddings:
151 | base = self.base * (
152 | (self.scaling_factor * seq_len / self.max_position_embeddings)
153 | - (self.scaling_factor - 1)
154 | ) ** (self.dim / (self.dim - 2))
155 | inv_freq = 1.0 / (
156 | base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
157 | )
158 | self.register_buffer("inv_freq", inv_freq, persistent=False)
159 |
160 | t = torch.arange(
161 | self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
162 | )
163 |
164 | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
165 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
166 | emb = torch.cat((freqs, freqs), dim=-1)
167 | self.register_buffer(
168 | "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
169 | )
170 | self.register_buffer(
171 | "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
172 | )
173 |
174 |
175 | def rotate_half(x):
176 | """Rotates half the hidden dims of the input."""
177 | x1 = x[..., : x.shape[-1] // 2]
178 | x2 = x[..., x.shape[-1] // 2 :]
179 | return torch.cat((-x2, x1), dim=-1)
180 |
181 |
182 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
183 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
184 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
185 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
186 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
187 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
188 | q_embed = (q * cos) + (rotate_half(q) * sin)
189 | k_embed = (k * cos) + (rotate_half(k) * sin)
190 | return q_embed, k_embed
191 |
192 |
193 | class CausalSelfAttention(nn.Module):
194 | def __init__(self, config: LTSMConfig) -> None:
195 | super().__init__()
196 | # query projections for all heads, but in a batch
197 | self.q_proj = nn.Linear(
198 | config.n_embd_per_head * config.n_head,
199 | config.n_embd_per_head * config.n_head,
200 | bias=False,
201 | )
202 | # key, value projections
203 | self.kv_proj = nn.Linear(
204 | config.n_embd_per_head * config.n_head,
205 | 2 * config.n_embd_per_head * config.n_head,
206 | bias=False,
207 | )
208 | # output projection
209 | self.c_proj = nn.Linear(
210 | config.n_embd_per_head * config.n_head,
211 | config.n_embd_per_head * config.n_head,
212 | bias=False,
213 | )
214 |
215 | self.n_head = config.n_head
216 | self.n_embd_per_head = config.n_embd_per_head
217 | self.block_size = config.block_size
218 | self.dropout = config.dropout
219 |
220 | self.rope_scaling = config.rope_scaling
221 | self._rope_scaling_validation()
222 |
223 | self._init_rope()
224 | self.kv_cache = None
225 |
226 | def _init_rope(self):
227 | if self.rope_scaling is None:
228 | self.rotary_emb = LlamaRotaryEmbedding(
229 | self.n_embd_per_head, max_position_embeddings=self.block_size
230 | )
231 | else:
232 | scaling_type = self.rope_scaling["type"]
233 | scaling_factor = self.rope_scaling["factor"]
234 | if scaling_type == "nope":
235 | self.rotary_emb = None
236 | elif scaling_type == "linear":
237 | self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
238 | self.n_embd_per_head,
239 | max_position_embeddings=self.block_size,
240 | scaling_factor=scaling_factor,
241 | )
242 | elif scaling_type == "dynamic":
243 | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
244 | self.n_embd_per_head,
245 | max_position_embeddings=self.block_size,
246 | scaling_factor=scaling_factor,
247 | )
248 | else:
249 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
250 |
251 | def _rope_scaling_validation(self):
252 | """
253 | Validate the `rope_scaling` configuration.
254 | """
255 | if self.rope_scaling is None:
256 | return
257 |
258 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
259 | raise ValueError(
260 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
261 | f"got {self.rope_scaling}"
262 | )
263 | rope_scaling_type = self.rope_scaling.get("type", None)
264 | rope_scaling_factor = self.rope_scaling.get("factor", None)
265 | if rope_scaling_type is None or rope_scaling_type not in [
266 | "linear",
267 | "dynamic",
268 | "nope",
269 | ]:
270 | raise ValueError(
271 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
272 | )
273 | if rope_scaling_type in ["linear", "dynamic"]:
274 | if (
275 | rope_scaling_factor is None
276 | or not isinstance(rope_scaling_factor, float)
277 | or rope_scaling_factor < 1.0
278 | ):
279 | raise ValueError(
280 | f"`rope_scaling`'s factor field must be an float >= 1, got {rope_scaling_factor}"
281 | )
282 |
283 | def forward(self, x: torch.Tensor, use_kv_cache: bool) -> torch.Tensor:
284 | # batch size, sequence length, embedding dimensionality (n_embd)
285 | # sequence length will be 1 when using kv_cache and kv_cache has been initialised
286 | (B, T, C) = x.size()
287 |
288 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
289 | q = self.q_proj(x)
290 | k, v = self.kv_proj(x).split(self.n_embd_per_head * self.n_head, dim=2)
291 |
292 | cache_initialized = self.kv_cache is not None
293 | if use_kv_cache:
294 | # Optimized for single next prediction
295 | if cache_initialized:
296 | # Update cache
297 | k = torch.cat([self.kv_cache[0], k], dim=1)[:, 1:]
298 | v = torch.cat([self.kv_cache[1], v], dim=1)[:, 1:]
299 | self.kv_cache = k, v
300 | else:
301 | # Build cache
302 | self.kv_cache = k, v
303 |
304 | k = k.view(B, -1, self.n_head, self.n_embd_per_head).transpose(
305 | 1, 2
306 | ) # (B, nh, T, hs)
307 | q = q.view(B, -1, self.n_head, self.n_embd_per_head).transpose(
308 | 1, 2
309 | ) # (B, nh, T, hs)
310 | v = v.view(B, -1, self.n_head, self.n_embd_per_head).transpose(
311 | 1, 2
312 | ) # (B, nh, T, hs)
313 |
314 | # This the true sequence length after concatenation with kv_cache,
315 | # will be the same as `T` when kv_cache is not in use
316 | true_seq_len = k.size(2)
317 | if self.rotary_emb is not None:
318 | if use_kv_cache and cache_initialized:
319 | # When kv_cache is in use and we're working with only the last token (T = 1 instead of full sequence length `true_seq_len``)
320 | # Use the full sequence length for positional embeddings (true_seq_len)
321 | # q is the query vector for the last token, so it's position is the last index (-1)
322 | cos, sin = self.rotary_emb(device=v.device, dtype=v.dtype, seq_len=true_seq_len)
323 | q, _ = apply_rotary_pos_emb(q, k, cos, sin, position_ids=[-1])
324 |
325 | # k is the key matrix after concatenation with cache, so no position_ids
326 | cos, sin = self.rotary_emb(device=v.device, dtype=v.dtype, seq_len=true_seq_len)
327 | _, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=None)
328 | else:
329 | cos, sin = self.rotary_emb(device=v.device, dtype=v.dtype, seq_len=T)
330 | q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=None)
331 |
332 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
333 | # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
334 | # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
335 | # att = F.softmax(att, dim=-1)
336 | # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
337 |
338 | # efficient attention using Flash Attention CUDA kernels
339 | # When using kv cache at inference, is_causal=False since decoder is causal, at each generation step we want
340 | # to avoid recalculating the same previous token attention
341 |
342 | if use_kv_cache and cache_initialized:
343 | y = F.scaled_dot_product_attention(
344 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False
345 | )
346 | else:
347 | y = F.scaled_dot_product_attention(
348 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
349 | )
350 |
351 | # re-assemble all head outputs side by side
352 | y = y.transpose(1, 2).contiguous().view(B, T, C)
353 |
354 | # output projection
355 | y = self.c_proj(y)
356 |
357 | return y
358 |
359 |
360 | def find_multiple(n: int, k: int) -> int:
361 | if n % k == 0:
362 | return n
363 | return n + k - (n % k)
364 |
365 |
366 | class MLP(nn.Module):
367 | def __init__(self, config: LTSMConfig) -> None:
368 | super().__init__()
369 | hidden_dim = 4 * config.n_embd_per_head * config.n_head
370 | n_hidden = int(2 * hidden_dim / 3)
371 | n_hidden = find_multiple(n_hidden, 256)
372 |
373 | self.c_fc1 = nn.Linear(
374 | config.n_embd_per_head * config.n_head, n_hidden, bias=False
375 | )
376 | self.c_fc2 = nn.Linear(
377 | config.n_embd_per_head * config.n_head, n_hidden, bias=False
378 | )
379 | self.c_proj = nn.Linear(
380 | n_hidden, config.n_embd_per_head * config.n_head, bias=False
381 | )
382 |
383 | def forward(self, x: torch.Tensor) -> torch.Tensor:
384 | x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
385 | x = self.c_proj(x)
386 | return x
387 |
388 |
389 | class RMSNorm(nn.Module):
390 | """Root Mean Square Layer Normalization.
391 |
392 | Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
393 | https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
394 | """
395 |
396 | def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
397 | super().__init__()
398 | self.scale = nn.Parameter(torch.ones(size))
399 | self.eps = eps
400 | self.dim = dim
401 |
402 | def forward(self, x: torch.Tensor) -> torch.Tensor:
403 | # NOTE: the original RMSNorm paper implementation is not equivalent
404 | # norm_x = x.norm(2, dim=self.dim, keepdim=True)
405 | # rms_x = norm_x * d_x ** (-1. / 2)
406 | # x_normed = x / (rms_x + self.eps)
407 | # keep RMSNorm in float32
408 | norm_x = x.to(torch.float32).pow(2).mean(dim=self.dim, keepdim=True)
409 | x_normed = x * torch.rsqrt(norm_x + self.eps)
410 | return (self.scale * x_normed).type_as(x)
411 |
412 |
413 | class LagLlamaModel(nn.Module):
414 | def __init__(
415 | self,
416 | context_length: int,
417 | max_context_length: int,
418 | scaling: str,
419 | input_size: int,
420 | n_layer: int,
421 | n_embd_per_head: int,
422 | n_head: int,
423 | lags_seq: List[int],
424 | distr_output: DistributionOutput,
425 | rope_scaling=None,
426 | num_parallel_samples: int = 100,
427 | time_feat: bool = True,
428 | dropout: float = 0.0,
429 | ) -> None:
430 | super().__init__()
431 | self.context_length = context_length
432 | self.lags_seq = lags_seq
433 | if time_feat:
434 | feature_size = input_size * (len(self.lags_seq)) + 2 * input_size + 6
435 | else:
436 | feature_size = input_size * (len(self.lags_seq)) + 2 * input_size
437 |
438 | config = LTSMConfig(
439 | n_layer=n_layer,
440 | n_embd_per_head=n_embd_per_head,
441 | n_head=n_head,
442 | block_size=max_context_length,
443 | feature_size=feature_size,
444 | rope_scaling=rope_scaling,
445 | dropout=dropout,
446 | )
447 | self.num_parallel_samples = num_parallel_samples
448 |
449 | if scaling == "mean":
450 | self.scaler = MeanScaler(keepdim=True, dim=1)
451 | elif scaling == "std":
452 | self.scaler = StdScaler(keepdim=True, dim=1)
453 | elif scaling == "robust":
454 | self.scaler = RobustScaler(keepdim=True, dim=1)
455 | else:
456 | self.scaler = NOPScaler(keepdim=True, dim=1)
457 |
458 | self.distr_output = distr_output
459 | self.param_proj = self.distr_output.get_args_proj(
460 | config.n_embd_per_head * config.n_head
461 | )
462 |
463 | self.transformer = nn.ModuleDict(
464 | dict(
465 | wte=nn.Linear(
466 | config.feature_size, config.n_embd_per_head * config.n_head
467 | ),
468 | h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
469 | ln_f=RMSNorm(config.n_embd_per_head * config.n_head),
470 | )
471 | )
472 | self.y_cache = False # used at time of inference when kv cached is used
473 |
474 | def _init_weights(self, module: nn.Module) -> None:
475 | if isinstance(module, nn.Linear):
476 | torch.nn.init.normal_(
477 | module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)
478 | )
479 | elif isinstance(module, nn.Embedding):
480 | torch.nn.init.normal_(
481 | module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)
482 | )
483 |
484 | def prepare_input(
485 | self,
486 | past_target: torch.Tensor,
487 | past_observed_values: torch.Tensor,
488 | past_time_feat: Optional[torch.Tensor] = None,
489 | future_time_feat: Optional[torch.Tensor] = None,
490 | future_target: Optional[torch.Tensor] = None,
491 | ):
492 | scaled_past_target, loc, scale = self.scaler(
493 | past_target, past_observed_values
494 | ) # Data is standardized (past_observed_values is passed as "weights" parameter) # (bsz, context_length+max(self.lags_seq)
495 |
496 | # In the below code, instead of max(self.lags_seq), it was previously -self.context_length
497 | if future_target is not None:
498 | input = torch.cat(
499 | (
500 | scaled_past_target[..., max(self.lags_seq) :], # Just the context
501 | (future_target[..., :-1] - loc)
502 | / scale, # Not sure about the -1 here. Maybe so since the last value isn't used in the model for prediction of any new values. also if the prediction length is 1, this doesn't really affect anything
503 | ),
504 | dim=-1,
505 | ) # Shape is (bsz, context_length+(pred_len-1))
506 | else:
507 | input = scaled_past_target[..., max(self.lags_seq) :]
508 | if (past_time_feat is not None) and (future_time_feat is not None):
509 | time_feat = (
510 | torch.cat(
511 | (
512 | past_time_feat[..., max(self.lags_seq) :, :],
513 | future_time_feat[..., :-1, :],
514 | ),
515 | dim=1,
516 | )
517 | if future_time_feat is not None
518 | else past_time_feat[..., max(self.lags_seq) :, :]
519 | )
520 |
521 | prior_input = (
522 | past_target[..., : max(self.lags_seq)] - loc
523 | ) / scale # This the history used to construct lags. # bsz, max(self.lags_seq)
524 |
525 | lags = lagged_sequence_values(
526 | self.lags_seq, prior_input, input, dim=-1
527 | ) # Lags are added as an extra dim. Shape is (bsz, context_length+(pred_len-1), len(self.lags_seq))
528 |
529 | static_feat = torch.cat(
530 | (loc.abs().log1p(), scale.log()), dim=-1
531 | ) # (bsz, 2) (loc and scale are concatenated)
532 | expanded_static_feat = unsqueeze_expand(
533 | static_feat, dim=-2, size=lags.shape[-2]
534 | ) # (bsz, context_length+(pred_len-1), 2)
535 | # expanded_static_feat: (bsz, context_length+(pred_len-1), len(self.lags_seq) + 2); (bsz, 1); (bsz, 1)
536 |
537 | if past_time_feat is not None:
538 | return (
539 | torch.cat((lags, expanded_static_feat, time_feat), dim=-1),
540 | loc,
541 | scale,
542 | )
543 | else:
544 | return torch.cat((lags, expanded_static_feat), dim=-1), loc, scale
545 |
546 | def forward(
547 | self,
548 | past_target: torch.Tensor,
549 | past_observed_values: torch.Tensor,
550 | past_time_feat: Optional[torch.Tensor] = None,
551 | future_time_feat: Optional[torch.Tensor] = None,
552 | future_target: Optional[torch.Tensor] = None,
553 | use_kv_cache: bool = False,
554 | ) -> torch.Tensor:
555 | # if past_time_feat is not None:
556 | transformer_input, loc, scale = self.prepare_input(
557 | past_target=past_target,
558 | past_observed_values=past_observed_values,
559 | future_target=future_target,
560 | past_time_feat=past_time_feat,
561 | future_time_feat=future_time_feat,
562 | ) # return: (bsz, context_length+(pred_len-1), len(self.lags_seq) + 2); (bsz, 1); (bsz, 1)
563 | # To use kv cache for inference and pass recent token to transformer
564 | if use_kv_cache and self.y_cache:
565 | # Only use the most recent one, rest is in cache
566 | transformer_input = transformer_input[:, -1:]
567 |
568 | # forward the LLaMA model itself
569 | x = self.transformer.wte(
570 | transformer_input
571 | ) # token embeddings of shape (b, t, n_embd_per_head*n_head) # (bsz, context_length+(pred_len-1), n_embd_per_head*n_head)
572 |
573 | for block in self.transformer.h:
574 | x = block(x, use_kv_cache)
575 | x = self.transformer.ln_f(
576 | x
577 | ) # (bsz, context_length+(pred_len-1), n_embd_per_head*n_head)
578 | if use_kv_cache:
579 | self.y_cache = True
580 | params = self.param_proj(
581 | x
582 | ) # (bsz, context_length+(pred_len-1)) ; (bsz, context_length+(pred_len-1))
583 | return params, loc, scale
584 |
585 | def reset_cache(self) -> None:
586 | """
587 | Resets all cached key-values in attention.
588 | Has to be called after prediction loop in predictor
589 | """
590 | self.y_cache = False
591 | for block in self.transformer.h:
592 | block.attn.kv_cache = None
593 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 |
6 | [project]
7 | name = "lag-llama"
8 | dynamic = ["version", "dependencies"]
9 | description = "Lag-Llama is the first open-source foundation model for time series forecasting!"
10 | readme = "README.md"
11 | license = {file = "LICENSE"}
12 | authors = [
13 | {name = "Arjun Ashok", email = "arjun.ashok@servicenow.com"},
14 | {name = "Kashif Rasul", email = "kashif.rasul@gmail.com"}
15 | ]
16 | keywords = ["llama", "time series", "forecasting", "machine learning", "open-source", "foundation model", "lag-llama"]
17 |
18 |
19 | [project.urls]
20 | "Homepage" = "https://github.com/time-series-foundation-models/lag-llama"
21 |
22 |
23 | [tool.setuptools.dynamic]
24 | dependencies = {file = "requirements.txt"}
25 | version = {file = "VERSION"}
26 |
27 | [tool.setuptools]
28 | package-data = {"data" = ["data/*"], "configs" = ["configs/*"]}
29 |
30 | [tool.setuptools.packages.find]
31 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gluonts[torch]<=0.14.4
2 | numpy>=1.23.5
3 | torch>=2.0.0
4 | wandb
5 | scipy
6 | pandas
7 | huggingface_hub[cli]
8 | matplotlib
9 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 |
17 | warnings.simplefilter(action="ignore", category=FutureWarning)
18 | warnings.simplefilter(action="ignore", category=UserWarning)
19 |
20 | import argparse
21 | import gc
22 | import json
23 | import os
24 | from hashlib import sha1
25 |
26 | import lightning
27 | import torch
28 | import wandb
29 | from gluonts.evaluation import Evaluator, make_evaluation_predictions
30 | from gluonts.evaluation._base import aggregate_valid
31 | from gluonts.transform import ExpectedNumInstanceSampler
32 | from lightning.pytorch.callbacks import (
33 | EarlyStopping,
34 | ModelCheckpoint,
35 | StochasticWeightAveraging,
36 | LearningRateMonitor
37 | )
38 | from lightning.pytorch.loggers import WandbLogger
39 |
40 | from data.data_utils import (
41 | CombinedDataset,
42 | SingleInstanceSampler,
43 | create_test_dataset,
44 | create_train_and_val_datasets_with_dates,
45 | )
46 |
47 | from data.dataset_list import ALL_DATASETS
48 | from utils.utils import plot_forecasts, set_seed
49 |
50 |
51 | from lag_llama.gluon.estimator import LagLlamaEstimator
52 |
53 |
54 | def train(args):
55 | # Set seed
56 | set_seed(args.seed)
57 | lightning.seed_everything(args.seed)
58 |
59 | # # Print GPU stats
60 | # print_gpu_stats()
61 |
62 | # Create a directory to store the results in
63 | # This string is made independent of hyperparameters here, as more hyperparameters / arguments may be added later
64 | # The name should be created in the calling bash script
65 | # This way, when that same script is executed again, automatically the model training is resumed from a checkpoint if available
66 | experiment_name = args.experiment_name
67 | fulldir_experiments = os.path.join(args.results_dir, experiment_name, str(args.seed))
68 | if os.path.exists(fulldir_experiments): print(fulldir_experiments, "already exists.")
69 | os.makedirs(fulldir_experiments, exist_ok=True)
70 |
71 | # Create directory for checkpoints
72 | checkpoint_dir = os.path.join(fulldir_experiments, "checkpoints")
73 | os.makedirs(checkpoint_dir, exist_ok=True)
74 |
75 | # Code to retrieve the version with the highest #epoch stored and restore it incl directory and its checkpoint
76 | if args.ckpt_path:
77 | ckpt_path = args.ckpt_path
78 | elif args.get_ckpt_path_from_experiment_name:
79 | fulldir_experiments_for_ckpt_path = os.path.join(args.results_dir, args.get_ckpt_path_from_experiment_name, str(args.seed))
80 | full_experiment_name_original = args.get_ckpt_path_from_experiment_name + "-seed-" + str(args.seed)
81 | experiment_id_original = sha1(full_experiment_name_original.encode("utf-8")).hexdigest()[:8]
82 | checkpoint_dir_wandb = os.path.join(fulldir_experiments_for_ckpt_path, "lag-llama", experiment_id_original, "checkpoints")
83 | file = os.listdir(checkpoint_dir_wandb)[-1]
84 | if file: ckpt_path = os.path.join(checkpoint_dir_wandb, file)
85 | if not ckpt_path: raise Exception("ckpt_path not found from experiment name")
86 | # Delete the EarlyStoppingCallback and save it in the current checkpoint_dir
87 | new_ckpt_path = checkpoint_dir + "/pretrained_ckpt.ckpt"
88 | print("Moving", ckpt_path, "to", new_ckpt_path)
89 | ckpt_loaded = torch.load(ckpt_path)
90 | del ckpt_loaded['callbacks']["EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"]
91 | ckpt_loaded['callbacks']["ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"]["best_model_path"] = new_ckpt_path
92 | ckpt_loaded['callbacks']["ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"]["dirpath"] = checkpoint_dir
93 | del ckpt_loaded['callbacks']["ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1, 'train_time_interval': None}"]["last_model_path"]
94 | torch.save(ckpt_loaded, checkpoint_dir + "/pretrained_ckpt.ckpt")
95 | ckpt_path = checkpoint_dir + "/pretrained_ckpt.ckpt"
96 | else:
97 | ckpt_path = None
98 | if not args.evaluate_only:
99 | ckpt_path = checkpoint_dir + "/last.ckpt"
100 | if not os.path.isfile(ckpt_path): ckpt_path = None
101 | else:
102 | if args.evaluate_only:
103 | full_experiment_name_original = experiment_name + "-seed-" + str(args.seed)
104 | experiment_id_original = sha1(full_experiment_name_original.encode("utf-8")).hexdigest()[:8]
105 | checkpoint_dir_wandb = os.path.join(fulldir_experiments, "lag-llama", experiment_id_original, "checkpoints")
106 | file = os.listdir(checkpoint_dir_wandb)[-1]
107 | if file: ckpt_path = os.path.join(checkpoint_dir_wandb, file)
108 | elif args.evaluate_only:
109 | for file in os.listdir(checkpoint_dir):
110 | if "best" in file:
111 | ckpt_path = checkpoint_dir + "/" + file
112 | break
113 |
114 | if ckpt_path:
115 | print("Checkpoint", ckpt_path, "retrieved from experiment directory")
116 | else:
117 | print("No checkpoints found. Training from scratch.")
118 |
119 | # W&B logging
120 | # NOTE: Caution when using `full_experiment_name` after this
121 | if args.eval_prefix and (args.evaluate_only): experiment_name = args.eval_prefix + "_" + experiment_name
122 | full_experiment_name = experiment_name + "-seed-" + str(args.seed)
123 | experiment_id = sha1(full_experiment_name.encode("utf-8")).hexdigest()[:8]
124 | logger = WandbLogger(name=full_experiment_name, \
125 | save_dir=fulldir_experiments, group=experiment_name, \
126 | tags=args.wandb_tags, entity=args.wandb_entity, \
127 | project=args.wandb_project, allow_val_change=True, \
128 | config=vars(args), id=experiment_id, \
129 | mode=args.wandb_mode, settings=wandb.Settings(code_dir="."))
130 |
131 | # Callbacks
132 | swa_callbacks = StochasticWeightAveraging(
133 | swa_lrs=args.swa_lrs,
134 | swa_epoch_start=args.swa_epoch_start,
135 | annealing_epochs=args.annealing_epochs,
136 | annealing_strategy=args.annealing_strategy,
137 | )
138 | early_stop_callback = EarlyStopping(
139 | monitor="val_loss",
140 | min_delta=0.00,
141 | patience=int(args.early_stopping_patience),
142 | verbose=True,
143 | mode="min",
144 | )
145 | model_checkpointing = ModelCheckpoint(
146 | dirpath=checkpoint_dir,
147 | save_last=True,
148 | save_top_k=1,
149 | filename="best-{epoch}-{val_loss:.2f}",
150 | )
151 | lr_monitor = LearningRateMonitor(logging_interval='step')
152 | callbacks = [early_stop_callback,
153 | lr_monitor,
154 | model_checkpointing
155 | ]
156 | if args.swa:
157 | print("Using SWA")
158 | callbacks.append(swa_callbacks)
159 |
160 | # Create train and test datasets
161 | if not args.single_dataset:
162 | train_dataset_names = args.all_datasets
163 | for test_dataset in args.test_datasets:
164 | train_dataset_names.remove(test_dataset)
165 | print("Training datasets:", train_dataset_names)
166 | print("Test datasets:", args.test_datasets)
167 | data_id_to_name_map = {}
168 | name_to_data_id_map = {}
169 | for data_id, name in enumerate(train_dataset_names):
170 | data_id_to_name_map[data_id] = name
171 | name_to_data_id_map[name] = data_id
172 | test_data_id = -1
173 | for name in args.test_datasets:
174 | data_id_to_name_map[test_data_id] = name
175 | name_to_data_id_map[name] = test_data_id
176 | test_data_id -= 1
177 | else:
178 | print("Training and test on", args.single_dataset)
179 | data_id_to_name_map = {}
180 | name_to_data_id_map = {}
181 | data_id_to_name_map[0] = args.single_dataset
182 | name_to_data_id_map[args.single_dataset] = 0
183 |
184 | # Get prediction length and set it if we are in the single dataset
185 | if args.single_dataset and args.use_dataset_prediction_length:
186 | _, prediction_length, _ = create_test_dataset(
187 | args.single_dataset, args.dataset_path, 0
188 | )
189 | args.prediction_length = prediction_length
190 |
191 | # Cosine Annealing LR
192 | if args.use_cosine_annealing_lr:
193 | cosine_annealing_lr_args = {"T_max": args.cosine_annealing_lr_t_max, \
194 | "eta_min": args.cosine_annealing_lr_eta_min}
195 | else:
196 | cosine_annealing_lr_args = {}
197 |
198 | # Create the estimator
199 | estimator = LagLlamaEstimator(
200 | prediction_length=args.prediction_length,
201 | context_length=args.context_length,
202 | input_size=1,
203 | batch_size=args.batch_size,
204 | n_layer=args.n_layer,
205 | n_embd_per_head=args.n_embd_per_head,
206 | n_head=args.n_head,
207 | max_context_length=2048,
208 | rope_scaling=None,
209 | scaling=args.data_normalization,
210 | lr=args.lr,
211 | weight_decay=args.weight_decay,
212 | distr_output=args.distr_output,
213 | # augmentations
214 | aug_prob=args.aug_prob,
215 | freq_mask_rate=args.freq_mask_rate,
216 | freq_mixing_rate=args.freq_mixing_rate,
217 | jitter_prob=args.jitter_prob,
218 | jitter_sigma=args.jitter_sigma,
219 | scaling_prob=args.scaling_prob,
220 | scaling_sigma=args.scaling_sigma,
221 | rotation_prob=args.rotation_prob,
222 | permutation_prob=args.permutation_prob,
223 | permutation_max_segments=args.permutation_max_segments,
224 | permutation_seg_mode=args.permutation_seg_mode,
225 | magnitude_warp_prob=args.magnitude_warp_prob,
226 | magnitude_warp_sigma=args.magnitude_warp_sigma,
227 | magnitude_warp_knot=args.magnitude_warp_knot,
228 | time_warp_prob=args.time_warp_prob,
229 | time_warp_sigma=args.time_warp_sigma,
230 | time_warp_knot=args.time_warp_knot,
231 | window_slice_prob=args.window_slice_prob,
232 | window_slice_reduce_ratio=args.window_slice_reduce_ratio,
233 | window_warp_prob=args.window_warp_prob,
234 | window_warp_window_ratio=args.window_warp_window_ratio,
235 | window_warp_scales=args.window_warp_scales,
236 | # others
237 | num_batches_per_epoch=args.num_batches_per_epoch,
238 | num_parallel_samples=args.num_parallel_samples,
239 | time_feat=args.time_feat,
240 | dropout=args.dropout,
241 | lags_seq=args.lags_seq,
242 | data_id_to_name_map=data_id_to_name_map,
243 | use_cosine_annealing_lr=args.use_cosine_annealing_lr,
244 | cosine_annealing_lr_args=cosine_annealing_lr_args,
245 | track_loss_per_series=args.single_dataset != None,
246 | ckpt_path=ckpt_path,
247 | trainer_kwargs=dict(
248 | max_epochs=args.max_epochs,
249 | accelerator="gpu",
250 | devices=[args.gpu],
251 | limit_val_batches=args.limit_val_batches,
252 | logger=logger,
253 | callbacks=callbacks,
254 | default_root_dir=fulldir_experiments,
255 | ),
256 | )
257 |
258 | # Save the args as config to the directory
259 | config_filepath = fulldir_experiments + "/args.json"
260 | with open(config_filepath, "w") as config_savefile:
261 | json.dump(vars(args), config_savefile, indent=4)
262 |
263 | # Save the number of parameters to the directory for easy retrieval
264 | num_parameters = sum(
265 | p.numel() for p in estimator.create_lightning_module().parameters()
266 | )
267 | num_parameters_path = fulldir_experiments + "/num_parameters.txt"
268 | with open(num_parameters_path, "w") as num_parameters_savefile:
269 | num_parameters_savefile.write(str(num_parameters))
270 | # Log num_parameters
271 | logger.log_metrics({"num_parameters": num_parameters})
272 |
273 | # Create samplers
274 | # Here we make a window slightly bigger so that instance sampler can sample from each window
275 | # An alternative is to have exact size and use different instance sampler (e.g. ValidationSplitSampler)
276 | # We change ValidationSplitSampler to add min_past
277 | history_length = estimator.context_length + max(estimator.lags_seq)
278 | prediction_length = args.prediction_length
279 | window_size = history_length + prediction_length
280 | print(
281 | "Context length:",
282 | estimator.context_length,
283 | "Prediction Length:",
284 | estimator.prediction_length,
285 | "max(lags_seq):",
286 | max(estimator.lags_seq),
287 | "Therefore, window size:",
288 | window_size,
289 | )
290 |
291 | # Remove max(estimator.lags_seq) if the dataset is too small
292 | if args.use_single_instance_sampler:
293 | estimator.train_sampler = SingleInstanceSampler(
294 | min_past=estimator.context_length + max(estimator.lags_seq),
295 | min_future=estimator.prediction_length,
296 | )
297 | estimator.validation_sampler = SingleInstanceSampler(
298 | min_past=estimator.context_length + max(estimator.lags_seq),
299 | min_future=estimator.prediction_length,
300 | )
301 | else:
302 | estimator.train_sampler = ExpectedNumInstanceSampler(
303 | num_instances=1.0,
304 | min_past=estimator.context_length + max(estimator.lags_seq),
305 | min_future=estimator.prediction_length,
306 | )
307 | estimator.validation_sampler = ExpectedNumInstanceSampler(
308 | num_instances=1.0,
309 | min_past=estimator.context_length + max(estimator.lags_seq),
310 | min_future=estimator.prediction_length,
311 | )
312 |
313 | ## Batch size
314 | batch_size = args.batch_size
315 |
316 | if args.evaluate_only:
317 | pass
318 | else:
319 | if not args.single_dataset:
320 | # Create training and validation data
321 | all_datasets, val_datasets, dataset_num_series = [], [], []
322 | dataset_train_num_points, dataset_val_num_points = [], []
323 |
324 | for data_id, name in enumerate(train_dataset_names):
325 | data_id = name_to_data_id_map[name]
326 | (
327 | train_dataset,
328 | val_dataset,
329 | total_train_points,
330 | total_val_points,
331 | total_val_windows,
332 | max_train_end_date,
333 | total_points,
334 | ) = create_train_and_val_datasets_with_dates(
335 | name,
336 | args.dataset_path,
337 | data_id,
338 | history_length,
339 | prediction_length,
340 | num_val_windows=args.num_validation_windows,
341 | last_k_percentage=args.single_dataset_last_k_percentage
342 | )
343 | print(
344 | "Dataset:",
345 | name,
346 | "Total train points:", total_train_points,
347 | "Total val points:", total_val_points,
348 | )
349 | all_datasets.append(train_dataset)
350 | val_datasets.append(val_dataset)
351 | dataset_num_series.append(len(train_dataset))
352 | dataset_train_num_points.append(total_train_points)
353 | dataset_val_num_points.append(total_val_points)
354 |
355 | # Add test splits of test data to validation dataset, just for tracking purposes
356 | test_datasets_num_series = []
357 | test_datasets_num_points = []
358 | test_datasets = []
359 |
360 | if args.stratified_sampling:
361 | if args.stratified_sampling == "series":
362 | train_weights = dataset_num_series
363 | val_weights = dataset_num_series + test_datasets_num_series # If there is just 1 series (airpassengers or saugeenday) this will fail
364 | elif args.stratified_sampling == "series_inverse":
365 | train_weights = [1/x for x in dataset_num_series]
366 | val_weights = [1/x for x in dataset_num_series + test_datasets_num_series] # If there is just 1 series (airpassengers or saugeenday) this will fail
367 | elif args.stratified_sampling == "timesteps":
368 | train_weights = dataset_train_num_points
369 | val_weights = dataset_val_num_points + test_datasets_num_points
370 | elif args.stratified_sampling == "timesteps_inverse":
371 | train_weights = [1 / x for x in dataset_train_num_points]
372 | val_weights = [1 / x for x in dataset_val_num_points + test_datasets_num_points]
373 | else:
374 | train_weights = val_weights = None
375 |
376 | train_data = CombinedDataset(all_datasets, weights=train_weights)
377 | val_data = CombinedDataset(val_datasets+test_datasets, weights=val_weights)
378 | else:
379 | (
380 | train_data,
381 | val_data,
382 | total_train_points,
383 | total_val_points,
384 | total_val_windows,
385 | max_train_end_date,
386 | total_points,
387 | ) = create_train_and_val_datasets_with_dates(
388 | args.single_dataset,
389 | args.dataset_path,
390 | 0,
391 | history_length,
392 | prediction_length,
393 | num_val_windows=args.num_validation_windows,
394 | last_k_percentage=args.single_dataset_last_k_percentage
395 | )
396 | print(
397 | "Dataset:",
398 | args.single_dataset,
399 | "Total train points:", total_train_points,
400 | "Total val points:", total_val_points,
401 | )
402 |
403 | # Batch size search since when we scale up, we might not be able to use the same batch size for all models
404 | if args.search_batch_size:
405 | estimator.num_batches_per_epoch = 10
406 | estimator.limit_val_batches = 10
407 | estimator.trainer_kwargs["max_epochs"] = 1
408 | estimator.trainer_kwargs["callbacks"] = []
409 | estimator.trainer_kwargs["logger"] = None
410 | fulldir_batchsize_search = os.path.join(
411 | fulldir_experiments, "batch-size-search"
412 | )
413 | os.makedirs(fulldir_batchsize_search, exist_ok=True)
414 | while batch_size >= 1:
415 | try:
416 | print("Trying batch size:", batch_size)
417 | batch_size_search_dir = os.path.join(
418 | fulldir_batchsize_search, "batch-size-search-" + str(batch_size)
419 | )
420 | os.makedirs(batch_size_search_dir, exist_ok=True)
421 | estimator.batch_size = batch_size
422 | estimator.trainer_kwargs[
423 | "default_root_dir"
424 | ] = fulldir_batchsize_search
425 | # Train
426 | train_output = estimator.train_model(
427 | training_data=train_data,
428 | validation_data=val_data,
429 | shuffle_buffer_length=None,
430 | ckpt_path=None,
431 | )
432 | break
433 | except RuntimeError as e:
434 | if "out of memory" in str(e):
435 | gc.collect()
436 | torch.cuda.empty_cache()
437 | if batch_size == 1:
438 | print(
439 | "Batch is already at the minimum. Cannot reduce further. Exiting..."
440 | )
441 | exit(0)
442 | else:
443 | print("Caught OutOfMemoryError. Reducing batch size...")
444 | batch_size //= 2
445 | continue
446 | else:
447 | print(e)
448 | exit(1)
449 | estimator.num_batches_per_epoch = args.num_batches_per_epoch
450 | estimator.limit_val_batches = args.limit_val_batches
451 | estimator.trainer_kwargs["max_epochs"] = args.max_epochs
452 | estimator.trainer_kwargs["callbacks"] = callbacks
453 | estimator.trainer_kwargs["logger"] = logger
454 | estimator.trainer_kwargs["default_root_dir"] = fulldir_experiments
455 | if batch_size > 1: batch_size //= 2
456 | estimator.batch_size = batch_size
457 | print("\nUsing a batch size of", batch_size, "\n")
458 | wandb.config.update({"batch_size": batch_size}, allow_val_change=True)
459 |
460 |
461 | # Train
462 | train_output = estimator.train_model(
463 | training_data=train_data,
464 | validation_data=val_data,
465 | shuffle_buffer_length=None,
466 | ckpt_path=ckpt_path,
467 | )
468 |
469 | # Set checkpoint path before evaluating
470 | best_model_path = train_output.trainer.checkpoint_callback.best_model_path
471 | estimator.ckpt_path = best_model_path
472 |
473 |
474 | print("Using checkpoint:", estimator.ckpt_path, "for evaluation")
475 | # Make directory to store metrics
476 | metrics_dir = os.path.join(fulldir_experiments, "metrics")
477 | os.makedirs(metrics_dir, exist_ok=True)
478 |
479 | # Evaluate
480 | evaluation_datasets = args.test_datasets + train_dataset_names if not args.single_dataset else [args.single_dataset]
481 |
482 | for name in evaluation_datasets: # [test_dataset]:
483 | print("Evaluating on", name)
484 | test_data, prediction_length, total_points = create_test_dataset(
485 | name, args.dataset_path, window_size
486 | )
487 | print("# of Series in the test data:", len(test_data))
488 |
489 | # Adapt evaluator to new dataset
490 | estimator.prediction_length = prediction_length
491 | # Batch size loop just in case. This is mandatory as it involves sampling etc.
492 | # NOTE: In case can't do sampling with even batch size of 1, then keep reducing num_parallel_samples until we can (keeping batch size at 1)
493 | while batch_size >= 1:
494 | try:
495 | # Batch size
496 | print("Trying batch size:", batch_size)
497 | estimator.batch_size = batch_size
498 | predictor = estimator.create_predictor(
499 | estimator.create_transformation(),
500 | estimator.create_lightning_module(),
501 | )
502 | # Make evaluations
503 | forecast_it, ts_it = make_evaluation_predictions(
504 | dataset=test_data, predictor=predictor, num_samples=args.num_samples
505 | )
506 | forecasts = list(forecast_it)
507 | tss = list(ts_it)
508 | break
509 | except RuntimeError as e:
510 | if "out of memory" in str(e):
511 | gc.collect()
512 | torch.cuda.empty_cache()
513 | if batch_size == 1:
514 | print(
515 | "Batch is already at the minimum. Cannot reduce further. Exiting..."
516 | )
517 | exit(0)
518 | else:
519 | print("Caught OutOfMemoryError. Reducing batch size...")
520 | batch_size //= 2
521 | continue
522 | else:
523 | print(e)
524 | exit(1)
525 |
526 | if args.plot_test_forecasts:
527 | print("Plotting forecasts")
528 | figure = plot_forecasts(forecasts, tss, prediction_length)
529 | wandb.log({f"Forecast plot of {name}": wandb.Image(figure)})
530 |
531 | # Get metrics
532 | evaluator = Evaluator(
533 | num_workers=args.num_workers, aggregation_strategy=aggregate_valid
534 | )
535 | agg_metrics, _ = evaluator(
536 | iter(tss), iter(forecasts), num_series=len(test_data)
537 | )
538 | # Save metrics
539 | metrics_savepath = metrics_dir + "/" + name + ".json"
540 | with open(metrics_savepath, "w") as metrics_savefile:
541 | json.dump(agg_metrics, metrics_savefile)
542 |
543 | # Log metrics. For now only CRPS is logged.
544 | wandb_metrics = {}
545 | wandb_metrics["test/" + name + "/" + "CRPS"] = agg_metrics["mean_wQuantileLoss"]
546 | logger.log_metrics(wandb_metrics)
547 |
548 | wandb.finish()
549 |
550 | if __name__ == "__main__":
551 | parser = argparse.ArgumentParser()
552 |
553 | # Experiment args
554 | parser.add_argument("-e", "--experiment_name", type=str, required=True)
555 |
556 | # Data arguments
557 | parser.add_argument(
558 | "-d",
559 | "--dataset_path",
560 | type=str,
561 | default="datasets",
562 | help="Enter the datasets folder path here"
563 | )
564 | parser.add_argument("--all_datasets", type=str, nargs="+", default=ALL_DATASETS)
565 | parser.add_argument("-t", "--test_datasets", type=str, nargs="+", default=[])
566 | parser.add_argument(
567 | "--stratified_sampling",
568 | type=str,
569 | choices=["series", "series_inverse", "timesteps", "timesteps_inverse"],
570 | )
571 |
572 | # Seed
573 | parser.add_argument("--seed", type=int, default=42)
574 |
575 | # Model hyperparameters
576 | parser.add_argument("--context_length", type=int, default=256)
577 | parser.add_argument("--prediction_length", type=int, default=1)
578 | parser.add_argument("--max_prediction_length", type=int, default=1024)
579 | parser.add_argument("--n_layer", type=int, default=4)
580 | parser.add_argument("--num_encoder_layer", type=int, default=4, help="Only for lag-transformer")
581 | parser.add_argument("--n_embd_per_head", type=int, default=64)
582 | parser.add_argument("--n_head", type=int, default=4)
583 | parser.add_argument("--dim_feedforward", type=int, default=256)
584 | parser.add_argument("--lags_seq", type=str, nargs="+", default=["Q", "M", "W", "D", "H", "T", "S"])
585 |
586 | # Data normalization
587 | parser.add_argument(
588 | "--data_normalization", default=None, choices=["mean", "std", "robust", "none"]
589 | )
590 |
591 | ## Augmentation hyperparameters
592 | # Augmentation probability
593 | parser.add_argument("--aug_prob", type=float, default=0)
594 |
595 | # Frequency Masking
596 | parser.add_argument(
597 | "--freq_mask_rate", type=float, default=0.1, help="Rate of frequency masking"
598 | )
599 |
600 | # Frequency Mixing
601 | parser.add_argument(
602 | "--freq_mixing_rate", type=float, default=0.1, help="Rate of frequency mixing"
603 | )
604 |
605 | # Jitter
606 | parser.add_argument(
607 | "--jitter_prob",
608 | type=float,
609 | default=0,
610 | help="Probability of applying Jitter augmentation",
611 | )
612 | parser.add_argument(
613 | "--jitter_sigma",
614 | type=float,
615 | default=0.03,
616 | help="Standard deviation for Jitter augmentation",
617 | )
618 |
619 | # Scaling
620 | parser.add_argument(
621 | "--scaling_prob",
622 | type=float,
623 | default=0,
624 | help="Probability of applying Scaling augmentation",
625 | )
626 | parser.add_argument(
627 | "--scaling_sigma",
628 | type=float,
629 | default=0.1,
630 | help="Standard deviation for Scaling augmentation",
631 | )
632 |
633 | # Rotation
634 | parser.add_argument(
635 | "--rotation_prob",
636 | type=float,
637 | default=0,
638 | help="Probability of applying Rotation augmentation",
639 | )
640 |
641 | # Permutation
642 | parser.add_argument(
643 | "--permutation_prob",
644 | type=float,
645 | default=0,
646 | help="Probability of applying Permutation augmentation",
647 | )
648 | parser.add_argument(
649 | "--permutation_max_segments",
650 | type=int,
651 | default=5,
652 | help="Maximum segments for Permutation augmentation",
653 | )
654 | parser.add_argument(
655 | "--permutation_seg_mode",
656 | type=str,
657 | default="equal",
658 | choices=["equal", "random"],
659 | help="Segment mode for Permutation augmentation",
660 | )
661 |
662 | # MagnitudeWarp
663 | parser.add_argument(
664 | "--magnitude_warp_prob",
665 | type=float,
666 | default=0,
667 | help="Probability of applying MagnitudeWarp augmentation",
668 | )
669 | parser.add_argument(
670 | "--magnitude_warp_sigma",
671 | type=float,
672 | default=0.2,
673 | help="Standard deviation for MagnitudeWarp augmentation",
674 | )
675 | parser.add_argument(
676 | "--magnitude_warp_knot",
677 | type=int,
678 | default=4,
679 | help="Number of knots for MagnitudeWarp augmentation",
680 | )
681 |
682 | # TimeWarp
683 | parser.add_argument(
684 | "--time_warp_prob",
685 | type=float,
686 | default=0,
687 | help="Probability of applying TimeWarp augmentation",
688 | )
689 | parser.add_argument(
690 | "--time_warp_sigma",
691 | type=float,
692 | default=0.2,
693 | help="Standard deviation for TimeWarp augmentation",
694 | )
695 | parser.add_argument(
696 | "--time_warp_knot",
697 | type=int,
698 | default=4,
699 | help="Number of knots for TimeWarp augmentation",
700 | )
701 |
702 | # WindowSlice
703 | parser.add_argument(
704 | "--window_slice_prob",
705 | type=float,
706 | default=0,
707 | help="Probability of applying WindowSlice augmentation",
708 | )
709 | parser.add_argument(
710 | "--window_slice_reduce_ratio",
711 | type=float,
712 | default=0.9,
713 | help="Reduce ratio for WindowSlice augmentation",
714 | )
715 |
716 | # WindowWarp
717 | parser.add_argument(
718 | "--window_warp_prob",
719 | type=float,
720 | default=0,
721 | help="Probability of applying WindowWarp augmentation",
722 | )
723 | parser.add_argument(
724 | "--window_warp_window_ratio",
725 | type=float,
726 | default=0.1,
727 | help="Window ratio for WindowWarp augmentation",
728 | )
729 | parser.add_argument(
730 | "--window_warp_scales",
731 | nargs="+",
732 | type=float,
733 | default=[0.5, 2.0],
734 | help="Scales for WindowWarp augmentation",
735 | )
736 |
737 | # Argument to include time-features
738 | parser.add_argument(
739 | "--time_feat",
740 | help="include time features",
741 | action="store_true",
742 | )
743 |
744 | # Training arguments
745 | parser.add_argument("-b", "--batch_size", type=int, default=256)
746 | parser.add_argument("-m", "--max_epochs", type=int, default=10000)
747 | parser.add_argument("-n", "--num_batches_per_epoch", type=int, default=100)
748 | parser.add_argument("--limit_val_batches", type=int)
749 | parser.add_argument("--early_stopping_patience", default=50)
750 | parser.add_argument("--dropout", type=float, default=0.0)
751 |
752 | # Evaluation arguments
753 | parser.add_argument("--num_parallel_samples", type=int, default=100)
754 | parser.add_argument("--num_samples", type=int, default=100)
755 | parser.add_argument("--num_workers", type=int, default=1)
756 |
757 | # GPU ID
758 | parser.add_argument("--gpu", type=int, default=0)
759 |
760 | # Directory to save everything in
761 | parser.add_argument("-r", "--results_dir", type=str, required=True)
762 |
763 | # W&B
764 | parser.add_argument("-w", "--wandb_entity", type=str, default=None)
765 | parser.add_argument("--wandb_project", type=str, default="lag-llama-test")
766 | parser.add_argument("--wandb_tags", nargs="+")
767 | parser.add_argument(
768 | "--wandb_mode", type=str, default="online", choices=["offline", "online"]
769 | )
770 |
771 | # Other arguments
772 | parser.add_argument(
773 | "--evaluate_only", action="store_true", help="Only evaluate, do not train"
774 | )
775 | parser.add_argument(
776 | "--use_kv_cache",
777 | help="KV caching during infernce. Only for Lag-LLama.",
778 | action="store_true",
779 | default=True
780 | )
781 |
782 | # SWA arguments
783 | parser.add_argument(
784 | "--swa", action="store_true", help="Using Stochastic Weight Averaging"
785 | )
786 | parser.add_argument("--swa_lrs", type=float, default=1e-2)
787 | parser.add_argument("--swa_epoch_start", type=float, default=0.8)
788 | parser.add_argument("--annealing_epochs", type=int, default=10)
789 | parser.add_argument(
790 | "--annealing_strategy", type=str, default="cos", choices=["cos", "linear"]
791 | )
792 |
793 | # Training/validation iterator type switching
794 | parser.add_argument("--use_single_instance_sampler", action="store_true", default=True)
795 |
796 | # Plot forecasts
797 | parser.add_argument("--plot_test_forecasts", action="store_true", default=True)
798 |
799 | # Search search_batch_size
800 | parser.add_argument("--search_batch_size", action="store_true", default=False)
801 |
802 | # Number of validation windows
803 | parser.add_argument("--num_validation_windows", type=int, default=14)
804 |
805 | # Training KWARGS
806 | parser.add_argument('--lr', type=float, default=1e-3)
807 | parser.add_argument('--weight_decay', type=float, default=1e-8)
808 |
809 | # Override arguments with a dictionary file with args
810 | parser.add_argument('--args_from_dict_path', type=str)
811 |
812 | # Evaluation utils
813 | parser.add_argument("--eval_prefix", type=str)
814 |
815 | # Checkpoints args
816 | parser.add_argument("--ckpt_path", type=str)
817 | parser.add_argument("--get_ckpt_path_from_experiment_name", type=str)
818 |
819 | # Single dataset setup: used typically for finetuning
820 | parser.add_argument("--single_dataset", type=str)
821 | parser.add_argument("--use_dataset_prediction_length", action="store_true", default=False)
822 | parser.add_argument("--single_dataset_last_k_percentage", type=float)
823 |
824 | # CosineAnnealingLR
825 | parser.add_argument("--use_cosine_annealing_lr", action="store_true", default=False)
826 | parser.add_argument("--cosine_annealing_lr_t_max", type=int, default=10000)
827 | parser.add_argument("--cosine_annealing_lr_eta_min", type=float, default=1e-2)
828 |
829 | # Distribution output
830 | parser.add_argument('--distr_output', type=str, default="studentT", choices=["studentT"])
831 |
832 | args = parser.parse_args()
833 |
834 | if args.args_from_dict_path:
835 | with open(args.args_from_dict_path, "r") as read_file: loaded_args = json.load(read_file)
836 | for key, value in loaded_args.items():
837 | setattr(args, key, value)
838 |
839 | # print args for logging
840 | for arg in vars(args):
841 | print(arg, ":", getattr(args, arg))
842 |
843 | train(args)
844 |
--------------------------------------------------------------------------------
/scripts/finetune.sh:
--------------------------------------------------------------------------------
1 | # This script can only be executed once you have trained a model. The experiment name used for the trained model should be specified in line 4
2 |
3 | CONFIGPATH="configs/lag_llama.json"
4 | PRETRAINING_EXP_NAME="pretraining_lag_llama"
5 | PERCENTAGE=100 # Change to lesser value to limit the history. Use 20, 40, 60, 80 to reproduce experiments in the paper.
6 |
7 | for FINETUNE_DATASET in "weather" "pedestrian_counts" "exchange_rate" "ett_m2" "platform_delay_minute" "requests_minute" "beijing_pm25"
8 | do
9 | EXP_NAME="${PRETRAINING_EXP_NAME}_finetune_on_${FINETUNE_DATASET}"
10 |
11 | # We reuse the same seeds as used for pretraining
12 | FILENAME="experiments/seeds/${PRETRAINING_EXP_NAME}"
13 | echo $PRETRAINING_EXP_NAME
14 |
15 | # Get the seeds
16 | if [ -f $FILENAME ]; then
17 | echo "${FILENAME} found. Reading seeds."
18 | SEEDS=()
19 | while read -r LINE; do
20 | SEEDS+=("$LINE")
21 | done < $FILENAME
22 | echo "Found ${#SEEDS[@]} seeds for finetuning."
23 | else
24 | echo "${FILENAME} does not exist. Cannot perform finetuning."
25 | exit 0
26 | fi
27 |
28 | # Iterate through all training dataset
29 | for SEED in "${SEEDS[@]}"
30 | do
31 | EXPERIMENT_NAME="${EXP_NAME}_seed_${SEED}"
32 |
33 | python run.py \
34 | -e $EXPERIMENT_NAME -d "datasets" --seed $SEED \
35 | -r "experiments/results" \
36 | --batch_size 512 -m 1000 -n 128 \
37 | --wandb_entity "enter-wandb-entity" --wandb_project "enter-wandb-project" --wandb_tags "enter-wandb-tags-or-remove-this-argument" \
38 | --num_workers 2 --args_from_dict_path $CONFIGPATH --search_batch_size \
39 | --single_dataset $FINETUNE_DATASET \
40 | --get_ckpt_path_from_experiment_name $PRETRAINING_EXP_NAME --lr 0.00001 --use_dataset_prediction_length --num_validation_windows 1 \
41 | --single_dataset_last_k_percentage $PERCENTAGE
42 | done
43 | done
--------------------------------------------------------------------------------
/scripts/pretrain.sh:
--------------------------------------------------------------------------------
1 | # PLEASE FOLLOW THE BELOW INSTRUCTIONS FIRST
2 |
3 | # 1. Install the requirements. It is recommend to use a new Anaconda environment with Python 3.10.8. Execute the below command (remove the #)
4 | # !pip install -r requirements.txt
5 |
6 | # 2. Please download https://drive.google.com/file/d/1JrDWMZyoPsc6d1wAAjgm3PosbGus-jCE/view?usp=sharing and use the below command to download the non-monash datasets (remove the #)
7 | # tar -xvzf nonmonash_datasets.tar.gz -C datasets
8 |
9 | # 3. Edit the Weights and Biases arguments on line 59 of this script
10 |
11 | mkdir -p experiments
12 | mkdir -p experiments/seeds
13 | mkdir -p experiments/results
14 |
15 | EXP_NAME="pretraining_lag_llama"
16 | FILENAME="experiments/seeds/${EXP_NAME}"
17 | CONFIGPATH="configs/lag_llama.json"
18 |
19 | echo $EXP_NAME
20 |
21 | # NUM_SEEDS used only if it is a new experiment
22 | NUM_SEEDS=1
23 |
24 | # Create seeds
25 | if [ -f $FILENAME ]; then
26 | echo "${FILENAME} already exists."
27 |
28 | SEEDS=()
29 | while read -r LINE; do
30 | SEEDS+=("$LINE")
31 | done < $FILENAME
32 | echo "Found ${#SEEDS[@]} seeds for training."
33 | else
34 | # Write seeds
35 | echo "${FILENAME} created. Writing seeds."
36 | touch $FILENAME
37 | for (( i = 0; i < $NUM_SEEDS; i++ ))
38 | do
39 | SEED=$((RANDOM + 1))
40 | echo $SEED >> $FILENAME
41 | done
42 |
43 | # Read them
44 | SEEDS=()
45 | while read -r LINE; do
46 | SEEDS+=("$LINE")
47 | done < $FILENAME
48 | fi
49 |
50 | # Train
51 | for SEED in "${SEEDS[@]}"
52 | do
53 | EXPERIMENT_NAME="${EXP_NAME}_seed_${SEED}"
54 |
55 | python run.py \
56 | -e $EXP_NAME -d "datasets" --seed $SEED \
57 | -r "experiments/results" \
58 | --batch_size 512 -m 1000 -n 128 \
59 | --wandb_entity "enter-wandb-entity" --wandb_project "enter-wandb-project" --wandb_tags "enter-wandb-tags-or-remove-this-argument" \
60 | --all_datasets "australian_electricity_demand" "electricity_hourly" "london_smart_meters_without_missing" "solar_10_minutes" "wind_farms_without_missing" "pedestrian_counts" "uber_tlc_hourly" "traffic" "kdd_cup_2018_without_missing" "saugeenday" "sunspot_without_missing" "exchange_rate" "cpu_limit_minute" "cpu_usage_minute" "function_delay_minute" "instances_minute" "memory_limit_minute" "memory_usage_minute" "platform_delay_minute" "requests_minute" "ett_h1" "ett_h2" "ett_m1" "ett_m2" "beijing_pm25" "AirQualityUCI" "beijing_multisite" "weather" \
61 | --test_datasets "weather" "pedestrian_counts" "exchange_rate" "ett_m2" "platform_delay_minute" "requests_minute" "beijing_pm25" \
62 | --num_workers 2 --args_from_dict_path $CONFIGPATH --search_batch_size \
63 | --lr 0.0001
64 | done
65 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Arjun Ashok
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import random
17 | from itertools import islice
18 |
19 | import matplotlib.lines as mlines
20 | import matplotlib.pyplot as plt
21 | import numpy as np
22 | import torch
23 |
24 |
25 | def set_seed(seed, deterministic=True):
26 | random.seed(seed)
27 | np.random.seed(seed)
28 | torch.manual_seed(seed)
29 | if torch.cuda.is_available():
30 | torch.cuda.manual_seed(seed)
31 | torch.cuda.manual_seed_all(seed)
32 | torch.backends.cudnn.deterministic = deterministic
33 | torch.backends.cudnn.benchmark = False
34 | os.environ["PYTHONHASHSEED"] = str(seed)
35 |
36 |
37 | def print_gpu_stats():
38 | # Print GPU stats
39 | device = torch.cuda.current_device()
40 | memory_stats = torch.cuda.memory_stats(device=device)
41 | t = torch.cuda.get_device_properties(0).total_memory / (1024**3)
42 | allocated_memory_gb = memory_stats["allocated_bytes.all.current"] / (1024**3)
43 | print(f"Total Memory: {t:.2f} GB")
44 | print(f"Allocated Memory: {allocated_memory_gb:.2f} GB")
45 |
46 |
47 | def plot_forecasts(forecasts, tss, prediction_length):
48 | plt.figure(figsize=(20, 15))
49 | plt.rcParams.update({"font.size": 15})
50 |
51 | # Create custom legend handles
52 | forecast_line = mlines.Line2D([], [], color="g", label="Forecast")
53 | target_line = mlines.Line2D([], [], color="blue", label="Target")
54 |
55 | for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
56 | ax = plt.subplot(3, 3, idx + 1)
57 | forecast.plot(color="g")
58 | # ax.plot(forecast, color='g', label="Forecast")
59 | # ts[-3 * dataset.metadata.prediction_length:][0].plot(label="target")
60 | ts[-3 * prediction_length :][0].plot(label="target", ax=ax)
61 | plt.xticks(rotation=60)
62 | ax.set_title(forecast.item_id)
63 | # ax.legend() # Add legend to each subplot
64 | ax.legend(handles=[forecast_line, target_line])
65 |
66 | plt.gcf().tight_layout()
67 | return plt.gcf()
68 |
--------------------------------------------------------------------------------