├── .gitignore
├── LICENSE
├── README.md
├── data_provider
├── __init__.py
├── data_factory.py
├── data_loader.py
└── m4.py
├── exp
├── exp_basic.py
├── exp_in_context_forecasting.py
├── exp_long_term_forecasting.py
├── exp_short_term_forecasting.py
└── exp_zero_shot_forecasting.py
├── figures
├── ablation.png
├── ablation_llm.png
├── adaption_efficiency.png
├── comparison.png
├── formulation.png
├── icon.png
├── illustration.png
├── in-context.png
├── lora.png
├── method.png
├── motivation.png
├── one-for-all_results.png
├── param.png
├── showcases.png
├── subway_icf.png
└── zeroshot_results.png
├── layers
├── __init__.py
└── mlp.py
├── models
├── AutoTimes_Gpt2.py
├── AutoTimes_Llama.py
├── AutoTimes_Opt_1b.py
├── Preprocess_Llama.py
└── __init__.py
├── predict.ipynb
├── preprocess.py
├── requirements.txt
├── run.py
├── scripts
├── in_context_forecasting
│ └── M3.sh
├── method_generality
│ ├── gpt2.sh
│ └── opt.sh
├── time_series_forecasting
│ ├── long_term
│ │ ├── AutoTimes_ECL.sh
│ │ ├── AutoTimes_ETTh1.sh
│ │ ├── AutoTimes_Solar.sh
│ │ ├── AutoTimes_Traffic.sh
│ │ └── AutoTimes_Weather.sh
│ └── short_term
│ │ └── AutoTimes_M4.sh
└── zero_shot_forecasting
│ ├── sM3_tM4.sh
│ └── sM4_tM3.sh
└── utils
├── __init__.py
├── losses.py
├── m4_summary.py
├── metrics.py
└── tools.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | dataset/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | */.DS_Store
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
136 | /dataset/
137 | /llama/
138 | /checkpoints/
139 | /test_results/
140 | /m4_results/
141 | /models--gpt2/
142 | /models--facebook--opt-1.3b/
143 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 THUML @ Tsinghua University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AutoTimes (Large Language Models for Time Series Forecasting)
2 |
3 | Official implementation: [AutoTimes: Autoregressive Time Series Forecasters via Large Language Models](https://arxiv.org/abs/2402.02370). [[Slides]](https://cloud.tsinghua.edu.cn/f/7689d30f92594ded84f0/), [[Poster]](https://cloud.tsinghua.edu.cn/f/f2c18ae34fef4e74ad46/)
4 |
5 |
6 | > **[Time Series Forecasting](./scripts/time_series_forecasting/)**: AutoTimes convert LLMs to autoregressive time series forecasters. Unlike previous methods, the forecaster can accommodate arbitrary-length lookback and predictions.
7 |
8 | > **[Zero-Shot Forecasting](./scripts/zero_shot_forecasting/)**: AutoTimes takes advantage of LLM's general-purposed token transition as the future extrapolation of time series, demonstrating good performance without downstream samples.
9 |
10 | > **[In-Context Forecasting](./scripts/in_context_forecasting/)**: We propose in-context forecasting for the first time, where time series prompts can be incorporated into the input context to enhance forecasting.
11 |
12 | > **[Easy-to-Use](scripts/method_generality)**: AutoTimes is compatiable with any decoder-only large language models, demonstrating generality and proper scaling behavior.
13 |
14 | ## Updates
15 |
16 | :triangular_flag_on_post: News (2024.10): An introduction of our works in available [[Slides]](https://cloud.tsinghua.edu.cn/f/7689d30f92594ded84f0/). See you at **NeurIPS 2024**!
17 |
18 | :triangular_flag_on_post: News (2024.10): AutoTimes has been accepted by **NeurIPS 2024**. [A revised version](https://arxiv.org/pdf/2402.02370) (**25 Pages**) is now available, including prompt engineering of in-context forecasting, adaptation cost evaluations, textual embeddings of metadata, and low-rank adaptation techique.
19 |
20 | :triangular_flag_on_post: News (2024.08): [Recent work](https://arxiv.org/abs/2406.16964) [(code)](https://github.com/bennytmt/ts_models) has also raised questions about previous non-autoregressive LLM4TS methods. We conduct ablations [here](./figures/ablation_llm.png), highlighting AutoTimes can truly utilize LLMs. Instead of adopting LLMs in a BERT-style, **the general-purpose token transition is transferable among time series and natural language**.
21 |
22 |
23 |
24 |
25 |
26 | :triangular_flag_on_post: **News** (2024.2) Scripts for the above tasks in our [paper](https://arxiv.org/pdf/2402.02370.pdf) are all available.
27 |
28 | ## Introduction
29 |
30 | 🌟 While prevalent LLM4TS methods adapt LLMs as encoder-only and non-autoregressive forecasters, we propose to **keep consistent with the inherent autoregressive property and model architecture**.
31 |
32 |
33 |
34 |
35 |
36 | 💪 We aim to **fully revitalize LLMs as foundation models for time series forecasting**, including multi-step forecasting, zero-shot capability, **in-context forecasting**, and multimodal utilization.
37 |
38 | 🏆 AutoTimes achieves **state-of-the-art performance** with **0.1% trainable parameters and over 5× training/inference speedup** compared to advanced LLM-based forecasters.
39 |
40 | ## Usage
41 |
42 | 1. Install Pytorch and necessary dependencies.
43 |
44 | ```
45 | pip install -r requirements.txt
46 | ```
47 |
48 | 1. Put the datasets [[Google Drive]](https://drive.google.com/file/d/1t7jOkctNJ0rt3VMwZaqmxSuA75TFEo96/view?usp=sharing)
49 | [[Tsinghua Cloud]](https://cloud.tsinghua.edu.cn/f/0a758154e0d44de890e3/) under the folder ```./dataset/```.
50 |
51 | 2. Download the large language models from [Hugging Face](https://huggingface.co/). The default LLM is LLaMA-7B, you can change the `llm_ckp_dir` in `run.py` to use other LLMs.
52 | * [LLaMA-7B](https://huggingface.co/meta-llama/Llama-2-7b)
53 | * [OPT Family](https://huggingface.co/facebook/opt-125m)
54 | * [GPT2](https://huggingface.co/openai-community/gpt2)
55 |
56 | For example, if you download and put the LLaMA directory successfully, the directory structure is as follows:
57 | - data_provider
58 | - dataset
59 | - llama
60 | - config.json
61 | - pytorch_model-00001-of-00002.bin
62 | - pytorch_model-00002-of-00002.bin
63 | - ...
64 | - ...
65 | - run.py
66 |
67 | 3. Using the position embedding from textual timestamps. Note that we have provided the embeddings of the given datasets in the download links, which are generated by LLaMA, suffixed by `{dataset_name}.pt`. If you want to generate the embeddings from your customized datasets, please refer to the following codes:
68 | ```
69 | # preprocess timestamps to generate text embedding
70 | python ./preprocess.py --gpu 0 --dataset ETTh1
71 | ```
72 |
73 | 4. Train and evaluate the model. We provide all the above tasks under the folder ```./scripts/```.
74 |
75 | ```
76 | # the default large language model is LLaMA-7B
77 |
78 | # long-term forecasting
79 | bash ./scripts/time_series_forecasting/long_term/AutoTimes_ETTh1.sh
80 |
81 | # short-term forecasting
82 | bash ./scripts/time_series_forecasting/short_term/AutoTimes_M4.sh
83 |
84 | # zero-shot forecasting
85 | # it's worth noting that sM4_tM3 utilizes models trained
86 | # on short-term, you should run AutoTimes_M4 first
87 | bash ./scripts/zero_shot_forecasting/sM4_tM3.sh
88 | bash ./scripts/zero_shot_forecasting/sM3_tM4.sh
89 |
90 | # in-context forecasting
91 | bash ./scripts/in_context_forecasting/M3.sh
92 |
93 | # try on other large language models
94 | bash ./scripts/method_generality/opt.sh
95 | ```
96 |
97 | > Due to the simple tokenization and the frozen of LLM blocks, AutoTimes is highly compatiable with LLMs. For example, it requires only **15min** for AutoTime to repurpuse LLaMA-7B on ETTh1 on one single RTX 3090-24G.
98 |
99 | ### A Usage Example
100 | See ```predict.ipynb``` for a simple training and inference workflow.
101 |
102 | ## Overall Approach
103 |
104 |
105 |
106 |
107 |
108 | ## Comparison
109 |
110 |
111 |
112 |
113 |
114 |
115 | ## Time Series Forecasting
116 |
117 | **One-for-all** benchmark: a single forecaster is trained on one dataset and subsequently used for all prediction lengths.
118 |
119 |
120 |
121 |
122 |
123 |
124 | ## In-Context Forecasting
125 |
126 |
127 |
128 |
129 |
130 | Benefiting from time series prompts from the target domain, AutoTimes achieves averaged **13.3%** SMAPE reduction compared with zero-shot forecasting.
131 |
132 | ## Method Efficiency
133 |
134 |
135 |
136 |
137 |
138 | ## Showcases
139 | We investigate different prompt retrieval strategies. Insightful results are provided to reveal the influence of using time series prompts for interactive prediction.
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 | ## Citation
150 |
151 | If you find this repo helpful, please cite our paper.
152 |
153 | ```
154 | @article{liu2024autotimes,
155 | title={AutoTimes: Autoregressive Time Series Forecasters via Large Language Models},
156 | author={Liu, Yong and Qin, Guo and Huang, Xiangdong and Wang, Jianmin and Long, Mingsheng},
157 | journal={arXiv preprint arXiv:2402.02370},
158 | year={2024}
159 | }
160 | ```
161 |
162 | ## Acknowledgement
163 |
164 | We appreciate the following GitHub repos a lot for their valuable code and efforts.
165 | - Time-Series-Library (https://github.com/thuml/Time-Series-Library)
166 | - FPT (https://github.com/DAMO-DI-ML/NeurIPS2023-One-Fits-All)
167 |
168 | ## Contact
169 |
170 | If you have any questions or want to use the code, feel free to contact:
171 | * Yong Liu (liuyong21@mails.tsinghua.edu.cn)
172 | * Guo Qin (qinguo24@mails.tsinghua.edu.cn)
173 |
--------------------------------------------------------------------------------
/data_provider/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/data_provider/__init__.py
--------------------------------------------------------------------------------
/data_provider/data_factory.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_Custom, Dataset_M4, Dataset_Solar, Dataset_TSF, Dataset_TSF_ICL
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.distributed import DistributedSampler
4 |
5 | data_dict = {
6 | 'ETTh1': Dataset_ETT_hour,
7 | 'custom': Dataset_Custom,
8 | 'm4': Dataset_M4,
9 | 'Solar': Dataset_Solar,
10 | 'tsf': Dataset_TSF,
11 | 'tsf_icl': Dataset_TSF_ICL
12 | }
13 |
14 |
15 | def data_provider(args, flag):
16 | Data = data_dict[args.data]
17 |
18 | if flag == 'test':
19 | shuffle_flag = False
20 | drop_last = False
21 | batch_size = args.batch_size
22 | elif flag == 'val':
23 | shuffle_flag = args.val_set_shuffle
24 | drop_last = False
25 | batch_size = args.batch_size
26 | else:
27 | shuffle_flag = True
28 | drop_last = args.drop_last
29 | batch_size = args.batch_size
30 |
31 | if flag in ['train', 'val']:
32 | data_set = Data(
33 | root_path=args.root_path,
34 | data_path=args.data_path,
35 | flag=flag,
36 | size=[args.seq_len, args.label_len, args.token_len],
37 | seasonal_patterns=args.seasonal_patterns,
38 | drop_short=args.drop_short,
39 | )
40 | else:
41 | data_set = Data(
42 | root_path=args.root_path,
43 | data_path=args.data_path,
44 | flag=flag,
45 | size=[args.test_seq_len, args.test_label_len, args.test_pred_len],
46 | seasonal_patterns=args.seasonal_patterns,
47 | drop_short=args.drop_short,
48 | )
49 | if (args.use_multi_gpu and args.local_rank == 0) or not args.use_multi_gpu:
50 | print(flag, len(data_set))
51 | if args.use_multi_gpu:
52 | train_datasampler = DistributedSampler(data_set, shuffle=shuffle_flag)
53 | data_loader = DataLoader(data_set,
54 | batch_size=batch_size,
55 | sampler=train_datasampler,
56 | num_workers=args.num_workers,
57 | persistent_workers=True,
58 | pin_memory=True,
59 | drop_last=drop_last,
60 | )
61 | else:
62 | data_loader = DataLoader(
63 | data_set,
64 | batch_size=batch_size,
65 | shuffle=shuffle_flag,
66 | num_workers=args.num_workers,
67 | drop_last=drop_last)
68 | return data_set, data_loader
--------------------------------------------------------------------------------
/data_provider/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import datetime
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | from torch.utils.data import Dataset
7 | from data_provider.m4 import M4Dataset, M4Meta
8 | from sklearn.preprocessing import StandardScaler
9 | from utils.tools import convert_tsf_to_dataframe
10 | import warnings
11 |
12 | warnings.filterwarnings('ignore')
13 |
14 |
15 | class Dataset_ETT_hour(Dataset):
16 | def __init__(self, root_path, flag='train', size=None, data_path='ETTh1.csv',
17 | scale=True, seasonal_patterns=None, drop_short=False):
18 | self.seq_len = size[0]
19 | self.label_len = size[1]
20 | self.pred_len = size[2]
21 | self.token_len = self.seq_len - self.label_len
22 | self.token_num = self.seq_len // self.token_len
23 | self.flag = flag
24 | # init
25 | assert flag in ['train', 'test', 'val']
26 | type_map = {'train': 0, 'val': 1, 'test': 2}
27 | self.set_type = type_map[flag]
28 |
29 | self.scale = scale
30 |
31 | self.root_path = root_path
32 | self.data_path = data_path
33 | self.__read_data__()
34 | self.enc_in = self.data_x.shape[-1]
35 | self.tot_len = len(self.data_x) - self.seq_len - self.pred_len + 1
36 |
37 | def __read_data__(self):
38 | self.scaler = StandardScaler()
39 | df_raw = pd.read_csv(os.path.join(self.root_path,
40 | self.data_path))
41 |
42 | border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
43 | border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
44 | border1 = border1s[self.set_type]
45 | border2 = border2s[self.set_type]
46 |
47 | cols_data = df_raw.columns[1:]
48 | df_data = df_raw[cols_data]
49 |
50 | if self.scale:
51 | train_data = df_data[border1s[0]:border2s[0]]
52 | self.scaler.fit(train_data.values)
53 | data = self.scaler.transform(df_data.values)
54 | else:
55 | data = df_data.values
56 |
57 | data_name = self.data_path.split('.')[0]
58 | self.data_stamp = torch.load(os.path.join(self.root_path, f'{data_name}.pt'))
59 | self.data_stamp = self.data_stamp[border1:border2]
60 | self.data_x = data[border1:border2]
61 | self.data_y = data[border1:border2]
62 |
63 | def __getitem__(self, index):
64 | feat_id = index // self.tot_len
65 | s_begin = index % self.tot_len
66 |
67 | s_end = s_begin + self.seq_len
68 | r_begin = s_end - self.label_len
69 | r_end = r_begin + self.label_len + self.pred_len
70 | seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1]
71 | seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1]
72 | seq_x_mark = self.data_stamp[s_begin:s_end:self.token_len]
73 | seq_y_mark = self.data_stamp[s_end:r_end:self.token_len]
74 |
75 | return seq_x, seq_y, seq_x_mark, seq_y_mark
76 |
77 | def __len__(self):
78 | return (len(self.data_x) - self.seq_len - self.pred_len + 1) * self.enc_in
79 |
80 | def inverse_transform(self, data):
81 | return self.scaler.inverse_transform(data)
82 |
83 | class Dataset_Custom(Dataset):
84 | def __init__(self, root_path, flag='train', size=None, data_path='ETTh1.csv',
85 | scale=True, seasonal_patterns=None, drop_short=False):
86 | self.seq_len = size[0]
87 | self.label_len = size[1]
88 | self.pred_len = size[2]
89 | self.token_len = self.seq_len - self.label_len
90 | self.token_num = self.seq_len // self.token_len
91 | self.flag = flag
92 | # init
93 | assert flag in ['train', 'test', 'val']
94 | type_map = {'train': 0, 'val': 1, 'test': 2}
95 | self.set_type = type_map[flag]
96 |
97 | self.scale = scale
98 |
99 | self.root_path = root_path
100 | self.data_path = data_path
101 | self.__read_data__()
102 | self.enc_in = self.data_x.shape[-1]
103 | self.tot_len = len(self.data_x) - self.seq_len - self.pred_len + 1
104 |
105 | def __read_data__(self):
106 | self.scaler = StandardScaler()
107 | df_raw = pd.read_csv(os.path.join(self.root_path,
108 | self.data_path))
109 | num_train = int(len(df_raw) * 0.7)
110 | num_test = int(len(df_raw) * 0.2)
111 | num_vali = len(df_raw) - num_train - num_test
112 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
113 | border2s = [num_train, num_train + num_vali, len(df_raw)]
114 | border1 = border1s[self.set_type]
115 | border2 = border2s[self.set_type]
116 |
117 |
118 | cols_data = df_raw.columns[1:]
119 | df_data = df_raw[cols_data]
120 |
121 | if self.scale:
122 | train_data = df_data[border1s[0]:border2s[0]]
123 | self.scaler.fit(train_data.values)
124 | data = self.scaler.transform(df_data.values)
125 | else:
126 | data = df_data.values
127 | data_name = self.data_path.split('.')[0]
128 | self.data_stamp = torch.load(os.path.join(self.root_path, f'{data_name}.pt'))
129 | self.data_stamp = self.data_stamp[border1:border2]
130 | self.data_x = data[border1:border2]
131 | self.data_y = data[border1:border2]
132 |
133 |
134 | def __getitem__(self, index):
135 | feat_id = index // self.tot_len
136 | s_begin = index % self.tot_len
137 |
138 | s_end = s_begin + self.seq_len
139 | r_begin = s_end - self.label_len
140 | r_end = r_begin + self.label_len + self.pred_len
141 | seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1]
142 | seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1]
143 | seq_x_mark = self.data_stamp[s_begin:s_end:self.token_len]
144 | seq_y_mark = self.data_stamp[s_end:r_end:self.token_len]
145 | return seq_x, seq_y, seq_x_mark, seq_y_mark
146 |
147 | def __len__(self):
148 | return (len(self.data_x) - self.seq_len - self.pred_len + 1) * self.enc_in
149 |
150 | def inverse_transform(self, data):
151 | return self.scaler.inverse_transform(data)
152 |
153 |
154 | class Dataset_Solar(Dataset):
155 | def __init__(self, root_path, flag='train', size=None, data_path='ETTh1.csv',
156 | seasonal_patterns=None, scale=True, drop_short=False):
157 | # size [seq_len, label_len, pred_len]
158 | # info
159 | self.seq_len = size[0]
160 | self.label_len = size[1]
161 | self.pred_len = size[2]
162 |
163 | self.token_len = self.seq_len - self.label_len
164 | self.token_num = self.seq_len // self.token_len
165 | self.flag = flag
166 | # init
167 | assert flag in ['train', 'test', 'val']
168 | type_map = {'train': 0, 'val': 1, 'test': 2}
169 | self.set_type = type_map[flag]
170 |
171 | self.scale = scale
172 |
173 | self.root_path = root_path
174 | self.data_path = data_path
175 | self.__read_data__()
176 | self.enc_in = self.data_x.shape[-1]
177 | self.tot_len = len(self.data_x) - self.seq_len - self.pred_len + 1
178 |
179 | def __read_data__(self):
180 | self.scaler = StandardScaler()
181 | df_raw = []
182 | with open(os.path.join(self.root_path, self.data_path), "r", encoding='utf-8') as f:
183 | for line in f.readlines():
184 | line = line.strip('\n').split(',')
185 | data_line = np.stack([float(i) for i in line])
186 | df_raw.append(data_line)
187 | df_raw = np.stack(df_raw, 0)
188 | df_raw = pd.DataFrame(df_raw)
189 |
190 | num_train = int(len(df_raw) * 0.7)
191 | num_test = int(len(df_raw) * 0.2)
192 | num_valid = int(len(df_raw) * 0.1)
193 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
194 | border2s = [num_train, num_train + num_valid, len(df_raw)]
195 | border1 = border1s[self.set_type]
196 | border2 = border2s[self.set_type]
197 |
198 | df_data = df_raw.values
199 |
200 | if self.scale:
201 | train_data = df_data[border1s[0]:border2s[0]]
202 | self.scaler.fit(train_data)
203 | data = self.scaler.transform(df_data)
204 | else:
205 | data = df_data
206 |
207 | self.data_x = data[border1:border2]
208 | self.data_y = data[border1:border2]
209 |
210 | def __getitem__(self, index):
211 | feat_id = index // self.tot_len
212 | s_begin = index % self.tot_len
213 |
214 | s_end = s_begin + self.seq_len
215 | r_begin = s_end - self.label_len
216 | r_end = r_begin + self.label_len + self.pred_len
217 | seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1]
218 | seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1]
219 | seq_x_mark = torch.zeros((seq_x.shape[0], 1))
220 | seq_y_mark = torch.zeros((seq_x.shape[0], 1))
221 |
222 | return seq_x, seq_y, seq_x_mark, seq_y_mark
223 |
224 | def __len__(self):
225 | return (len(self.data_x) - self.seq_len - self.pred_len + 1) * self.enc_in
226 |
227 | def inverse_transform(self, data):
228 | return self.scaler.inverse_transform(data)
229 |
230 |
231 | class Dataset_M4(Dataset):
232 | def __init__(self, root_path, flag='pred', size=None, data_path='ETTh1.csv',
233 | scale=False, inverse=False, seasonal_patterns='Yearly', drop_short=False):
234 | self.scale = scale
235 | self.inverse = inverse
236 | self.root_path = root_path
237 |
238 | self.seq_len = size[0]
239 | self.label_len = size[1]
240 | self.pred_len = size[2]
241 |
242 | self.seasonal_patterns = seasonal_patterns
243 | self.history_size = M4Meta.history_size[seasonal_patterns]
244 | self.window_sampling_limit = int(self.history_size * self.pred_len)
245 | self.flag = flag
246 |
247 | self.__read_data__()
248 |
249 | def __read_data__(self):
250 | # M4Dataset.initialize()
251 | if self.flag == 'train':
252 | dataset = M4Dataset.load(training=True, dataset_file=self.root_path)
253 | else:
254 | dataset = M4Dataset.load(training=False, dataset_file=self.root_path)
255 | training_values = np.array(
256 | [v[~np.isnan(v)] for v in
257 | dataset.values[dataset.groups == self.seasonal_patterns]]) # split different frequencies
258 | self.ids = np.array([i for i in dataset.ids[dataset.groups == self.seasonal_patterns]])
259 | self.timeseries = [ts for ts in training_values]
260 |
261 | def __getitem__(self, index):
262 | insample = np.zeros((self.seq_len, 1))
263 | insample_mask = np.zeros((self.seq_len, 1))
264 | outsample = np.zeros((self.pred_len + self.label_len, 1))
265 | outsample_mask = np.zeros((self.pred_len + self.label_len, 1)) # m4 dataset
266 |
267 | sampled_timeseries = self.timeseries[index]
268 | cut_point = np.random.randint(low=max(1, len(sampled_timeseries) - self.window_sampling_limit),
269 | high=len(sampled_timeseries),
270 | size=1)[0]
271 |
272 | insample_window = sampled_timeseries[max(0, cut_point - self.seq_len):cut_point]
273 | insample[-len(insample_window):, 0] = insample_window
274 | insample_mask[-len(insample_window):, 0] = 1.0
275 | outsample_window = sampled_timeseries[
276 | cut_point - self.label_len:min(len(sampled_timeseries), cut_point + self.pred_len)]
277 | outsample[:len(outsample_window), 0] = outsample_window
278 | outsample_mask[:len(outsample_window), 0] = 1.0
279 | return insample, outsample, insample_mask, outsample_mask
280 |
281 | def __len__(self):
282 | return len(self.timeseries)
283 |
284 | def inverse_transform(self, data):
285 | return self.scaler.inverse_transform(data)
286 |
287 | def last_insample_window(self):
288 | """
289 | The last window of insample size of all timeseries.
290 | This function does not support batching and does not reshuffle timeseries.
291 |
292 | :return: Last insample window of all timeseries. Shape "timeseries, insample size"
293 | """
294 | insample = np.zeros((len(self.timeseries), self.seq_len))
295 | insample_mask = np.zeros((len(self.timeseries), self.seq_len))
296 | for i, ts in enumerate(self.timeseries):
297 | ts_last_window = ts[-self.seq_len:]
298 | insample[i, -len(ts):] = ts_last_window
299 | insample_mask[i, -len(ts):] = 1.0
300 | return insample, insample_mask
301 |
302 |
303 | class Dataset_TSF(Dataset):
304 | def __init__(self, root_path, flag='train', size=None, data_path=None,
305 | scale=True, seasonal_patterns=None, drop_short=False):
306 |
307 | self.seq_len = size[0]
308 | self.label_len = size[1]
309 | self.pred_len = size[2]
310 | self.token_len = self.pred_len
311 | self.context_len = 4 * self.token_len
312 | print(self.seq_len, self.label_len, self.pred_len)
313 | type_map = {'train': 0, 'val': 1, 'test': 2}
314 | self.set_type = type_map[flag]
315 |
316 | self.root_path = root_path
317 | self.data_path = data_path
318 | self.drop_short = drop_short
319 | self.timeseries = self.__read_data__()
320 |
321 |
322 | def __read_data__(self):
323 | df, _, _, _, _ = convert_tsf_to_dataframe(os.path.join(self.root_path, self.data_path))
324 | def dropna(x):
325 | return x[~np.isnan(x)]
326 | timeseries = [dropna(ts).astype(np.float32) for ts in df.series_value]
327 | if self.drop_short:
328 | timeseries = [ts for ts in timeseries if ts.shape[0] > self.context_len]
329 | self.tot_len = 0
330 | self.len_seq = []
331 | self.seq_id = []
332 | for i in range(len(timeseries)):
333 | res_len = max(self.pred_len + self.seq_len - timeseries[i].shape[0], 0)
334 | pad_zeros = np.zeros(res_len)
335 | timeseries[i] = np.hstack([pad_zeros, timeseries[i]])
336 |
337 | _len = timeseries[i].shape[0]
338 | train_len = _len-self.pred_len
339 | border1s = [0, train_len - self.seq_len - self.pred_len, train_len-self.seq_len]
340 | border2s = [train_len - self.pred_len, train_len, _len]
341 |
342 | curr_len = border2s[self.set_type] - max(border1s[self.set_type], 0) - self.pred_len - self.seq_len + 1
343 | curr_len = max(0, curr_len)
344 |
345 | self.len_seq.append(np.zeros(curr_len) + self.tot_len)
346 | self.seq_id.append(np.zeros(curr_len) + i)
347 | self.tot_len += curr_len
348 |
349 | self.len_seq = np.hstack(self.len_seq)
350 | self.seq_id = np.hstack(self.seq_id)
351 |
352 | return timeseries
353 |
354 | def __getitem__(self, index):
355 | len_seq = self.len_seq[index]
356 | seq_id = int(self.seq_id[index])
357 | index = index - int(len_seq)
358 |
359 | _len = self.timeseries[seq_id].shape[0]
360 | train_len = _len - self.pred_len
361 | border1s = [0, train_len - self.seq_len - self.pred_len, train_len-self.seq_len]
362 |
363 | s_begin = index + border1s[self.set_type]
364 | s_end = s_begin + self.seq_len
365 | r_begin = s_end - self.label_len
366 | r_end = s_end + self.pred_len
367 |
368 | data_x = self.timeseries[seq_id][s_begin:s_end]
369 | data_y = self.timeseries[seq_id][r_begin:r_end]
370 | data_x = np.expand_dims(data_x, axis=-1)
371 | data_y = np.expand_dims(data_y, axis=-1)
372 |
373 | return data_x, data_y, data_x, data_y
374 |
375 | def __len__(self):
376 | return self.tot_len
377 |
378 | class Dataset_TSF_ICL(Dataset):
379 | def __init__(self, root_path, flag='train', size=None, data_path=None,
380 | scale=True, seasonal_patterns=None, drop_short=True):
381 |
382 | self.pred_len = size[2]
383 | self.token_len = self.pred_len
384 | self.context_len = 4 * self.token_len
385 |
386 | self.root_path = root_path
387 | self.data_path = data_path
388 | self.timeseries = self.__read_data__()
389 |
390 | def __read_data__(self):
391 | df, _, _, _, _ = convert_tsf_to_dataframe(os.path.join(self.root_path, self.data_path))
392 | def dropna(x):
393 | return x[~np.isnan(x)]
394 | timeseries = [dropna(ts).astype(np.float32) for ts in df.series_value]
395 | timeseries = [ts for ts in timeseries if ts.shape[0] > self.context_len]
396 | return timeseries
397 |
398 | # we uniformly adopting the first time points of the time series as the corresponding prompt.
399 | def __getitem__(self, index):
400 | data_x1 = self.timeseries[index][:2*self.token_len]
401 | data_x2 = self.timeseries[index][-2*self.token_len:-1*self.token_len]
402 | data_x = np.concatenate((data_x1, data_x2))
403 | data_y = self.timeseries[index][-1*self.token_len:]
404 | data_x = np.expand_dims(data_x, axis=-1)
405 | data_y = np.expand_dims(data_y, axis=-1)
406 | return data_x, data_y, data_x, data_y
407 |
408 | def __len__(self):
409 | return len(self.timeseries)
410 |
411 | class Dataset_Preprocess(Dataset):
412 | def __init__(self, root_path, flag='train', size=None,
413 | data_path='ETTh1.csv', scale=True, seasonal_patterns=None):
414 | self.seq_len = size[0]
415 | self.label_len = size[1]
416 | self.pred_len = size[2]
417 | self.token_len = self.seq_len - self.label_len
418 | self.token_num = self.seq_len // self.token_len
419 | self.flag = flag
420 | self.data_set_type = data_path.split('.')[0]
421 | # init
422 | assert flag in ['train', 'test', 'val']
423 | type_map = {'train': 0, 'val': 1, 'test': 2}
424 | self.set_type = type_map[flag]
425 |
426 | self.scale = scale
427 |
428 | self.root_path = root_path
429 | self.data_path = data_path
430 | self.__read_data__()
431 | self.tot_len = len(self.data_stamp)
432 |
433 | def __read_data__(self):
434 | df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))
435 | df_stamp = df_raw[['date']]
436 | df_stamp['date'] = pd.to_datetime(df_stamp.date).apply(str)
437 | self.data_stamp = df_stamp['date'].values
438 | self.data_stamp = [str(x) for x in self.data_stamp]
439 |
440 |
441 | def __getitem__(self, index):
442 | s_begin = index % self.tot_len
443 | s_end = s_begin + self.token_len
444 | start = datetime.datetime.strptime(self.data_stamp[s_begin], "%Y-%m-%d %H:%M:%S")
445 | if self.data_set_type in ['traffic', 'electricity', 'ETTh1', 'ETTh2']:
446 | end = (start + datetime.timedelta(hours=self.token_len-1)).strftime("%Y-%m-%d %H:%M:%S")
447 | elif self.data_set_type == 'weather':
448 | end = (start + datetime.timedelta(minutes=10*(self.token_len-1))).strftime("%Y-%m-%d %H:%M:%S")
449 | elif self.data_set_type in ['ETTm1', 'ETTm2']:
450 | end = (start + datetime.timedelta(minutes=15*(self.token_len-1))).strftime("%Y-%m-%d %H:%M:%S")
451 | seq_x_mark = f"This is Time Series from {self.data_stamp[s_begin]} to {end}"
452 | return seq_x_mark
453 |
454 | def __len__(self):
455 | return len(self.data_stamp)
456 |
--------------------------------------------------------------------------------
/data_provider/m4.py:
--------------------------------------------------------------------------------
1 | # This source code is provided for the purposes of scientific reproducibility
2 | # under the following limited license from Element AI Inc. The code is an
3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis
4 | # expansion analysis for interpretable time series forecasting,
5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is
6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0
7 | # International license (CC BY-NC 4.0):
8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether
9 | # for the benefit of third parties or internally in production) requires an
10 | # explicit license. The subject-matter of the N-BEATS model and associated
11 | # materials are the property of Element AI Inc. and may be subject to patent
12 | # protection. No license to patents is granted hereunder (whether express or
13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved.
14 |
15 | """
16 | M4 Dataset
17 | """
18 | import logging
19 | import os
20 | from collections import OrderedDict
21 | from dataclasses import dataclass
22 | from glob import glob
23 |
24 | import numpy as np
25 | import pandas as pd
26 | import patoolib
27 | from tqdm import tqdm
28 | import logging
29 | import os
30 | import pathlib
31 | import sys
32 | from urllib import request
33 |
34 |
35 | def url_file_name(url: str) -> str:
36 | """
37 | Extract file name from url.
38 |
39 | :param url: URL to extract file name from.
40 | :return: File name.
41 | """
42 | return url.split('/')[-1] if len(url) > 0 else ''
43 |
44 |
45 | def download(url: str, file_path: str) -> None:
46 | """
47 | Download a file to the given path.
48 |
49 | :param url: URL to download
50 | :param file_path: Where to download the content.
51 | """
52 |
53 | def progress(count, block_size, total_size):
54 | progress_pct = float(count * block_size) / float(total_size) * 100.0
55 | sys.stdout.write('\rDownloading {} to {} {:.1f}%'.format(url, file_path, progress_pct))
56 | sys.stdout.flush()
57 |
58 | if not os.path.isfile(file_path):
59 | opener = request.build_opener()
60 | opener.addheaders = [('User-agent', 'Mozilla/5.0')]
61 | request.install_opener(opener)
62 | pathlib.Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True)
63 | f, _ = request.urlretrieve(url, file_path, progress)
64 | sys.stdout.write('\n')
65 | sys.stdout.flush()
66 | file_info = os.stat(f)
67 | logging.info(f'Successfully downloaded {os.path.basename(file_path)} {file_info.st_size} bytes.')
68 | else:
69 | file_info = os.stat(file_path)
70 | logging.info(f'File already exists: {file_path} {file_info.st_size} bytes.')
71 |
72 |
73 | @dataclass()
74 | class M4Dataset:
75 | ids: np.ndarray
76 | groups: np.ndarray
77 | frequencies: np.ndarray
78 | horizons: np.ndarray
79 | values: np.ndarray
80 |
81 | @staticmethod
82 | def load(training: bool = True, dataset_file: str = '../dataset/m4') -> 'M4Dataset':
83 | """
84 | Load cached dataset.
85 |
86 | :param training: Load training part if training is True, test part otherwise.
87 | """
88 | info_file = os.path.join(dataset_file, 'M4-info.csv')
89 | train_cache_file = os.path.join(dataset_file, 'training.npz')
90 | test_cache_file = os.path.join(dataset_file, 'test.npz')
91 | m4_info = pd.read_csv(info_file)
92 | return M4Dataset(ids=m4_info.M4id.values,
93 | groups=m4_info.SP.values,
94 | frequencies=m4_info.Frequency.values,
95 | horizons=m4_info.Horizon.values,
96 | values=np.load(
97 | train_cache_file if training else test_cache_file,
98 | allow_pickle=True))
99 |
100 |
101 | @dataclass()
102 | class M4Meta:
103 | seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly']
104 | horizons = [6, 8, 18, 13, 14, 48]
105 | frequencies = [1, 4, 12, 1, 1, 24]
106 | horizons_map = {
107 | 'Yearly': 6,
108 | 'Quarterly': 8,
109 | 'Monthly': 18,
110 | 'Weekly': 13,
111 | 'Daily': 14,
112 | 'Hourly': 48
113 | } # different predict length
114 | frequency_map = {
115 | 'Yearly': 1,
116 | 'Quarterly': 4,
117 | 'Monthly': 12,
118 | 'Weekly': 1,
119 | 'Daily': 1,
120 | 'Hourly': 24
121 | }
122 | history_size = {
123 | 'Yearly': 1.5,
124 | 'Quarterly': 1.5,
125 | 'Monthly': 1.5,
126 | 'Weekly': 10,
127 | 'Daily': 10,
128 | 'Hourly': 10
129 | } # from interpretable.gin
130 |
131 |
132 | def load_m4_info() -> pd.DataFrame:
133 | """
134 | Load M4Info file.
135 |
136 | :return: Pandas DataFrame of M4Info.
137 | """
138 | return pd.read_csv(INFO_FILE_PATH)
--------------------------------------------------------------------------------
/exp/exp_basic.py:
--------------------------------------------------------------------------------
1 | from models import AutoTimes_Llama, AutoTimes_Gpt2, AutoTimes_Opt_1b
2 |
3 |
4 | class Exp_Basic(object):
5 | def __init__(self, args):
6 | self.args = args
7 | self.model_dict = {
8 | 'AutoTimes_Llama': AutoTimes_Llama,
9 | 'AutoTimes_Gpt2': AutoTimes_Gpt2,
10 | 'AutoTimes_Opt_1b': AutoTimes_Opt_1b
11 | }
12 | self.model = self._build_model()
13 |
14 | def _build_model(self):
15 | raise NotImplementedError
16 |
17 | def _get_data(self):
18 | pass
19 |
20 | def vali(self):
21 | pass
22 |
23 | def train(self):
24 | pass
25 |
26 | def test(self):
27 | pass
28 |
--------------------------------------------------------------------------------
/exp/exp_in_context_forecasting.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_factory import data_provider
2 | from data_provider.m4 import M4Meta
3 | from exp.exp_basic import Exp_Basic
4 | from utils.tools import EarlyStopping, adjust_learning_rate
5 | from utils.losses import mape_loss, mase_loss, smape_loss, zero_shot_smape_loss
6 | import torch
7 | import torch.nn as nn
8 | from torch import optim
9 | import os
10 | import time
11 | import warnings
12 | import numpy as np
13 |
14 | # In our in-context learning setting
15 | # the task is to apply a forecaster, trained on a source dataset, to an unseen target dataset
16 | # Additionally, several task demonstrations from the target domain,
17 | # referred to as time series prompts are available during inference
18 | # Concretely, AutoTimes trains LLMs on the source domain with a larger context length to place the additional time series prompt.
19 | # See ```Dataset_TSF_ICL``` in ```data_loader.py``` for the construction of time series prompts
20 |
21 | warnings.filterwarnings('ignore')
22 |
23 | def SMAPE(pred, true):
24 | return np.mean(200 * np.abs(pred - true) / (np.abs(pred) + np.abs(true) + 1e-8))
25 | def MAPE(pred, true):
26 | return np.mean(np.abs(100 * (pred - true) / (true +1e-8)))
27 |
28 | class Exp_In_Context_Forecast(Exp_Basic):
29 | def __init__(self, args):
30 | super(Exp_In_Context_Forecast, self).__init__(args)
31 |
32 | def _build_model(self):
33 | if self.args.data == 'm4':
34 | self.args.frequency_map = M4Meta.frequency_map[self.args.seasonal_patterns]
35 | self.device = self.args.gpu
36 | model = self.model_dict[self.args.model].Model(self.args).to(self.device)
37 | return model
38 |
39 | def _get_data(self, flag):
40 | data_set, data_loader = data_provider(self.args, flag)
41 | return data_set, data_loader
42 |
43 | def _select_optimizer(self):
44 | p_list = []
45 | for n, p in self.model.named_parameters():
46 | if not p.requires_grad:
47 | continue
48 | else:
49 | p_list.append(p)
50 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
51 | print(n, p.dtype, p.shape)
52 | model_optim = optim.Adam([{'params': p_list}], lr=self.args.learning_rate, weight_decay=self.args.weight_decay)
53 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
54 | print('next learning rate is {}'.format(self.args.learning_rate))
55 | return model_optim
56 |
57 | def _select_criterion(self, loss_name='MSE'):
58 | if loss_name == 'MSE':
59 | return nn.MSELoss()
60 | elif loss_name == 'MAPE':
61 | return mape_loss()
62 | elif loss_name == 'MASE':
63 | return mase_loss()
64 | elif loss_name == 'SMAPE':
65 | return smape_loss()
66 |
67 | def train(self, setting):
68 | train_data, train_loader = self._get_data(flag='train')
69 | vali_data, vali_loader = self._get_data(flag='val')
70 |
71 | self.args.root_path = './dataset/tsf'
72 | self.args.data_path = self.args.test_data_path
73 | self.args.data = 'tsf'
74 | test_data2, test_loader2 = self._get_data(flag='test')
75 |
76 | self.args.data = 'tsf_icl'
77 | test_data3, test_loader3 = self._get_data(flag="test")
78 | path = os.path.join(self.args.checkpoints, setting)
79 | if not os.path.exists(path):
80 | os.makedirs(path)
81 |
82 | time_now = time.time()
83 |
84 | train_steps = len(train_loader)
85 | early_stopping = EarlyStopping(self.args, verbose=True)
86 |
87 | model_optim = self._select_optimizer()
88 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=self.args.tmax, eta_min=1e-8)
89 | criterion = self._select_criterion(self.args.loss)
90 | if self.args.use_amp:
91 | scaler = torch.cuda.amp.GradScaler()
92 |
93 | for epoch in range(self.args.train_epochs):
94 | iter_count = 0
95 | train_loss = []
96 |
97 | self.model.train()
98 | epoch_time = time.time()
99 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
100 | iter_count += 1
101 | model_optim.zero_grad()
102 | batch_x = batch_x.float().to(self.device)
103 |
104 | batch_y = batch_y.float().to(self.device)
105 | batch_y_mark = batch_y_mark.float().to(self.device)
106 |
107 | if self.args.use_amp:
108 | with torch.cuda.amp.autocast():
109 | outputs = self.model(batch_x, None, None, None)
110 | else:
111 | outputs = self.model(batch_x, None, None, None)
112 |
113 | loss = criterion(batch_x, self.args.frequency_map, outputs, batch_y, batch_y_mark)
114 | train_loss.append(loss.item())
115 |
116 | if (i + 1) % 100 == 0:
117 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
118 | speed = (time.time() - time_now) / iter_count
119 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
120 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
121 | iter_count = 0
122 | time_now = time.time()
123 |
124 | if self.args.use_amp:
125 | scaler.scale(loss).backward()
126 | scaler.step(model_optim)
127 | scaler.update()
128 | else:
129 | loss.backward()
130 | model_optim.step()
131 |
132 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
133 | train_loss = np.average(train_loss)
134 | vali_loss = self.vali(train_loader, vali_loader, criterion) # test_loss indicates the result on the source datasets
135 | test_loss = vali_loss
136 | test_loss2 = self.vali2(test_data2, test_loader2, zero_shot_smape_loss()) # test_loss2 indicates the result on the target datasets
137 | test_loss3 = self.vali2(test_data3, test_loader3, zero_shot_smape_loss()) # test_loss3 indicates the result on the target datasets with time series prompts
138 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Zero Shot Test Loss: {4:.7f} In Context Test Loss: {5:.7f}".format(
139 | epoch + 1, train_steps, train_loss, vali_loss, test_loss2, test_loss3))
140 | early_stopping(vali_loss, self.model, path)
141 | if early_stopping.early_stop:
142 | print("Early stopping")
143 | break
144 | if self.args.cosine:
145 | scheduler.step()
146 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
147 | print("lr = {:.10f}".format(model_optim.param_groups[0]['lr']))
148 | else:
149 | adjust_learning_rate(model_optim, epoch + 1, self.args)
150 |
151 | best_model_path = path + '/' + f'checkpoint.pth'
152 |
153 | self.model.load_state_dict(torch.load(best_model_path), strict=False)
154 |
155 | return self.model
156 |
157 | def vali(self, train_loader, vali_loader, criterion):
158 | x, _ = train_loader.dataset.last_insample_window()
159 | y = vali_loader.dataset.timeseries
160 | x = torch.tensor(x, dtype=torch.float32).to(self.device)
161 | x = x.unsqueeze(-1)
162 |
163 | self.model.eval()
164 | with torch.no_grad():
165 | # decoder input
166 | B, _, C = x.shape
167 |
168 | outputs = torch.zeros((B, self.args.seq_len, C)).float() # .to(self.device)
169 | id_list = np.arange(0, B, 500) # validation set size
170 | id_list = np.append(id_list, B)
171 | if self.args.use_amp:
172 | with torch.cuda.amp.autocast():
173 | for i in range(len(id_list) - 1):
174 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None, None, None).detach().cpu()
175 | else:
176 | for i in range(len(id_list) - 1):
177 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None, None, None).detach().cpu()
178 | pred = outputs[:, -self.args.token_len:, :]
179 | true = torch.from_numpy(np.array(y))
180 | batch_y_mark = torch.ones(true.shape)
181 | loss = criterion(x.detach().cpu()[:, :, 0], self.args.frequency_map, pred[:, :, 0], true, batch_y_mark)
182 |
183 | self.model.train()
184 | return loss
185 |
186 | def vali2(self, vali_data, vali_loader, criterion):
187 | total_loss = []
188 | count= []
189 | self.model.eval()
190 | with torch.no_grad():
191 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
192 | batch_x = batch_x.float().to(self.device)
193 | batch_y = batch_y.float()
194 |
195 | batch_x_mark = batch_x_mark.float().to(self.device)
196 | batch_y_mark = batch_y_mark.float().to(self.device)
197 |
198 | if self.args.use_amp:
199 | with torch.cuda.amp.autocast():
200 | outputs = self.model(batch_x, None, None, None)
201 | else:
202 | outputs = self.model(batch_x, None, None, None)
203 |
204 | batch_y = batch_y[:, -self.args.test_pred_len:, :].to(self.device)
205 |
206 | pred = outputs[:, -self.args.test_pred_len:, :].detach().cpu()
207 | true = batch_y.detach().cpu()
208 |
209 | loss = criterion(pred, true)
210 |
211 | total_loss.append(loss)
212 | count.append(batch_x.shape[0])
213 | total_loss = np.average(total_loss, weights=count)
214 | self.model.train()
215 |
216 | return total_loss
217 |
218 | def test_(self, test_loader):
219 | preds = []
220 | trues = []
221 |
222 | self.model.eval()
223 | with torch.no_grad():
224 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
225 | batch_x = batch_x.float().to(self.device)
226 | batch_y = batch_y.float()
227 |
228 | if self.args.use_amp:
229 | with torch.cuda.amp.autocast():
230 | outputs = self.model(batch_x, None, None, None)
231 | else:
232 | outputs = self.model(batch_x, None, None, None)
233 |
234 | outputs = outputs[:, -self.args.test_pred_len:, :]
235 | batch_y = batch_y[:, -self.args.test_pred_len:, :].to(self.device)
236 |
237 | pred = outputs.detach().cpu().numpy()
238 | true = batch_y.detach().cpu().numpy()
239 |
240 | preds.append(pred)
241 | trues.append(true)
242 |
243 | preds = np.concatenate(preds, axis=0)
244 | trues = np.concatenate(trues, axis=0)
245 | print('test shape:', preds.shape, trues.shape)
246 |
247 | smape = SMAPE(preds, trues)
248 | mape = MAPE(preds, trues)
249 | print('mape:{:4f}, smape:{:.4f}'.format(mape, smape))
250 |
251 | def test(self, setting, test=0):
252 | if test:
253 | print('loading model')
254 | setting = self.args.test_dir
255 | best_model_path = self.args.test_file_name
256 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
257 | print("loading model from {}".format(os.path.join(self.args.checkpoints, setting, best_model_path)))
258 | self.model.load_state_dict(torch.load(os.path.join(self.args.checkpoints, setting, best_model_path)), strict=False)
259 |
260 | self.args.data_path = self.args.test_data_path
261 |
262 | self.args.root_path = './dataset/tsf'
263 | self.args.data_path = self.args.test_data_path
264 | self.args.data = 'tsf'
265 | test_data, test_loader = self._get_data('test')
266 | self.args.data = 'tsf_icl'
267 | test_data2, test_loader2 = self._get_data('test')
268 |
269 | print("zero shot forecasting")
270 | self.test_(test_loader)
271 | print("in context forecasting")
272 | self.test_(test_loader2)
273 |
274 |
--------------------------------------------------------------------------------
/exp/exp_long_term_forecasting.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_factory import data_provider
2 | from exp.exp_basic import Exp_Basic
3 | from utils.tools import EarlyStopping, adjust_learning_rate, visual
4 | from utils.metrics import metric
5 | import torch
6 | import torch.nn as nn
7 | from torch import optim
8 | import os
9 | import time
10 | import warnings
11 | import numpy as np
12 | from torch.nn.parallel import DistributedDataParallel as DDP
13 | import torch.distributed as dist
14 |
15 | warnings.filterwarnings('ignore')
16 |
17 |
18 | class Exp_Long_Term_Forecast(Exp_Basic):
19 | def __init__(self, args):
20 | super(Exp_Long_Term_Forecast, self).__init__(args)
21 |
22 | def _build_model(self):
23 | model = self.model_dict[self.args.model].Model(self.args)
24 | if self.args.use_multi_gpu:
25 | self.device = torch.device('cuda:{}'.format(self.args.local_rank))
26 | model = DDP(model.cuda(), device_ids=[self.args.local_rank])
27 | else:
28 | self.device = self.args.gpu
29 | model = model.to(self.device)
30 | return model
31 |
32 | def _get_data(self, flag):
33 | data_set, data_loader = data_provider(self.args, flag)
34 | return data_set, data_loader
35 |
36 | def _select_optimizer(self):
37 | p_list = []
38 | for n, p in self.model.named_parameters():
39 | if not p.requires_grad:
40 | continue
41 | else:
42 | p_list.append(p)
43 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
44 | print(n, p.dtype, p.shape)
45 | model_optim = optim.Adam([{'params': p_list}], lr=self.args.learning_rate, weight_decay=self.args.weight_decay)
46 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
47 | print('next learning rate is {}'.format(self.args.learning_rate))
48 | return model_optim
49 |
50 | def _select_criterion(self):
51 | criterion = nn.MSELoss()
52 | return criterion
53 |
54 | def vali(self, vali_data, vali_loader, criterion, is_test=False):
55 | total_loss = []
56 | total_count = []
57 | time_now = time.time()
58 | test_steps = len(vali_loader)
59 | iter_count = 0
60 | self.model.eval()
61 | with torch.no_grad():
62 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
63 | iter_count += 1
64 | batch_x = batch_x.float().to(self.device)
65 | batch_y = batch_y.float()
66 | batch_x_mark = batch_x_mark.float().to(self.device)
67 | batch_y_mark = batch_y_mark.float().to(self.device)
68 |
69 | if self.args.use_amp:
70 | with torch.cuda.amp.autocast():
71 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark)
72 | else:
73 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark)
74 | if is_test:
75 | outputs = outputs[:, -self.args.token_len:, :]
76 | batch_y = batch_y[:, -self.args.token_len:, :].to(self.device)
77 | else:
78 | outputs = outputs[:, :, :]
79 | batch_y = batch_y[:, :, :].to(self.device)
80 |
81 | loss = criterion(outputs, batch_y)
82 |
83 | loss = loss.detach().cpu()
84 | total_loss.append(loss)
85 | total_count.append(batch_x.shape[0])
86 | if (i + 1) % 100 == 0:
87 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
88 | speed = (time.time() - time_now) / iter_count
89 | left_time = speed * (test_steps - i)
90 | print("\titers: {}, speed: {:.4f}s/iter, left time: {:.4f}s".format(i + 1, speed, left_time))
91 | iter_count = 0
92 | time_now = time.time()
93 | if self.args.use_multi_gpu:
94 | total_loss = torch.tensor(np.average(total_loss, weights=total_count)).to(self.device)
95 | dist.barrier()
96 | dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
97 | total_loss = total_loss.item() / dist.get_world_size()
98 | else:
99 | total_loss = np.average(total_loss, weights=total_count)
100 | self.model.train()
101 | return total_loss
102 |
103 | def train(self, setting):
104 | train_data, train_loader = self._get_data(flag='train')
105 | vali_data, vali_loader = self._get_data(flag='val')
106 | test_data, test_loader = self._get_data(flag='test')
107 |
108 | path = os.path.join(self.args.checkpoints, setting)
109 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
110 | if not os.path.exists(path):
111 | os.makedirs(path)
112 |
113 | time_now = time.time()
114 |
115 | train_steps = len(train_loader)
116 | early_stopping = EarlyStopping(self.args, verbose=True)
117 |
118 | model_optim = self._select_optimizer()
119 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=self.args.tmax, eta_min=1e-8)
120 | criterion = self._select_criterion()
121 | if self.args.use_amp:
122 | scaler = torch.cuda.amp.GradScaler()
123 |
124 | for epoch in range(self.args.train_epochs):
125 | iter_count = 0
126 |
127 | loss_val = torch.tensor(0., device="cuda")
128 | count = torch.tensor(0., device="cuda")
129 |
130 | self.model.train()
131 | epoch_time = time.time()
132 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
133 | iter_count += 1
134 | model_optim.zero_grad()
135 | batch_x = batch_x.float().to(self.device)
136 | batch_y = batch_y.float().to(self.device)
137 | batch_x_mark = batch_x_mark.float().to(self.device)
138 | batch_y_mark = batch_y_mark.float().to(self.device)
139 |
140 | if self.args.use_amp:
141 | with torch.cuda.amp.autocast():
142 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark)
143 | loss = criterion(outputs, batch_y)
144 | loss_val += loss
145 | count += 1
146 | else:
147 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark)
148 | loss = criterion(outputs, batch_y)
149 | loss_val += loss
150 | count += 1
151 |
152 | if (i + 1) % 100 == 0:
153 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
154 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
155 | speed = (time.time() - time_now) / iter_count
156 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
157 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
158 | iter_count = 0
159 | time_now = time.time()
160 |
161 | if self.args.use_amp:
162 | scaler.scale(loss).backward()
163 | scaler.step(model_optim)
164 | scaler.update()
165 | else:
166 | loss.backward()
167 | model_optim.step()
168 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
169 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
170 | if self.args.use_multi_gpu:
171 | dist.barrier()
172 | dist.all_reduce(loss_val, op=dist.ReduceOp.SUM)
173 | dist.all_reduce(count, op=dist.ReduceOp.SUM)
174 | train_loss = loss_val.item() / count.item()
175 |
176 | vali_loss = self.vali(vali_data, vali_loader, criterion)
177 | test_loss = self.vali(test_data, test_loader, criterion, is_test=True)
178 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
179 | print("Epoch: {}, Steps: {} | Train Loss: {:.7f} Vali Loss: {:.7f} Test Loss: {:.7f}".format(
180 | epoch + 1, train_steps, train_loss, vali_loss, test_loss))
181 | early_stopping(vali_loss, self.model, path)
182 | if early_stopping.early_stop:
183 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
184 | print("Early stopping")
185 | break
186 | if self.args.cosine:
187 | scheduler.step()
188 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
189 | print("lr = {:.10f}".format(model_optim.param_groups[0]['lr']))
190 | else:
191 | adjust_learning_rate(model_optim, epoch + 1, self.args)
192 | if self.args.use_multi_gpu:
193 | train_loader.sampler.set_epoch(epoch + 1)
194 |
195 | best_model_path = path + '/' + 'checkpoint.pth'
196 | if self.args.use_multi_gpu:
197 | dist.barrier()
198 | self.model.load_state_dict(torch.load(best_model_path), strict=False)
199 | else:
200 | self.model.load_state_dict(torch.load(best_model_path), strict=False)
201 | return self.model
202 |
203 | def test(self, setting, test=0):
204 | test_data, test_loader = self._get_data(flag='test')
205 |
206 | print("info:", self.args.test_seq_len, self.args.test_label_len, self.args.token_len, self.args.test_pred_len)
207 | if test:
208 | print('loading model')
209 | setting = self.args.test_dir
210 | best_model_path = self.args.test_file_name
211 |
212 | print("loading model from {}".format(os.path.join(self.args.checkpoints, setting, best_model_path)))
213 | load_item = torch.load(os.path.join(self.args.checkpoints, setting, best_model_path))
214 | self.model.load_state_dict({k.replace('module.', ''): v for k, v in load_item.items()}, strict=False)
215 |
216 | preds = []
217 | trues = []
218 | folder_path = './test_results/' + setting + '/'
219 | if not os.path.exists(folder_path):
220 | os.makedirs(folder_path)
221 | time_now = time.time()
222 | test_steps = len(test_loader)
223 | iter_count = 0
224 | self.model.eval()
225 | with torch.no_grad():
226 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
227 | iter_count += 1
228 | batch_x = batch_x.float().to(self.device)
229 | batch_y = batch_y.float().to(self.device)
230 | batch_x_mark = batch_x_mark.float().to(self.device)
231 | batch_y_mark = batch_y_mark.float().to(self.device)
232 |
233 | inference_steps = self.args.test_pred_len // self.args.token_len
234 | dis = self.args.test_pred_len - inference_steps * self.args.token_len
235 | if dis != 0:
236 | inference_steps += 1
237 | pred_y = []
238 | for j in range(inference_steps):
239 | if len(pred_y) != 0:
240 | batch_x = torch.cat([batch_x[:, self.args.token_len:, :], pred_y[-1]], dim=1)
241 | tmp = batch_y_mark[:, j-1:j, :]
242 | batch_x_mark = torch.cat([batch_x_mark[:, 1:, :], tmp], dim=1)
243 |
244 | if self.args.use_amp:
245 | with torch.cuda.amp.autocast():
246 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark)
247 | else:
248 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark)
249 | pred_y.append(outputs[:, -self.args.token_len:, :])
250 | pred_y = torch.cat(pred_y, dim=1)
251 | if dis != 0:
252 | pred_y = pred_y[:, :-(self.args.token_len - dis), :]
253 | batch_y = batch_y[:, -self.args.test_pred_len:, :].to(self.device)
254 | outputs = pred_y.detach().cpu()
255 | batch_y = batch_y.detach().cpu()
256 |
257 | pred = outputs
258 | true = batch_y
259 |
260 | preds.append(pred)
261 | trues.append(true)
262 | if (i + 1) % 100 == 0:
263 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu:
264 | speed = (time.time() - time_now) / iter_count
265 | left_time = speed * (test_steps - i)
266 | print("\titers: {}, speed: {:.4f}s/iter, left time: {:.4f}s".format(i + 1, speed, left_time))
267 | iter_count = 0
268 | time_now = time.time()
269 |
270 | if self.args.visualize and i == 0:
271 | gt = np.array(true[0, :, -1])
272 | pd = np.array(pred[0, :, -1])
273 | lookback = batch_x[0, :, -1].detach().cpu().numpy()
274 | gt = np.concatenate([lookback, gt], axis=0)
275 | pd = np.concatenate([lookback, pd], axis=0)
276 | dir_path = folder_path + f'{self.args.test_pred_len}/'
277 | if not os.path.exists(dir_path):
278 | os.makedirs(dir_path)
279 | visual(gt, pd, os.path.join(dir_path, f'{i}.png'))
280 |
281 | preds = torch.cat(preds, dim=0).numpy()
282 | trues = torch.cat(trues, dim=0).numpy()
283 |
284 | mae, mse, rmse, mape, mspe = metric(preds, trues)
285 | print('mse:{}, mae:{}'.format(mse, mae))
286 | f = open("result_long_term_forecast.txt", 'a')
287 | f.write(setting + " \n")
288 | f.write('mse:{}, mae:{}'.format(mse, mae))
289 | f.write('\n')
290 | f.write('\n')
291 | f.close()
292 | return
293 |
--------------------------------------------------------------------------------
/exp/exp_short_term_forecasting.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_factory import data_provider
2 | from data_provider.m4 import M4Meta
3 | from exp.exp_basic import Exp_Basic
4 | from utils.tools import EarlyStopping, adjust_learning_rate, visual
5 | from utils.losses import mape_loss, mase_loss, smape_loss
6 | from utils.m4_summary import M4Summary
7 | import torch
8 | import torch.nn as nn
9 | from torch import optim
10 | import os
11 | import time
12 | import warnings
13 | import numpy as np
14 | import pandas
15 |
16 |
17 | warnings.filterwarnings('ignore')
18 |
19 | class Exp_Short_Term_Forecast(Exp_Basic):
20 | def __init__(self, args):
21 | super(Exp_Short_Term_Forecast, self).__init__(args)
22 |
23 | def _build_model(self):
24 | if self.args.data == 'm4':
25 | self.args.token_len = M4Meta.horizons_map[self.args.seasonal_patterns] # Up to M4 config
26 | self.args.seq_len = 2 * self.args.token_len # input_len = 2*token_len
27 | self.args.label_len = self.args.token_len
28 | self.args.frequency_map = M4Meta.frequency_map[self.args.seasonal_patterns]
29 | self.device = self.args.gpu
30 | model = self.model_dict[self.args.model].Model(self.args).to(self.device)
31 | return model
32 |
33 | def _get_data(self, flag):
34 | data_set, data_loader = data_provider(self.args, flag)
35 | return data_set, data_loader
36 |
37 | def _select_optimizer(self):
38 | p_list = []
39 | for n, p in self.model.named_parameters():
40 | if not p.requires_grad:
41 | continue
42 | else:
43 | p_list.append(p)
44 | print(n, p.dtype, p.shape)
45 | model_optim = optim.Adam([{'params': p_list}], lr=self.args.learning_rate, weight_decay=self.args.weight_decay)
46 | print('next learning rate is {}'.format(self.args.learning_rate))
47 | return model_optim
48 |
49 | def _select_criterion(self, loss_name='MSE'):
50 | if loss_name == 'MSE':
51 | return nn.MSELoss()
52 | elif loss_name == 'MAPE':
53 | return mape_loss()
54 | elif loss_name == 'MASE':
55 | return mase_loss()
56 | elif loss_name == 'SMAPE':
57 | return smape_loss()
58 |
59 | def train(self, setting):
60 | train_data, train_loader = self._get_data(flag='train')
61 | vali_data, vali_loader = self._get_data(flag='val')
62 |
63 | path = os.path.join(self.args.checkpoints, setting)
64 | if not os.path.exists(path):
65 | os.makedirs(path)
66 |
67 | time_now = time.time()
68 |
69 | train_steps = len(train_loader)
70 | early_stopping = EarlyStopping(self.args, verbose=True)
71 |
72 | model_optim = self._select_optimizer()
73 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=self.args.tmax, eta_min=1e-8)
74 | criterion = self._select_criterion(self.args.loss)
75 | if self.args.use_amp:
76 | scaler = torch.cuda.amp.GradScaler()
77 |
78 | for epoch in range(self.args.train_epochs):
79 | iter_count = 0
80 | train_loss = []
81 |
82 | self.model.train()
83 | epoch_time = time.time()
84 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
85 | iter_count += 1
86 | model_optim.zero_grad()
87 | batch_x = batch_x.float().to(self.device)
88 |
89 | batch_y = batch_y.float().to(self.device)
90 | batch_y_mark = batch_y_mark.float().to(self.device)
91 |
92 | if self.args.use_amp:
93 | with torch.cuda.amp.autocast():
94 | outputs = self.model(batch_x, None, None, None)
95 | else:
96 | outputs = self.model(batch_x, None, None, None)
97 |
98 | loss = criterion(batch_x, self.args.frequency_map, outputs, batch_y, batch_y_mark)
99 | train_loss.append(loss.item())
100 |
101 | if (i + 1) % 100 == 0:
102 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
103 | speed = (time.time() - time_now) / iter_count
104 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
105 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
106 | iter_count = 0
107 | time_now = time.time()
108 |
109 | if self.args.use_amp:
110 | scaler.scale(loss).backward()
111 | scaler.step(model_optim)
112 | scaler.update()
113 | else:
114 | loss.backward()
115 | model_optim.step()
116 |
117 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
118 | train_loss = np.average(train_loss)
119 | vali_loss = self.vali(train_loader, vali_loader, criterion)
120 | test_loss = vali_loss
121 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
122 | epoch + 1, train_steps, train_loss, vali_loss, test_loss))
123 | early_stopping(vali_loss, self.model, path)
124 | if early_stopping.early_stop:
125 | print("Early stopping")
126 | break
127 | if self.args.cosine:
128 | scheduler.step()
129 | print("lr = {:.10f}".format(model_optim.param_groups[0]['lr']))
130 | else:
131 | adjust_learning_rate(model_optim, epoch + 1, self.args)
132 |
133 | best_model_path = path + '/' + f'checkpoint.pth'
134 | self.model.load_state_dict(torch.load(best_model_path), strict=False)
135 | return self.model
136 |
137 | def vali(self, train_loader, vali_loader, criterion):
138 | x, _ = train_loader.dataset.last_insample_window()
139 | y = vali_loader.dataset.timeseries
140 | x = torch.tensor(x, dtype=torch.float32).to(self.device)
141 | x = x.unsqueeze(-1)
142 |
143 | self.model.eval()
144 | with torch.no_grad():
145 | B, _, C = x.shape
146 |
147 | outputs = torch.zeros((B, self.args.seq_len, C)).float() # .to(self.device)
148 | id_list = np.arange(0, B, 500) # validation set size
149 | id_list = np.append(id_list, B)
150 | if self.args.use_amp:
151 | with torch.cuda.amp.autocast():
152 | for i in range(len(id_list) - 1):
153 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None,
154 | None,
155 | None).detach().cpu()
156 | else:
157 | for i in range(len(id_list) - 1):
158 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None,
159 | None,
160 | None).detach().cpu()
161 | pred = outputs[:, -self.args.token_len:, :]
162 | true = torch.from_numpy(np.array(y))
163 | batch_y_mark = torch.ones(true.shape)
164 |
165 | loss = criterion(x.detach().cpu()[:, :, 0], self.args.frequency_map, pred[:, :, 0], true, batch_y_mark)
166 |
167 | self.model.train()
168 | return loss
169 |
170 | def test(self, setting, test=0):
171 | _, train_loader = self._get_data(flag='train')
172 | _, test_loader = self._get_data(flag='test')
173 | x, _ = train_loader.dataset.last_insample_window()
174 | y = test_loader.dataset.timeseries
175 | x = torch.tensor(x, dtype=torch.float32).to(self.device)
176 | x = x.unsqueeze(-1)
177 |
178 | if test:
179 | print('loading model')
180 | setting = self.args.test_dir
181 | best_model_path = self.args.test_file_name
182 |
183 | print("loading model from {}".format(os.path.join(self.args.checkpoints, setting, best_model_path)))
184 | load_item = torch.load(os.path.join(self.args.checkpoints, setting, best_model_path))
185 | self.model.load_state_dict({k.replace('module.', ''): v for k, v in load_item.items()}, strict=False)
186 |
187 | folder_path = './test_results/' + setting + '/'
188 | if not os.path.exists(folder_path):
189 | os.makedirs(folder_path)
190 |
191 | self.model.eval()
192 | with torch.no_grad():
193 | B, _, C = x.shape
194 |
195 | outputs = torch.zeros((B, self.args.seq_len, C)).float().to(self.device)
196 | id_list = np.arange(0, B, 1)
197 | id_list = np.append(id_list, B)
198 | if self.args.use_amp:
199 | with torch.cuda.amp.autocast():
200 | for i in range(len(id_list) - 1):
201 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None,
202 | None, None)
203 | else:
204 | for i in range(len(id_list) - 1):
205 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None,
206 | None, None)
207 | outputs = outputs[:, -self.args.token_len:, :]
208 | preds = outputs.detach().cpu().numpy()
209 | trues = y
210 | x = x.detach().cpu().numpy()
211 |
212 | if self.args.visualize and i % 2 == 0:
213 | gt = np.concatenate((x[i, :, 0], trues[i]), axis=0)
214 | pd = np.concatenate((x[i, :, 0], preds[i, :, 0]), axis=0)
215 | visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))
216 |
217 | print('test shape:', preds.shape)
218 |
219 | # result save
220 | folder_path = './m4_results/' + self.args.model + '/'
221 | if not os.path.exists(folder_path):
222 | os.makedirs(folder_path)
223 |
224 | forecasts_df = pandas.DataFrame(preds[:, :, 0], columns=[f'V{i + 1}' for i in range(self.args.token_len)])
225 | forecasts_df.index = test_loader.dataset.ids[:preds.shape[0]]
226 | forecasts_df.index.name = 'id'
227 | forecasts_df.set_index(forecasts_df.columns[0], inplace=True)
228 | forecasts_df.to_csv(folder_path + self.args.seasonal_patterns + '_forecast.csv')
229 |
230 | print(self.args.model)
231 | file_path = './m4_results/' + self.args.model + '/'
232 | if 'Weekly_forecast.csv' in os.listdir(file_path) \
233 | and 'Monthly_forecast.csv' in os.listdir(file_path) \
234 | and 'Yearly_forecast.csv' in os.listdir(file_path) \
235 | and 'Daily_forecast.csv' in os.listdir(file_path) \
236 | and 'Hourly_forecast.csv' in os.listdir(file_path) \
237 | and 'Quarterly_forecast.csv' in os.listdir(file_path):
238 | m4_summary = M4Summary(file_path, self.args.root_path)
239 | # m4_forecast.set_index(m4_winner_forecast.columns[0], inplace=True)
240 | smape_results, owa_results, mape, mase = m4_summary.evaluate()
241 | print('smape:', smape_results)
242 | print('mape:', mape)
243 | print('mase:', mase)
244 | print('owa:', owa_results)
245 | else:
246 | print('After all 6 tasks are finished, you can calculate the averaged index')
247 | return
248 |
--------------------------------------------------------------------------------
/exp/exp_zero_shot_forecasting.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_factory import data_provider
2 | from data_provider.m4 import M4Meta
3 | from exp.exp_basic import Exp_Basic
4 | from utils.tools import EarlyStopping, adjust_learning_rate
5 | from utils.losses import zero_shot_smape_loss
6 | import torch
7 | import torch.nn as nn
8 | from torch import optim
9 | import os
10 | import time
11 | import warnings
12 | import numpy as np
13 |
14 |
15 | warnings.filterwarnings('ignore')
16 |
17 | def SMAPE(pred, true):
18 | return np.mean(200 * np.abs(pred - true) / (np.abs(pred) + np.abs(true) + 1e-8))
19 | def MAPE(pred, true):
20 | return np.mean(np.abs(100 * (pred - true) / (true +1e-8)))
21 |
22 | class Exp_Zero_Shot_Forecast(Exp_Basic):
23 | def __init__(self, args):
24 | super(Exp_Zero_Shot_Forecast, self).__init__(args)
25 |
26 | def _build_model(self):
27 | if self.args.data == 'tsf':
28 | self.args.frequency_map = M4Meta.frequency_map[self.args.seasonal_patterns]
29 | self.device = self.args.gpu
30 | model = self.model_dict[self.args.model].Model(self.args).to(self.device)
31 | return model
32 |
33 | def _get_data(self, flag):
34 | data_set, data_loader = data_provider(self.args, flag)
35 | return data_set, data_loader
36 |
37 | def _select_optimizer(self):
38 | p_list = []
39 | for n, p in self.model.named_parameters():
40 | if not p.requires_grad:
41 | continue
42 | else:
43 | p_list.append(p)
44 | print(n, p.dtype, p.shape)
45 | model_optim = optim.Adam([{'params': p_list}], lr=self.args.learning_rate, weight_decay=self.args.weight_decay)
46 | print('next learning rate is {}'.format(self.args.learning_rate))
47 | return model_optim
48 |
49 | def _select_criterion(self, loss_name='MSE'):
50 | if loss_name == 'MSE':
51 | return nn.MSELoss()
52 | elif loss_name == 'SMAPE':
53 | return zero_shot_smape_loss()
54 |
55 | def train(self, setting):
56 | train_data, train_loader = self._get_data(flag='train')
57 | vali_data, vali_loader = self._get_data(flag='val')
58 | test_data, test_loader = self._get_data(flag='test')
59 |
60 | self.args.data_path = self.args.test_data_path
61 | test_data2, test_loader2 = self._get_data(flag='test')
62 |
63 | path = os.path.join(self.args.checkpoints, setting)
64 | if not os.path.exists(path):
65 | os.makedirs(path)
66 |
67 | time_now = time.time()
68 |
69 | train_steps = len(train_loader)
70 | early_stopping = EarlyStopping(self.args, verbose=True)
71 |
72 | model_optim = self._select_optimizer()
73 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=self.args.tmax, eta_min=1e-8)
74 | criterion = self._select_criterion(self.args.loss)
75 | if self.args.use_amp:
76 | scaler = torch.cuda.amp.GradScaler()
77 |
78 | for epoch in range(self.args.train_epochs):
79 | iter_count = 0
80 | train_loss = []
81 |
82 | self.model.train()
83 | epoch_time = time.time()
84 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
85 | iter_count += 1
86 | model_optim.zero_grad()
87 | batch_x = batch_x.float().to(self.device)
88 | batch_y = batch_y.float().to(self.device)
89 |
90 | if self.args.use_amp:
91 | with torch.cuda.amp.autocast():
92 | outputs = self.model(batch_x, None, None, None)
93 | else:
94 | outputs = self.model(batch_x, None, None, None)
95 |
96 | loss = criterion(outputs, batch_y)
97 | train_loss.append(loss.item())
98 |
99 | if (i + 1) % 100 == 0:
100 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
101 | speed = (time.time() - time_now) / iter_count
102 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
103 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
104 | iter_count = 0
105 | time_now = time.time()
106 |
107 | if self.args.use_amp:
108 | scaler.scale(loss).backward()
109 | scaler.step(model_optim)
110 | scaler.update()
111 | else:
112 | loss.backward()
113 | model_optim.step()
114 |
115 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
116 | train_loss = np.average(train_loss)
117 | vali_loss = self.vali(vali_data, vali_loader, criterion)
118 | test_loss = self.vali2(test_data, test_loader, criterion) # test_loss indicates the result on the source datasets,
119 | test_loss2 = self.vali2(test_data2, test_loader2, criterion) # test_loss2 indicates the result on the taregt datasets. The latter is what we concerned in zero-shot forecasting.
120 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f} Test Loss2: {5:.7f}".format(
121 | epoch + 1, train_steps, train_loss, vali_loss, test_loss, test_loss2))
122 | early_stopping(vali_loss, self.model, path)
123 | if early_stopping.early_stop:
124 | print("Early stopping")
125 | break
126 | if self.args.cosine:
127 | scheduler.step()
128 | print("lr = {:.10f}".format(model_optim.param_groups[0]['lr']))
129 | else:
130 | adjust_learning_rate(model_optim, epoch + 1, self.args)
131 |
132 | best_model_path = path + '/' + f'checkpoint.pth'
133 | self.model.load_state_dict(torch.load(best_model_path), strict=False)
134 |
135 | return self.model
136 |
137 | def vali(self, vali_data, vali_loader, criterion):
138 | total_loss = []
139 | self.model.eval()
140 | with torch.no_grad():
141 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
142 | batch_x = batch_x.float().to(self.device)
143 | batch_y = batch_y.float()
144 |
145 | if self.args.use_amp:
146 | with torch.cuda.amp.autocast():
147 | outputs = self.model(batch_x, None, None, None)
148 | else:
149 | outputs = self.model(batch_x, None, None, None)
150 |
151 | pred = outputs.detach().cpu()
152 | true = batch_y.detach().cpu()
153 |
154 | loss = criterion(pred, true)
155 | total_loss.append(loss)
156 | total_loss = np.average(total_loss)
157 | self.model.train()
158 |
159 | return total_loss
160 |
161 | def vali2(self, vali_data, vali_loader, criterion):
162 | total_loss = []
163 | count= []
164 | self.model.eval()
165 | with torch.no_grad():
166 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
167 | batch_x = batch_x.float().to(self.device)
168 | batch_y = batch_y.float()
169 |
170 | inference_steps = self.args.test_pred_len // self.args.token_len
171 | dis = self.args.test_pred_len - inference_steps * self.args.token_len
172 | if dis != 0:
173 | inference_steps += 1
174 | pred_y = []
175 |
176 | for j in range(inference_steps):
177 | if len(pred_y) != 0:
178 | batch_x = torch.cat([batch_x[:, self.args.token_len:, :], pred_y[-1]], dim=1)
179 | if self.args.use_amp:
180 | with torch.cuda.amp.autocast():
181 | outputs = self.model(batch_x, None, None, None)
182 | else:
183 | outputs = self.model(batch_x, None, None, None)
184 | pred_y.append(outputs[:, -self.args.token_len:, :])
185 | pred_y = torch.cat(pred_y, dim=1)
186 | if dis != 0:
187 | pred_y = pred_y[:, :-dis, :]
188 |
189 | outputs = pred_y
190 | batch_y = batch_y[:, -self.args.test_pred_len:, :].to(self.device)
191 |
192 | pred = outputs.detach().cpu()
193 | true = batch_y.detach().cpu()
194 |
195 | loss = criterion(pred, true)
196 | total_loss.append(loss)
197 | count.append(batch_x.shape[0])
198 | total_loss = np.average(total_loss, weights=count)
199 | self.model.train()
200 |
201 | return total_loss
202 |
203 | def test(self, setting, test=0):
204 | if test:
205 | print('loading model')
206 | setting = self.args.test_dir
207 | best_model_path = self.args.test_file_name
208 | print("loading model from {}".format(os.path.join(self.args.checkpoints, setting, best_model_path)))
209 | self.model.load_state_dict(torch.load(os.path.join(self.args.checkpoints, setting, best_model_path)), strict=False)
210 |
211 | self.args.data_path = self.args.test_data_path
212 | test_data, test_loader = self._get_data('test')
213 |
214 | preds = []
215 | trues = []
216 |
217 | self.model.eval()
218 | with torch.no_grad():
219 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
220 | batch_x = batch_x.float().to(self.device)
221 | batch_y = batch_y.float()
222 |
223 | inference_steps = self.args.test_pred_len // self.args.token_len
224 | dis = self.args.test_pred_len - inference_steps * self.args.token_len
225 | if dis != 0:
226 | inference_steps += 1
227 | pred_y = []
228 |
229 | for j in range(inference_steps):
230 | if len(pred_y) != 0:
231 | batch_x = torch.cat([batch_x[:, self.args.token_len:, :], pred_y[-1]], dim=1)
232 | if self.args.use_amp:
233 | with torch.cuda.amp.autocast():
234 | outputs = self.model(batch_x, None, None, None)
235 | else:
236 | outputs = self.model(batch_x, None, None, None)
237 | pred_y.append(outputs[:, -self.args.token_len:, :])
238 | pred_y = torch.cat(pred_y, dim=1)
239 | if dis != 0:
240 | pred_y = pred_y[:, :-dis, :]
241 |
242 | outputs = pred_y
243 | batch_y = batch_y[:, -self.args.test_pred_len:, :].to(self.device)
244 |
245 | pred = outputs.detach().cpu().numpy()
246 | true = batch_y.detach().cpu().numpy()
247 |
248 | preds.append(pred)
249 | trues.append(true)
250 |
251 | preds = np.concatenate(preds, axis=0)
252 | trues = np.concatenate(trues, axis=0)
253 | print('test shape:', preds.shape, trues.shape)
254 |
255 | smape = SMAPE(preds, trues)
256 | mape = MAPE(preds, trues)
257 | print('mape:{:4f}, smape:{:.4f}'.format(mape, smape))
258 |
--------------------------------------------------------------------------------
/figures/ablation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/ablation.png
--------------------------------------------------------------------------------
/figures/ablation_llm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/ablation_llm.png
--------------------------------------------------------------------------------
/figures/adaption_efficiency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/adaption_efficiency.png
--------------------------------------------------------------------------------
/figures/comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/comparison.png
--------------------------------------------------------------------------------
/figures/formulation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/formulation.png
--------------------------------------------------------------------------------
/figures/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/icon.png
--------------------------------------------------------------------------------
/figures/illustration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/illustration.png
--------------------------------------------------------------------------------
/figures/in-context.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/in-context.png
--------------------------------------------------------------------------------
/figures/lora.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/lora.png
--------------------------------------------------------------------------------
/figures/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/method.png
--------------------------------------------------------------------------------
/figures/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/motivation.png
--------------------------------------------------------------------------------
/figures/one-for-all_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/one-for-all_results.png
--------------------------------------------------------------------------------
/figures/param.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/param.png
--------------------------------------------------------------------------------
/figures/showcases.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/showcases.png
--------------------------------------------------------------------------------
/figures/subway_icf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/subway_icf.png
--------------------------------------------------------------------------------
/figures/zeroshot_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/zeroshot_results.png
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/layers/__init__.py
--------------------------------------------------------------------------------
/layers/mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class MLP(nn.Module):
4 | '''
5 | Multilayer perceptron to encode/decode high dimension representation of sequential data
6 | '''
7 | def __init__(self,
8 | f_in,
9 | f_out,
10 | hidden_dim=256,
11 | hidden_layers=2,
12 | dropout=0.1,
13 | activation='tanh'):
14 | super(MLP, self).__init__()
15 | self.f_in = f_in
16 | self.f_out = f_out
17 | self.hidden_dim = hidden_dim
18 | self.hidden_layers = hidden_layers
19 | self.dropout = dropout
20 | if activation == 'relu':
21 | self.activation = nn.ReLU()
22 | elif activation == 'tanh':
23 | self.activation = nn.Tanh()
24 | elif activation == 'gelu':
25 | self.activation = nn.GELU()
26 | else:
27 | raise NotImplementedError
28 |
29 | layers = [nn.Linear(self.f_in, self.hidden_dim),
30 | self.activation, nn.Dropout(self.dropout)]
31 | for i in range(self.hidden_layers-2):
32 | layers += [nn.Linear(self.hidden_dim, self.hidden_dim),
33 | self.activation, nn.Dropout(dropout)]
34 |
35 | layers += [nn.Linear(hidden_dim, f_out)]
36 | self.layers = nn.Sequential(*layers)
37 |
38 | def forward(self, x):
39 | # x: B x S x f_in
40 | # y: B x S x f_out
41 | y = self.layers(x)
42 | return y
43 |
--------------------------------------------------------------------------------
/models/AutoTimes_Gpt2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers.models.gpt2.modeling_gpt2 import GPT2Model
4 | from layers.mlp import MLP
5 |
6 | class Model(nn.Module):
7 | def __init__(self, configs):
8 | super(Model, self).__init__()
9 | self.token_len = configs.token_len
10 | if configs.use_multi_gpu:
11 | self.device = f"cuda:{configs.local_rank}"
12 | else:
13 | self.device = f"cuda:{configs.gpu}"
14 | print(self.device)
15 |
16 | self.gpt2 = GPT2Model.from_pretrained(configs.llm_ckp_dir)
17 | self.hidden_dim_of_gpt2 = 768
18 | self.mix = configs.mix_embeds
19 |
20 | if self.mix:
21 | self.add_scale = nn.Parameter(torch.ones([]))
22 |
23 | for name, param in self.gpt2.named_parameters():
24 | param.requires_grad = False
25 |
26 | if configs.mlp_hidden_layers == 0:
27 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0):
28 | print("use linear as tokenizer and detokenizer")
29 | self.encoder = nn.Linear(self.token_len, self.hidden_dim_of_gpt2)
30 | self.decoder = nn.Linear(self.hidden_dim_of_gpt2, self.token_len)
31 | else:
32 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0):
33 | print("use mlp as tokenizer and detokenizer")
34 | self.encoder = MLP(self.token_len, self.hidden_dim_of_gpt2,
35 | configs.mlp_hidden_dim, configs.mlp_hidden_layers,
36 | configs.dropout, configs.mlp_activation)
37 | self.decoder = MLP(self.hidden_dim_of_gpt2, self.token_len,
38 | configs.mlp_hidden_dim, configs.mlp_hidden_layers,
39 | configs.dropout, configs.mlp_activation)
40 |
41 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
42 | means = x_enc.mean(1, keepdim=True).detach()
43 | x_enc = x_enc - means
44 | stdev = torch.sqrt(
45 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
46 | x_enc /= stdev
47 |
48 | bs, _, n_vars = x_enc.shape
49 | # x_enc: [bs x nvars x seq_len]
50 | x_enc = x_enc.permute(0, 2, 1)
51 | # x_enc: [bs * nvars x seq_len]
52 | x_enc = x_enc.reshape(x_enc.shape[0] * x_enc.shape[1], -1)
53 | # fold_out: [bs * n_vars x token_num x token_len]
54 | fold_out = x_enc.unfold(dimension=-1, size=self.token_len, step=self.token_len)
55 | token_num = fold_out.shape[1]
56 | # times_embeds: [bs * n_vars x token_num x hidden_dim_of_gpt2]
57 | times_embeds = self.encoder(fold_out)
58 | if self.mix:
59 | times_embeds = times_embeds / times_embeds.norm(dim=2, keepdim=True)
60 | x_mark_enc = x_mark_enc / x_mark_enc.norm(dim=2, keepdim=True)
61 | times_embeds = times_embeds + self.add_scale * x_mark_enc
62 | # outputs: [bs * n_vars x token_num x hidden_dim_of_gpt2]
63 | outputs = self.gpt2(
64 | inputs_embeds=times_embeds).last_hidden_state
65 | # dec_out: [bs * n_vars x token_num x token_len]
66 | dec_out = self.decoder(outputs)
67 | dec_out = dec_out.reshape(bs, n_vars, -1)
68 | # dec_out: [bs x token_num * token_len x n_vars]
69 | dec_out = dec_out.permute(0, 2, 1)
70 |
71 | dec_out = dec_out * \
72 | (stdev[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1))
73 | dec_out = dec_out + \
74 | (means[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1))
75 |
76 | return dec_out
77 |
78 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
79 | return self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
--------------------------------------------------------------------------------
/models/AutoTimes_Llama.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import LlamaForCausalLM
4 | from layers.mlp import MLP
5 |
6 | class Model(nn.Module):
7 | def __init__(self, configs):
8 | super(Model, self).__init__()
9 | self.token_len = configs.token_len
10 | if configs.use_multi_gpu:
11 | self.device = f"cuda:{configs.local_rank}"
12 | else:
13 | self.device = f"cuda:{configs.gpu}"
14 | print(self.device)
15 |
16 | self.llama = LlamaForCausalLM.from_pretrained(
17 | configs.llm_ckp_dir,
18 | device_map=self.device,
19 | torch_dtype=torch.float16 if configs.use_amp else torch.float32,
20 | )
21 | self.hidden_dim_of_llama = 4096
22 | self.mix = configs.mix_embeds
23 | if self.mix:
24 | self.add_scale = nn.Parameter(torch.ones([]))
25 |
26 | for name, param in self.llama.named_parameters():
27 | param.requires_grad = False
28 |
29 | if configs.mlp_hidden_layers == 0:
30 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0):
31 | print("use linear as tokenizer and detokenizer")
32 | self.encoder = nn.Linear(self.token_len, self.hidden_dim_of_llama)
33 | self.decoder = nn.Linear(self.hidden_dim_of_llama, self.token_len)
34 | else:
35 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0):
36 | print("use mlp as tokenizer and detokenizer")
37 | self.encoder = MLP(self.token_len, self.hidden_dim_of_llama,
38 | configs.mlp_hidden_dim, configs.mlp_hidden_layers,
39 | configs.dropout, configs.mlp_activation)
40 | self.decoder = MLP(self.hidden_dim_of_llama, self.token_len,
41 | configs.mlp_hidden_dim, configs.mlp_hidden_layers,
42 | configs.dropout, configs.mlp_activation)
43 |
44 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
45 | means = x_enc.mean(1, keepdim=True).detach()
46 | x_enc = x_enc - means
47 | stdev = torch.sqrt(
48 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
49 | x_enc /= stdev
50 |
51 | bs, _, n_vars = x_enc.shape
52 | # x_enc: [bs x nvars x seq_len]
53 | x_enc = x_enc.permute(0, 2, 1)
54 | # x_enc: [bs * nvars x seq_len]
55 | x_enc = x_enc.reshape(x_enc.shape[0] * x_enc.shape[1], -1)
56 | # fold_out: [bs * n_vars x token_num x token_len]
57 | fold_out = x_enc.unfold(dimension=-1, size=self.token_len, step=self.token_len)
58 | token_num = fold_out.shape[1]
59 | # times_embeds: [bs * n_vars x token_num x hidden_dim_of_llama]
60 | times_embeds = self.encoder(fold_out)
61 | if self.mix:
62 | times_embeds = times_embeds / times_embeds.norm(dim=2, keepdim=True)
63 | x_mark_enc = x_mark_enc / x_mark_enc.norm(dim=2, keepdim=True)
64 | times_embeds = times_embeds + self.add_scale * x_mark_enc
65 | # outputs: [bs * n_vars x token_num x hidden_dim_of_llama]
66 | outputs = self.llama.model(
67 | inputs_embeds=times_embeds)[0]
68 | # dec_out: [bs * n_vars x token_num x token_len]
69 | dec_out = self.decoder(outputs)
70 | dec_out = dec_out.reshape(bs, n_vars, -1)
71 | # dec_out: [bs x token_num * token_len x n_vars]
72 | dec_out = dec_out.permute(0, 2, 1)
73 |
74 | dec_out = dec_out * \
75 | (stdev[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1))
76 | dec_out = dec_out + \
77 | (means[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1))
78 |
79 | return dec_out
80 |
81 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
82 | return self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
--------------------------------------------------------------------------------
/models/AutoTimes_Opt_1b.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import OPTForCausalLM
4 | from layers.mlp import MLP
5 |
6 | class Model(nn.Module):
7 | def __init__(self, configs):
8 | super(Model, self).__init__()
9 | self.token_len = configs.token_len
10 | if configs.use_multi_gpu:
11 | self.device = f"cuda:{configs.local_rank}"
12 | else:
13 | self.device = f"cuda:{configs.gpu}"
14 | print(self.device)
15 |
16 | self.opt = OPTForCausalLM.from_pretrained(configs.llm_ckp_dir, torch_dtype=torch.float16)
17 | self.opt.model.decoder.project_in = None
18 | self.opt.model.decoder.project_out = None
19 |
20 | self.hidden_dim_of_opt1b = 2048
21 | self.mix = configs.mix_embeds
22 |
23 | if self.mix:
24 | self.add_scale = nn.Parameter(torch.ones([]))
25 |
26 | for name, param in self.opt.named_parameters():
27 | param.requires_grad = False
28 |
29 | if configs.mlp_hidden_layers == 0:
30 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0):
31 | print("use linear as tokenizer and detokenizer")
32 | self.encoder = nn.Linear(self.token_len, self.hidden_dim_of_opt1b)
33 | self.decoder = nn.Linear(self.hidden_dim_of_opt1b, self.token_len)
34 | else:
35 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0):
36 | print("use mlp as tokenizer and detokenizer")
37 | self.encoder = MLP(self.token_len, self.hidden_dim_of_opt1b,
38 | configs.mlp_hidden_dim, configs.mlp_hidden_layers,
39 | configs.dropout, configs.mlp_activation)
40 | self.decoder = MLP(self.hidden_dim_of_opt1b, self.token_len,
41 | configs.mlp_hidden_dim, configs.mlp_hidden_layers,
42 | configs.dropout, configs.mlp_activation)
43 |
44 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
45 | means = x_enc.mean(1, keepdim=True).detach()
46 | x_enc = x_enc - means
47 | stdev = torch.sqrt(
48 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
49 | x_enc /= stdev
50 |
51 | bs, _, n_vars = x_enc.shape
52 | # x_enc: [bs x nvars x seq_len]
53 | x_enc = x_enc.permute(0, 2, 1)
54 | # x_enc: [bs * nvars x seq_len]
55 | x_enc = x_enc.reshape(x_enc.shape[0] * x_enc.shape[1], -1)
56 | # fold_out: [bs * n_vars x token_num x token_len]
57 | fold_out = x_enc.unfold(dimension=-1, size=self.token_len, step=self.token_len)
58 | token_num = fold_out.shape[1]
59 | # times_embeds: [bs * n_vars x token_num x hidden_dim_of_opt1b]
60 | times_embeds = self.encoder(fold_out)
61 | if self.mix:
62 | times_embeds = times_embeds / times_embeds.norm(dim=2, keepdim=True)
63 | x_mark_enc = x_mark_enc / x_mark_enc.norm(dim=2, keepdim=True)
64 | times_embeds = times_embeds + self.add_scale * x_mark_enc
65 | # outputs: [bs * n_vars x token_num x hidden_dim_of_opt1b]
66 | outputs = self.opt.model(
67 | inputs_embeds=times_embeds).last_hidden_state
68 | # dec_out: [bs * n_vars x token_num x token_len]
69 | dec_out = self.decoder(outputs)
70 | dec_out = dec_out.reshape(bs, n_vars, -1)
71 | # dec_out: [bs x token_num * token_len x n_vars]
72 | dec_out = dec_out.permute(0, 2, 1)
73 |
74 | dec_out = dec_out * \
75 | (stdev[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1))
76 | dec_out = dec_out + \
77 | (means[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1))
78 |
79 | return dec_out
80 |
81 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
82 | return self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
--------------------------------------------------------------------------------
/models/Preprocess_Llama.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import (
4 | LlamaForCausalLM,
5 | LlamaTokenizer,
6 | )
7 |
8 | class Model(nn.Module):
9 | def __init__(self, configs):
10 | super(Model, self).__init__()
11 | self.device = configs.gpu
12 | print(self.device)
13 |
14 | self.llama = LlamaForCausalLM.from_pretrained(
15 | configs.llm_ckp_dir,
16 | device_map=self.device,
17 | torch_dtype=torch.float16,
18 | )
19 | self.llama_tokenizer = LlamaTokenizer.from_pretrained(configs.llm_ckp_dir)
20 | self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
21 | self.vocab_size = self.llama_tokenizer.vocab_size
22 | self.hidden_dim_of_llama = 4096
23 |
24 | for name, param in self.llama.named_parameters():
25 | param.requires_grad = False
26 |
27 | def tokenizer(self, x):
28 | output = self.llama_tokenizer(x, return_tensors="pt")['input_ids'].to(self.device)
29 | result = self.llama.get_input_embeddings()(output)
30 | return result
31 |
32 | def forecast(self, x_mark_enc):
33 | # x_mark_enc: [bs x T x hidden_dim_of_llama]
34 | x_mark_enc = torch.cat([self.tokenizer(x_mark_enc[i]) for i in range(len(x_mark_enc))], 0)
35 | text_outputs = self.llama.model(inputs_embeds=x_mark_enc)[0]
36 | text_outputs = text_outputs[:, -1, :]
37 | return text_outputs
38 |
39 | def forward(self, x_mark_enc):
40 | return self.forecast(x_mark_enc)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/models/__init__.py
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from models.Preprocess_Llama import Model
4 |
5 | from data_provider.data_loader import Dataset_Preprocess
6 | from torch.utils.data import DataLoader
7 |
8 | if __name__ == '__main__':
9 | parser = argparse.ArgumentParser(description='AutoTimes Preprocess')
10 | parser.add_argument('--gpu', type=int, default=0, help='gpu id')
11 | parser.add_argument('--llm_ckp_dir', type=str, default='./llama', help='llm checkpoints dir')
12 | parser.add_argument('--dataset', type=str, default='ETTh1',
13 | help='dataset to preprocess, options:[ETTh1, electricity, weather, traffic]')
14 | args = parser.parse_args()
15 | print(args.dataset)
16 |
17 | model = Model(args)
18 |
19 | seq_len = 672
20 | label_len = 576
21 | pred_len = 96
22 |
23 | assert args.dataset in ['ETTh1', 'electricity', 'weather', 'traffic']
24 | if args.dataset == 'ETTh1':
25 | data_set = Dataset_Preprocess(
26 | root_path='./dataset/ETT-small/',
27 | data_path='ETTh1.csv',
28 | size=[seq_len, label_len, pred_len])
29 | elif args.dataset == 'electricity':
30 | data_set = Dataset_Preprocess(
31 | root_path='./dataset/electricity/',
32 | data_path='electricity.csv',
33 | size=[seq_len, label_len, pred_len])
34 | elif args.dataset == 'weather':
35 | data_set = Dataset_Preprocess(
36 | root_path='./dataset/weather/',
37 | data_path='weather.csv',
38 | size=[seq_len, label_len, pred_len])
39 | elif args.dataset == 'traffic':
40 | data_set = Dataset_Preprocess(
41 | root_path='./dataset/traffic/',
42 | data_path='traffic.csv',
43 | size=[seq_len, label_len, pred_len])
44 |
45 | data_loader = DataLoader(
46 | data_set,
47 | batch_size=128,
48 | shuffle=False,
49 | )
50 |
51 | from tqdm import tqdm
52 | print(len(data_set.data_stamp))
53 | print(data_set.tot_len)
54 | save_dir_path = './dataset/'
55 | output_list = []
56 | for idx, data in tqdm(enumerate(data_loader)):
57 | output = model(data)
58 | output_list.append(output.detach().cpu())
59 | result = torch.cat(output_list, dim=0)
60 | print(result.shape)
61 | torch.save(result, save_dir_path + f'/{args.dataset}.pt')
62 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.4.0
2 | matplotlib==3.7.0
3 | numpy==1.23.5
4 | pandas==1.5.3
5 | patool==1.12
6 | reformer-pytorch==1.4.4
7 | scikit-learn==1.2.2
8 | scipy==1.10.1
9 | sktime==0.16.1
10 | sympy==1.11.1
11 | tqdm==4.64.1
12 | torch==2.0.1
13 | transformers==4.35.2
14 | accelerate
15 | sentencepiece
16 | protobuf
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import numpy as np
5 | import torch
6 | import torch.distributed as dist
7 | from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast
8 | from exp.exp_short_term_forecasting import Exp_Short_Term_Forecast
9 | from exp.exp_zero_shot_forecasting import Exp_Zero_Shot_Forecast
10 | from exp.exp_in_context_forecasting import Exp_In_Context_Forecast
11 |
12 | if __name__ == '__main__':
13 | fix_seed = 2021
14 | random.seed(fix_seed)
15 | torch.manual_seed(fix_seed)
16 | np.random.seed(fix_seed)
17 |
18 | parser = argparse.ArgumentParser(description='AutoTimes')
19 |
20 | # basic config
21 | parser.add_argument('--task_name', type=str, required=True, default='long_term_forecast',
22 | help='task name, options:[long_term_forecast, short_term_forecast, zero_shot_forecasting, in_context_forecasting]')
23 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status')
24 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id')
25 | parser.add_argument('--model', type=str, required=True, default='AutoTimes_Llama',
26 | help='model name, options: [AutoTimes_Llama, AutoTimes_Gpt2, AutoTimes_Opt1b]')
27 |
28 | # data loader
29 | parser.add_argument('--data', type=str, required=True, default='ETTm1', help='dataset type')
30 | parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file')
31 | parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
32 | parser.add_argument('--test_data_path', type=str, default='ETTh1.csv', help='test data file used in zero shot forecasting')
33 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
34 | parser.add_argument('--drop_last', action='store_true', default=False, help='drop last batch in data loader')
35 | parser.add_argument('--val_set_shuffle', action='store_false', default=True, help='shuffle validation set')
36 | parser.add_argument('--drop_short', action='store_true', default=False, help='drop too short sequences in dataset')
37 |
38 | # forecasting task
39 | parser.add_argument('--seq_len', type=int, default=672, help='input sequence length')
40 | parser.add_argument('--label_len', type=int, default=576, help='label length')
41 | parser.add_argument('--token_len', type=int, default=96, help='token length')
42 | parser.add_argument('--test_seq_len', type=int, default=672, help='test seq len')
43 | parser.add_argument('--test_label_len', type=int, default=576, help='test label len')
44 | parser.add_argument('--test_pred_len', type=int, default=96, help='test pred len')
45 | parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4')
46 |
47 | # model define
48 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
49 | parser.add_argument('--llm_ckp_dir', type=str, default='./llama', help='llm checkpoints dir')
50 | parser.add_argument('--mlp_hidden_dim', type=int, default=256, help='mlp hidden dim')
51 | parser.add_argument('--mlp_hidden_layers', type=int, default=2, help='mlp hidden layers')
52 | parser.add_argument('--mlp_activation', type=str, default='tanh', help='mlp activation')
53 |
54 | # optimization
55 | parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
56 | parser.add_argument('--itr', type=int, default=1, help='experiments times')
57 | parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
58 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
59 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience')
60 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
61 | parser.add_argument('--des', type=str, default='test', help='exp description')
62 | parser.add_argument('--loss', type=str, default='MSE', help='loss function')
63 | parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
64 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
65 | parser.add_argument('--cosine', action='store_true', help='use cosine annealing lr', default=False)
66 | parser.add_argument('--tmax', type=int, default=10, help='tmax in cosine anealing lr')
67 | parser.add_argument('--weight_decay', type=float, default=0)
68 | parser.add_argument('--mix_embeds', action='store_true', help='mix embeds', default=False)
69 | parser.add_argument('--test_dir', type=str, default='./test', help='test dir')
70 | parser.add_argument('--test_file_name', type=str, default='checkpoint.pth', help='test file')
71 |
72 | # GPU
73 | parser.add_argument('--gpu', type=int, default=0, help='gpu')
74 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
75 | parser.add_argument('--visualize', action='store_true', help='visualize', default=False)
76 | args = parser.parse_args()
77 |
78 | if args.use_multi_gpu:
79 | ip = os.environ.get("MASTER_ADDR", "127.0.0.1")
80 | port = os.environ.get("MASTER_PORT", "64209")
81 | hosts = int(os.environ.get("WORLD_SIZE", "8"))
82 | rank = int(os.environ.get("RANK", "0"))
83 | local_rank = int(os.environ.get("LOCAL_RANK", "0"))
84 | gpus = torch.cuda.device_count()
85 | args.local_rank = local_rank
86 | print(ip, port, hosts, rank, local_rank, gpus)
87 | dist.init_process_group(backend="nccl", init_method=f"tcp://{ip}:{port}", world_size=hosts,
88 | rank=rank)
89 | torch.cuda.set_device(local_rank)
90 |
91 | if args.task_name == 'long_term_forecast':
92 | Exp = Exp_Long_Term_Forecast
93 | elif args.task_name == 'short_term_forecast':
94 | Exp = Exp_Short_Term_Forecast
95 | elif args.task_name == 'zero_shot_forecast':
96 | Exp = Exp_Zero_Shot_Forecast
97 | elif args.task_name == 'in_context_forecast':
98 | Exp = Exp_In_Context_Forecast
99 | else:
100 | Exp = Exp_Long_Term_Forecast
101 |
102 | if args.is_training:
103 | for ii in range(args.itr):
104 | # setting record of experiments
105 | exp = Exp(args) # set experiments
106 | setting = '{}_{}_{}_{}_sl{}_ll{}_tl{}_lr{}_bt{}_wd{}_hd{}_hl{}_cos{}_mix{}_{}_{}'.format(
107 | args.task_name,
108 | args.model_id,
109 | args.model,
110 | args.data,
111 | args.seq_len,
112 | args.label_len,
113 | args.token_len,
114 | args.learning_rate,
115 | args.batch_size,
116 | args.weight_decay,
117 | args.mlp_hidden_dim,
118 | args.mlp_hidden_layers,
119 | args.cosine,
120 | args.mix_embeds,
121 | args.des, ii)
122 | if (args.use_multi_gpu and args.local_rank == 0) or not args.use_multi_gpu:
123 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
124 | exp.train(setting)
125 | if (args.use_multi_gpu and args.local_rank == 0) or not args.use_multi_gpu:
126 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
127 | exp.test(setting)
128 | torch.cuda.empty_cache()
129 | else:
130 | ii = 0
131 | setting = '{}_{}_{}_{}_sl{}_ll{}_tl{}_lr{}_bt{}_wd{}_hd{}_hl{}_cos{}_mix{}_{}_{}'.format(
132 | args.task_name,
133 | args.model_id,
134 | args.model,
135 | args.data,
136 | args.seq_len,
137 | args.label_len,
138 | args.token_len,
139 | args.learning_rate,
140 | args.batch_size,
141 | args.weight_decay,
142 | args.mlp_hidden_dim,
143 | args.mlp_hidden_layers,
144 | args.cosine,
145 | args.mix_embeds,
146 | args.des, ii)
147 | exp = Exp(args) # set experiments
148 | exp.test(setting, test=1)
149 | torch.cuda.empty_cache()
150 |
--------------------------------------------------------------------------------
/scripts/in_context_forecasting/M3.sh:
--------------------------------------------------------------------------------
1 | model_name=AutoTimes_Llama
2 |
3 | python -u run.py \
4 | --task_name in_context_forecast \
5 | --is_training 1 \
6 | --root_path ./dataset/m4 \
7 | --test_data_path m3_yearly_dataset.tsf \
8 | --seasonal_patterns 'Yearly' \
9 | --model_id m4_Yearly \
10 | --model $model_name \
11 | --data m4 \
12 | --seq_len 18 \
13 | --label_len 12 \
14 | --token_len 6 \
15 | --test_seq_len 6 \
16 | --test_label_len 0 \
17 | --test_pred_len 6 \
18 | --batch_size 16 \
19 | --des 'Exp' \
20 | --itr 1 \
21 | --learning_rate 0.0005 \
22 | --loss 'SMAPE' \
23 | --use_amp \
24 | --mlp_hidden_dim 128 \
25 | --cosine \
26 | --tmax 10 \
27 | --drop_short
28 |
29 | python -u run.py \
30 | --task_name in_context_forecast \
31 | --is_training 1 \
32 | --root_path ./dataset/m4 \
33 | --test_data_path m3_quarterly_dataset.tsf \
34 | --seasonal_patterns 'Quarterly' \
35 | --model_id m4_Quarterly \
36 | --model $model_name \
37 | --data m4 \
38 | --seq_len 24 \
39 | --label_len 16 \
40 | --token_len 8 \
41 | --test_seq_len 8 \
42 | --test_label_len 0 \
43 | --test_pred_len 8 \
44 | --batch_size 16 \
45 | --des 'Exp' \
46 | --itr 1 \
47 | --learning_rate 0.0001 \
48 | --loss 'SMAPE' \
49 | --use_amp \
50 | --mlp_hidden_dim 1024 \
51 | --cosine \
52 | --tmax 10 \
53 | --drop_short
54 |
55 | python -u run.py \
56 | --task_name in_context_forecast \
57 | --is_training 1 \
58 | --root_path ./dataset/m4 \
59 | --test_data_path m3_monthly_dataset.tsf \
60 | --seasonal_patterns 'Monthly' \
61 | --model_id m4_Monthly \
62 | --model $model_name \
63 | --data m4 \
64 | --seq_len 54 \
65 | --label_len 36 \
66 | --token_len 18 \
67 | --test_seq_len 18 \
68 | --test_label_len 0 \
69 | --test_pred_len 18 \
70 | --batch_size 16 \
71 | --des 'Exp' \
72 | --itr 1 \
73 | --learning_rate 0.0005 \
74 | --loss 'SMAPE' \
75 | --use_amp \
76 | --mlp_hidden_dim 256 \
77 | --cosine \
78 | --tmax 10 \
79 | --drop_short
80 |
81 | python -u run.py \
82 | --task_name in_context_forecast \
83 | --is_training 1 \
84 | --root_path ./dataset/m4 \
85 | --test_data_path m3_other_dataset.tsf \
86 | --seasonal_patterns 'Quarterly' \
87 | --model_id m4_Quarterly \
88 | --model $model_name \
89 | --data m4 \
90 | --seq_len 24 \
91 | --label_len 16 \
92 | --token_len 8 \
93 | --test_seq_len 8 \
94 | --test_label_len 0 \
95 | --test_pred_len 8 \
96 | --batch_size 16 \
97 | --des 'Exp' \
98 | --itr 1 \
99 | --learning_rate 0.0005 \
100 | --loss 'SMAPE' \
101 | --use_amp \
102 | --mlp_hidden_dim 512 \
103 | --cosine \
104 | --tmax 10 \
105 | --drop_short
--------------------------------------------------------------------------------
/scripts/method_generality/gpt2.sh:
--------------------------------------------------------------------------------
1 | model_name=AutoTimes_Gpt2
2 |
3 | python -u run.py \
4 | --task_name long_term_forecast \
5 | --is_training 1 \
6 | --root_path ./dataset/ETT-small/ \
7 | --data_path ETTh1.csv \
8 | --model_id ETTh1_672_96 \
9 | --model $model_name \
10 | --data ETTh1 \
11 | --seq_len 672 \
12 | --label_len 576 \
13 | --token_len 96 \
14 | --test_seq_len 672 \
15 | --test_label_len 576 \
16 | --test_pred_len 96 \
17 | --batch_size 2048 \
18 | --learning_rate 0.002 \
19 | --itr 1 \
20 | --train_epochs 10 \
21 | --use_amp \
22 | --llm_ckp_dir ./models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10 \
23 | --gpu 0 \
24 | --des 'Gpt2' \
25 | --cosine \
26 | --tmax 10 \
27 | --mlp_hidden_dim 512
28 |
29 |
30 | for test_pred_len in 96 192 336 720
31 | do
32 | python -u run.py \
33 | --task_name long_term_forecast \
34 | --is_training 0 \
35 | --root_path ./dataset/ETT-small/ \
36 | --data_path ETTh1.csv \
37 | --model_id ETTh1_672_96 \
38 | --model $model_name \
39 | --data ETTh1 \
40 | --seq_len 672 \
41 | --label_len 576 \
42 | --token_len 96 \
43 | --test_seq_len 672 \
44 | --test_label_len 576 \
45 | --test_pred_len $test_pred_len \
46 | --batch_size 2048 \
47 | --learning_rate 0.002 \
48 | --itr 1 \
49 | --train_epochs 10 \
50 | --use_amp \
51 | --llm_ckp_dir ./models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10 \
52 | --gpu 0 \
53 | --des 'Gpt2' \
54 | --cosine \
55 | --tmax 10 \
56 | --mlp_hidden_dim 512 \
57 | --test_dir long_term_forecast_ETTh1_672_96_AutoTimes_Gpt2_ETTh1_sl672_ll576_tl96_lr0.002_bt2048_wd0_hd512_hl2_cosTrue_mixFalse_Gpt2_0
58 | done
--------------------------------------------------------------------------------
/scripts/method_generality/opt.sh:
--------------------------------------------------------------------------------
1 | model_name=AutoTimes_Opt_1b
2 |
3 | python -u run.py \
4 | --task_name long_term_forecast \
5 | --is_training 1 \
6 | --root_path ./dataset/ETT-small/ \
7 | --data_path ETTh1.csv \
8 | --model_id ETTh1_672_96 \
9 | --model $model_name \
10 | --data ETTh1 \
11 | --seq_len 672 \
12 | --label_len 576 \
13 | --token_len 96 \
14 | --test_seq_len 672 \
15 | --test_label_len 576 \
16 | --test_pred_len 96 \
17 | --batch_size 2048 \
18 | --learning_rate 0.001 \
19 | --itr 1 \
20 | --train_epochs 10 \
21 | --use_amp \
22 | --llm_ckp_dir ./models--facebook--opt-1.3b/snapshots/3f5c25d0bc631cb57ac65913f76e22c2dfb61d62 \
23 | --gpu 0 \
24 | --des 'Opt1b' \
25 | --cosine \
26 | --tmax 10 \
27 | --mlp_hidden_dim 256
28 |
29 | for test_pred_len in 96 192 336 720
30 | do
31 | python -u run.py \
32 | --task_name long_term_forecast \
33 | --is_training 0 \
34 | --root_path ./dataset/ETT-small/ \
35 | --data_path ETTh1.csv \
36 | --model_id ETTh1_672_96 \
37 | --model $model_name \
38 | --data ETTh1 \
39 | --seq_len 672 \
40 | --label_len 576 \
41 | --token_len 96 \
42 | --test_seq_len 672 \
43 | --test_label_len 576 \
44 | --test_pred_len $test_pred_len \
45 | --batch_size 2048 \
46 | --learning_rate 0.001 \
47 | --itr 1 \
48 | --train_epochs 10 \
49 | --use_amp \
50 | --llm_ckp_dir ./models--facebook--opt-1.3b/snapshots/3f5c25d0bc631cb57ac65913f76e22c2dfb61d62 \
51 | --gpu 0 \
52 | --des 'Opt1b' \
53 | --cosine \
54 | --tmax 10 \
55 | --mlp_hidden_dim 256 \
56 | --test_dir long_term_forecast_ETTh1_672_96_AutoTimes_Opt_1b_ETTh1_sl672_ll576_tl96_lr0.001_bt2048_wd0_hd256_hl2_cosTrue_mixFalse_Opt1b_0
57 | done
--------------------------------------------------------------------------------
/scripts/time_series_forecasting/long_term/AutoTimes_ECL.sh:
--------------------------------------------------------------------------------
1 | model_name=AutoTimes_Llama
2 |
3 | # training one model with a context length
4 | torchrun --nnodes 1 --nproc-per-node 8 run.py \
5 | --task_name long_term_forecast \
6 | --is_training 1 \
7 | --root_path ./dataset/electricity/ \
8 | --data_path electricity.csv \
9 | --model_id ECL_672_96 \
10 | --model $model_name \
11 | --data custom \
12 | --seq_len 672 \
13 | --label_len 576 \
14 | --token_len 96 \
15 | --test_seq_len 672 \
16 | --test_label_len 576 \
17 | --test_pred_len 96 \
18 | --batch_size 256 \
19 | --learning_rate 0.001 \
20 | --weight_decay 0.00001 \
21 | --mlp_hidden_dim 1024 \
22 | --train_epochs 10 \
23 | --use_amp \
24 | --mix_embeds \
25 | --use_multi_gpu \
26 | --tmax 10 \
27 | --cosine
28 |
29 | # testing the model on all forecast lengths
30 | for test_pred_len in 96 192 336 720
31 | do
32 | python -u run.py \
33 | --task_name long_term_forecast \
34 | --is_training 0 \
35 | --root_path ./dataset/electricity/ \
36 | --data_path electricity.csv \
37 | --model_id ECL_672_96 \
38 | --model $model_name \
39 | --data custom \
40 | --seq_len 672 \
41 | --label_len 576 \
42 | --token_len 96 \
43 | --test_seq_len 672 \
44 | --test_label_len 576 \
45 | --test_pred_len $test_pred_len \
46 | --batch_size 256 \
47 | --learning_rate 0.001 \
48 | --weight_decay 0.00001 \
49 | --mlp_hidden_dim 1024 \
50 | --train_epochs 10 \
51 | --use_amp \
52 | --mix_embeds \
53 | --tmax 10 \
54 | --cosine \
55 | --test_dir long_term_forecast_ECL_672_96_AutoTimes_Llama_custom_sl672_ll576_tl96_lr0.001_bt256_wd1e-05_hd1024_hl2_cosTrue_mixTrue_test_0
56 | done
--------------------------------------------------------------------------------
/scripts/time_series_forecasting/long_term/AutoTimes_ETTh1.sh:
--------------------------------------------------------------------------------
1 | model_name=AutoTimes_Llama
2 |
3 | # training one model with a context length
4 | python -u run.py \
5 | --task_name long_term_forecast \
6 | --is_training 1 \
7 | --root_path ./dataset/ETT-small/ \
8 | --data_path ETTh1.csv \
9 | --model_id ETTh1_672_96 \
10 | --model $model_name \
11 | --data ETTh1 \
12 | --seq_len 672 \
13 | --label_len 576 \
14 | --token_len 96 \
15 | --test_seq_len 672 \
16 | --test_label_len 576 \
17 | --test_pred_len 96 \
18 | --batch_size 256 \
19 | --learning_rate 0.0005 \
20 | --mlp_hidden_layers 0 \
21 | --train_epochs 10 \
22 | --use_amp \
23 | --gpu 0 \
24 | --cosine \
25 | --tmax 10 \
26 | --mix_embeds \
27 | --drop_last
28 |
29 | # testing the model on all forecast lengths
30 | for test_pred_len in 96 192 336 720
31 | do
32 | python -u run.py \
33 | --task_name long_term_forecast \
34 | --is_training 0 \
35 | --root_path ./dataset/ETT-small/ \
36 | --data_path ETTh1.csv \
37 | --model_id ETTh1_672_96 \
38 | --model $model_name \
39 | --data ETTh1 \
40 | --seq_len 672 \
41 | --label_len 576 \
42 | --token_len 96 \
43 | --test_seq_len 672 \
44 | --test_label_len 576 \
45 | --test_pred_len $test_pred_len \
46 | --batch_size 256 \
47 | --learning_rate 0.0005 \
48 | --mlp_hidden_layers 0 \
49 | --train_epochs 10 \
50 | --use_amp \
51 | --gpu 0 \
52 | --cosine \
53 | --tmax 10 \
54 | --mix_embeds \
55 | --drop_last \
56 | --test_dir long_term_forecast_ETTh1_672_96_AutoTimes_Llama_ETTh1_sl672_ll576_tl96_lr0.0005_bt256_wd0_hd256_hl0_cosTrue_mixTrue_test_0
57 | done
58 |
--------------------------------------------------------------------------------
/scripts/time_series_forecasting/long_term/AutoTimes_Solar.sh:
--------------------------------------------------------------------------------
1 | model_name=AutoTimes_Llama
2 |
3 | # training one model with a context length
4 | torchrun --nnodes 1 --nproc-per-node 4 run.py \
5 | --task_name long_term_forecast \
6 | --is_training 1 \
7 | --root_path ./dataset/Solar/ \
8 | --data_path solar_AL.txt \
9 | --model_id solar_672_96 \
10 | --model $model_name \
11 | --data Solar \
12 | --seq_len 672 \
13 | --label_len 576 \
14 | --token_len 96 \
15 | --test_seq_len 672 \
16 | --test_label_len 576 \
17 | --test_pred_len 96 \
18 | --batch_size 256 \
19 | --learning_rate 0.000005 \
20 | --train_epochs 2 \
21 | --use_amp \
22 | --mlp_hidden_dim 1024 \
23 | --mlp_activation relu \
24 | --des 'Exp' \
25 | --use_multi_gpu \
26 | --cosine \
27 | --tmax 10
28 |
29 | # testing the model on all forecast lengths
30 | for test_pred_len in 96 192 336 720
31 | do
32 | python -u run.py \
33 | --task_name long_term_forecast \
34 | --is_training 0 \
35 | --root_path ./dataset/Solar/ \
36 | --data_path solar_AL.txt \
37 | --model_id solar_672_96 \
38 | --model $model_name \
39 | --data Solar \
40 | --seq_len 672 \
41 | --label_len 576 \
42 | --token_len 96 \
43 | --test_seq_len 672 \
44 | --test_label_len 576 \
45 | --test_pred_len $test_pred_len \
46 | --batch_size 256 \
47 | --learning_rate 0.000005 \
48 | --train_epochs 2 \
49 | --use_amp \
50 | --mlp_hidden_dim 1024 \
51 | --mlp_activation relu \
52 | --des 'Exp' \
53 | --cosine \
54 | --tmax 10 \
55 | --test_dir long_term_forecast_solar_672_96_AutoTimes_Llama_Solar_sl672_ll576_tl96_lr5e-06_bt256_wd0_hd1024_hl2_cosTrue_mixFalse_Exp_0
56 | done
--------------------------------------------------------------------------------
/scripts/time_series_forecasting/long_term/AutoTimes_Traffic.sh:
--------------------------------------------------------------------------------
1 | model_name=AutoTimes_Llama
2 |
3 | # training one model with a context length
4 | torchrun --nnodes 1 --nproc-per-node 8 run.py \
5 | --task_name long_term_forecast \
6 | --is_training 1 \
7 | --root_path ./dataset/traffic/ \
8 | --data_path traffic.csv \
9 | --model_id traffic_672_96 \
10 | --model $model_name \
11 | --data custom \
12 | --seq_len 672 \
13 | --label_len 576 \
14 | --token_len 96 \
15 | --test_seq_len 672 \
16 | --test_label_len 576 \
17 | --test_pred_len 96 \
18 | --batch_size 256 \
19 | --learning_rate 0.0001 \
20 | --weight_decay 0.00001 \
21 | --mlp_hidden_dim 1024 \
22 | --mlp_activation relu \
23 | --train_epochs 10 \
24 | --use_amp \
25 | --cosine \
26 | --tmax 10 \
27 | --mix_embeds \
28 | --use_multi_gpu
29 |
30 | # testing the model on all forecast lengths
31 | for test_pred_len in 96 192 336 720
32 | do
33 | python -u run.py \
34 | --task_name long_term_forecast \
35 | --is_training 0 \
36 | --root_path ./dataset/traffic/ \
37 | --data_path traffic.csv \
38 | --model_id traffic_672_96 \
39 | --model $model_name \
40 | --data custom \
41 | --seq_len 672 \
42 | --label_len 576 \
43 | --token_len 96 \
44 | --test_seq_len 672 \
45 | --test_label_len 576 \
46 | --test_pred_len $test_pred_len \
47 | --batch_size 256 \
48 | --learning_rate 0.0001 \
49 | --weight_decay 0.00001 \
50 | --mlp_hidden_dim 1024 \
51 | --mlp_activation relu \
52 | --train_epochs 10 \
53 | --use_amp \
54 | --gpu 0 \
55 | --cosine \
56 | --tmax 10 \
57 | --mix_embeds \
58 | --test_dir long_term_forecast_traffic_672_96_AutoTimes_Llama_custom_sl672_ll576_tl96_lr0.0001_bt256_wd1e-05_hd1024_hl2_cosTrue_mixTrue_test_0
59 | done
--------------------------------------------------------------------------------
/scripts/time_series_forecasting/long_term/AutoTimes_Weather.sh:
--------------------------------------------------------------------------------
1 | model_name=AutoTimes_Llama
2 |
3 | # training one model with a context length
4 | torchrun --nnodes 1 --nproc-per-node 4 run.py \
5 | --task_name long_term_forecast \
6 | --is_training 1 \
7 | --root_path ./dataset/weather/ \
8 | --data_path weather.csv \
9 | --model_id weather_672_96 \
10 | --model $model_name \
11 | --data custom \
12 | --seq_len 672 \
13 | --label_len 576 \
14 | --token_len 96 \
15 | --test_seq_len 672 \
16 | --test_label_len 576 \
17 | --test_pred_len 96 \
18 | --batch_size 384 \
19 | --learning_rate 0.0005 \
20 | --train_epochs 10 \
21 | --use_amp \
22 | --lradj type2 \
23 | --des 'Exp' \
24 | --mlp_hidden_dim 512 \
25 | --mlp_activation relu \
26 | --use_multi_gpu \
27 | --mix_embeds
28 |
29 | # testing the model on all forecast lengths
30 | for test_pred_len in 96 192 336 720
31 | do
32 | python -u run.py \
33 | --task_name long_term_forecast \
34 | --is_training 0 \
35 | --root_path ./dataset/weather/ \
36 | --data_path weather.csv \
37 | --model_id weather_672_96 \
38 | --model $model_name \
39 | --data custom \
40 | --seq_len 672 \
41 | --label_len 576 \
42 | --token_len 96 \
43 | --test_seq_len 672 \
44 | --test_label_len 576 \
45 | --test_pred_len $test_pred_len \
46 | --batch_size 384 \
47 | --learning_rate 0.0005 \
48 | --train_epochs 10 \
49 | --use_amp \
50 | --lradj type2 \
51 | --des 'Exp' \
52 | --mlp_hidden_dim 512 \
53 | --mlp_activation relu \
54 | --mix_embeds \
55 | --test_dir long_term_forecast_weather_672_96_AutoTimes_Llama_custom_sl672_ll576_tl96_lr0.0005_bt384_wd0_hd512_hl2_cosFalse_mixTrue_Exp_0
56 | done
--------------------------------------------------------------------------------
/scripts/time_series_forecasting/short_term/AutoTimes_M4.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=AutoTimes_Llama
4 | python -u run.py \
5 | --task_name short_term_forecast \
6 | --is_training 1 \
7 | --root_path ./dataset/m4 \
8 | --seasonal_patterns 'Yearly' \
9 | --model_id m4_Yearly \
10 | --model $model_name \
11 | --data m4 \
12 | --batch_size 16 \
13 | --des 'Exp' \
14 | --itr 1 \
15 | --learning_rate 0.0001 \
16 | --loss 'SMAPE' \
17 | --use_amp \
18 | --mlp_hidden_dim 512 \
19 | --cosine \
20 | --tmax 10 \
21 | --weight_decay 0.00001
22 |
23 | python -u run.py \
24 | --task_name short_term_forecast \
25 | --is_training 1 \
26 | --root_path ./dataset/m4 \
27 | --seasonal_patterns 'Quarterly' \
28 | --model_id m4_Quarterly \
29 | --model $model_name \
30 | --data m4 \
31 | --batch_size 16 \
32 | --des 'Exp' \
33 | --itr 1 \
34 | --learning_rate 0.00005 \
35 | --loss 'SMAPE' \
36 | --use_amp \
37 | --mlp_hidden_dim 512 \
38 | --cosine \
39 | --tmax 10 \
40 | --weight_decay 0.000005
41 |
42 | python -u run.py \
43 | --task_name short_term_forecast \
44 | --is_training 1 \
45 | --root_path ./dataset/m4 \
46 | --seasonal_patterns 'Monthly' \
47 | --model_id m4_Monthly \
48 | --model $model_name \
49 | --data m4 \
50 | --batch_size 16 \
51 | --learning_rate 0.00005 \
52 | --des 'Exp' \
53 | --itr 1 \
54 | --loss 'SMAPE' \
55 | --use_amp \
56 | --mlp_hidden_dim 1024 \
57 | --cosine \
58 | --tmax 10 \
59 | --weight_decay 0.000001
60 |
61 | python -u run.py \
62 | --task_name short_term_forecast \
63 | --is_training 1 \
64 | --root_path ./dataset/m4 \
65 | --seasonal_patterns 'Weekly' \
66 | --model_id m4_Weekly \
67 | --model $model_name \
68 | --data m4 \
69 | --batch_size 16 \
70 | --des 'Exp' \
71 | --itr 1 \
72 | --learning_rate 0.0001 \
73 | --loss 'SMAPE' \
74 | --use_amp \
75 | --mlp_hidden_dim 1024 \
76 | --cosine \
77 | --tmax 10
78 |
79 | python -u run.py \
80 | --task_name short_term_forecast \
81 | --is_training 1 \
82 | --root_path ./dataset/m4 \
83 | --seasonal_patterns 'Daily' \
84 | --model_id m4_Daily \
85 | --model $model_name \
86 | --data m4 \
87 | --batch_size 16 \
88 | --des 'Exp' \
89 | --itr 1 \
90 | --learning_rate 0.0005 \
91 | --loss 'SMAPE' \
92 | --use_amp \
93 | --mlp_hidden_dim 1024 \
94 | --weight_decay 0.000005
95 |
96 | python -u run.py \
97 | --task_name short_term_forecast \
98 | --is_training 1 \
99 | --root_path ./dataset/m4 \
100 | --seasonal_patterns 'Hourly' \
101 | --model_id m4_Hourly \
102 | --model $model_name \
103 | --data m4 \
104 | --batch_size 16 \
105 | --des 'Exp' \
106 | --itr 1 \
107 | --learning_rate 0.0001 \
108 | --loss 'SMAPE' \
109 | --use_amp \
110 | --mlp_hidden_dim 1024 \
111 | --cosine \
112 | --tmax 10
--------------------------------------------------------------------------------
/scripts/zero_shot_forecasting/sM3_tM4.sh:
--------------------------------------------------------------------------------
1 | model_name=AutoTimes_Llama
2 |
3 | python -u run.py \
4 | --task_name zero_shot_forecast \
5 | --is_training 1 \
6 | --root_path ./dataset/tsf \
7 | --data_path m3_yearly_dataset.tsf \
8 | --test_data_path m4_yearly_dataset.tsf \
9 | --seasonal_patterns 'Yearly' \
10 | --model_id m3_Yearly \
11 | --model $model_name \
12 | --data tsf \
13 | --seq_len 6 \
14 | --label_len 0 \
15 | --token_len 6 \
16 | --test_seq_len 6 \
17 | --test_label_len 0 \
18 | --test_pred_len 6 \
19 | --learning_rate 0.0001 \
20 | --mlp_hidden_dim 256 \
21 | --mlp_hidden_layers 3 \
22 | --batch_size 16 \
23 | --des 'Exp' \
24 | --itr 1 \
25 | --loss 'SMAPE' \
26 | --use_amp \
27 | --cosine \
28 | --tmax 10 \
29 | --val_set_shuffle
30 |
31 | python -u run.py \
32 | --task_name zero_shot_forecast \
33 | --is_training 1 \
34 | --root_path ./dataset/tsf \
35 | --data_path m3_quarterly_dataset.tsf \
36 | --test_data_path m4_quarterly_dataset.tsf \
37 | --seasonal_patterns 'Quarterly' \
38 | --model_id m3_Quarterly \
39 | --model $model_name \
40 | --data tsf \
41 | --seq_len 8 \
42 | --label_len 0 \
43 | --token_len 8 \
44 | --test_seq_len 8 \
45 | --test_label_len 0 \
46 | --test_pred_len 8 \
47 | --learning_rate 0.000005 \
48 | --mlp_hidden_dim 1024 \
49 | --batch_size 16 \
50 | --des 'Exp' \
51 | --itr 1 \
52 | --loss 'SMAPE' \
53 | --use_amp \
54 | --cosine \
55 | --tmax 10 \
56 | --val_set_shuffle
57 |
58 | python -u run.py \
59 | --task_name zero_shot_forecast \
60 | --is_training 1 \
61 | --root_path ./dataset/tsf \
62 | --data_path m3_monthly_dataset.tsf \
63 | --test_data_path m4_monthly_dataset.tsf \
64 | --seasonal_patterns 'Monthly' \
65 | --model_id m3_Monthly \
66 | --model $model_name \
67 | --data tsf \
68 | --seq_len 24 \
69 | --label_len 0 \
70 | --token_len 24 \
71 | --test_seq_len 24 \
72 | --test_label_len 0 \
73 | --test_pred_len 24 \
74 | --learning_rate 0.00001 \
75 | --mlp_hidden_dim 512 \
76 | --batch_size 16 \
77 | --des 'Exp' \
78 | --itr 1 \
79 | --loss 'SMAPE' \
80 | --use_amp \
81 | --cosine \
82 | --tmax 10 \
83 | --val_set_shuffle
84 |
85 | python -u run.py \
86 | --task_name zero_shot_forecast \
87 | --is_training 1 \
88 | --root_path ./dataset/tsf \
89 | --data_path m3_monthly_dataset.tsf \
90 | --test_data_path m4_weekly_dataset.tsf \
91 | --seasonal_patterns 'Monthly' \
92 | --model_id m3_Monthly \
93 | --model $model_name \
94 | --data tsf \
95 | --seq_len 26 \
96 | --label_len 13 \
97 | --token_len 13 \
98 | --test_seq_len 26 \
99 | --test_label_len 13 \
100 | --test_pred_len 13 \
101 | --learning_rate 0.001 \
102 | --mlp_hidden_dim 256 \
103 | --batch_size 16 \
104 | --des 'Exp' \
105 | --itr 1 \
106 | --loss 'SMAPE' \
107 | --use_amp \
108 | --cosine \
109 | --tmax 10 \
110 | --val_set_shuffle
111 |
112 | python -u run.py \
113 | --task_name zero_shot_forecast \
114 | --is_training 1 \
115 | --root_path ./dataset/tsf \
116 | --data_path m3_monthly_dataset.tsf \
117 | --test_data_path m4_daily_dataset.tsf \
118 | --seasonal_patterns 'Monthly' \
119 | --model_id m3_Monthly \
120 | --model $model_name \
121 | --data tsf \
122 | --seq_len 28 \
123 | --label_len 14 \
124 | --token_len 14 \
125 | --test_seq_len 28 \
126 | --test_label_len 14 \
127 | --test_pred_len 14 \
128 | --learning_rate 0.0001 \
129 | --mlp_hidden_dim 256 \
130 | --batch_size 16 \
131 | --des 'Exp' \
132 | --itr 1 \
133 | --loss 'SMAPE' \
134 | --use_amp \
135 | --cosine \
136 | --tmax 10 \
137 | --val_set_shuffle
138 |
139 | python -u run.py \
140 | --task_name zero_shot_forecast \
141 | --is_training 1 \
142 | --root_path ./dataset/tsf \
143 | --data_path m3_monthly_dataset.tsf \
144 | --test_data_path m4_hourly_dataset.tsf \
145 | --seasonal_patterns 'Monthly' \
146 | --model_id m3_Monthly \
147 | --model $model_name \
148 | --data tsf \
149 | --seq_len 48 \
150 | --label_len 24 \
151 | --token_len 24 \
152 | --test_seq_len 48 \
153 | --test_label_len 24 \
154 | --test_pred_len 48 \
155 | --learning_rate 0.001 \
156 | --mlp_hidden_dim 128 \
157 | --mlp_hidden_layers 3 \
158 | --batch_size 16 \
159 | --des 'Exp' \
160 | --itr 1 \
161 | --loss 'SMAPE' \
162 | --use_amp \
163 | --cosine \
164 | --tmax 10 \
165 | --val_set_shuffle
--------------------------------------------------------------------------------
/scripts/zero_shot_forecasting/sM4_tM3.sh:
--------------------------------------------------------------------------------
1 | model_name=AutoTimes_Llama
2 |
3 | python -u run.py \
4 | --task_name zero_shot_forecast \
5 | --is_training 0 \
6 | --root_path ./dataset/tsf \
7 | --test_data_path m3_yearly_dataset.tsf \
8 | --seasonal_patterns 'Yearly' \
9 | --model_id m4_Yearly \
10 | --model $model_name \
11 | --data tsf \
12 | --seq_len 12 \
13 | --label_len 6 \
14 | --token_len 6 \
15 | --test_seq_len 12 \
16 | --test_label_len 6 \
17 | --test_pred_len 6 \
18 | --batch_size 16 \
19 | --des 'Exp' \
20 | --itr 1 \
21 | --loss 'SMAPE' \
22 | --use_amp \
23 | --mlp_hidden_dim 512 \
24 | --test_dir short_term_forecast_m4_Yearly_AutoTimes_Llama_m4_sl12_ll6_tl6_lr0.0001_bt16_wd1e-05_hd512_hl2_cosTrue_mixFalse_Exp_0
25 |
26 | python -u run.py \
27 | --task_name zero_shot_forecast \
28 | --is_training 0 \
29 | --root_path ./dataset/tsf \
30 | --test_data_path m3_quarterly_dataset.tsf \
31 | --seasonal_patterns 'Quarterly' \
32 | --model_id m4_Quarterly \
33 | --model $model_name \
34 | --data tsf \
35 | --seq_len 16 \
36 | --label_len 8 \
37 | --token_len 8 \
38 | --test_seq_len 16 \
39 | --test_label_len 8 \
40 | --test_pred_len 8 \
41 | --batch_size 16 \
42 | --des 'Exp' \
43 | --itr 1 \
44 | --loss 'SMAPE' \
45 | --use_amp \
46 | --mlp_hidden_dim 512 \
47 | --test_dir short_term_forecast_m4_Quarterly_AutoTimes_Llama_m4_sl16_ll8_tl8_lr5e-05_bt16_wd5e-06_hd512_hl2_cosTrue_mixFalse_Exp_0
48 |
49 | python -u run.py \
50 | --task_name zero_shot_forecast \
51 | --is_training 0 \
52 | --root_path ./dataset/tsf \
53 | --test_data_path m3_monthly_dataset.tsf \
54 | --seasonal_patterns 'Monthly' \
55 | --model_id m4_Monthly \
56 | --model $model_name \
57 | --data tsf \
58 | --seq_len 36 \
59 | --label_len 18 \
60 | --token_len 18 \
61 | --test_seq_len 36 \
62 | --test_label_len 18 \
63 | --test_pred_len 18 \
64 | --batch_size 16 \
65 | --des 'Exp' \
66 | --itr 1 \
67 | --loss 'SMAPE' \
68 | --use_amp \
69 | --mlp_hidden_dim 1024 \
70 | --test_dir short_term_forecast_m4_Monthly_AutoTimes_Llama_m4_sl36_ll18_tl18_lr5e-05_bt16_wd1e-06_hd1024_hl2_cosTrue_mixFalse_Exp_0
71 |
72 | python -u run.py \
73 | --task_name zero_shot_forecast \
74 | --is_training 0 \
75 | --root_path ./dataset/tsf \
76 | --test_data_path m3_other_dataset.tsf \
77 | --seasonal_patterns 'Quarterly' \
78 | --model_id m4_Quarterly \
79 | --model $model_name \
80 | --data tsf \
81 | --seq_len 16 \
82 | --label_len 8 \
83 | --token_len 8 \
84 | --test_seq_len 16 \
85 | --test_label_len 8 \
86 | --test_pred_len 8 \
87 | --batch_size 16 \
88 | --des 'Exp' \
89 | --itr 1 \
90 | --loss 'SMAPE' \
91 | --use_amp \
92 | --mlp_hidden_dim 512 \
93 | --test_dir short_term_forecast_m4_Quarterly_AutoTimes_Llama_m4_sl16_ll8_tl8_lr5e-05_bt16_wd5e-06_hd512_hl2_cosTrue_mixFalse_Exp_0
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/utils/__init__.py
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | # This source code is provided for the purposes of scientific reproducibility
2 | # under the following limited license from Element AI Inc. The code is an
3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis
4 | # expansion analysis for interpretable time series forecasting,
5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is
6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0
7 | # International license (CC BY-NC 4.0):
8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether
9 | # for the benefit of third parties or internally in production) requires an
10 | # explicit license. The subject-matter of the N-BEATS model and associated
11 | # materials are the property of Element AI Inc. and may be subject to patent
12 | # protection. No license to patents is granted hereunder (whether express or
13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved.
14 |
15 | """
16 | Loss functions for PyTorch.
17 | """
18 |
19 | import torch as t
20 | import torch.nn as nn
21 | import numpy as np
22 | import pdb
23 |
24 |
25 | def divide_no_nan(a, b):
26 | """
27 | a/b where the resulted NaN or Inf are replaced by 0.
28 | """
29 | result = a / b
30 | result[result != result] = .0
31 | result[result == np.inf] = .0
32 | return result
33 |
34 |
35 | class mape_loss(nn.Module):
36 | def __init__(self):
37 | super(mape_loss, self).__init__()
38 |
39 | def forward(self, insample: t.Tensor, freq: int,
40 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float:
41 | """
42 | MAPE loss as defined in: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error
43 |
44 | :param forecast: Forecast values. Shape: batch, time
45 | :param target: Target values. Shape: batch, time
46 | :param mask: 0/1 mask. Shape: batch, time
47 | :return: Loss value
48 | """
49 | weights = divide_no_nan(mask, target)
50 | return t.mean(t.abs((forecast - target) * weights))
51 |
52 |
53 | class smape_loss(nn.Module):
54 | def __init__(self):
55 | super(smape_loss, self).__init__()
56 |
57 | def forward(self, insample: t.Tensor, freq: int,
58 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float:
59 | """
60 | sMAPE loss as defined in https://robjhyndman.com/hyndsight/smape/ (Makridakis 1993)
61 |
62 | :param forecast: Forecast values. Shape: batch, time
63 | :param target: Target values. Shape: batch, time
64 | :param mask: 0/1 mask. Shape: batch, time
65 | :return: Loss value
66 | """
67 | return 200 * t.mean(divide_no_nan(t.abs(forecast - target),
68 | t.abs(forecast.data) + t.abs(target.data)) * mask)
69 |
70 |
71 | class mase_loss(nn.Module):
72 | def __init__(self):
73 | super(mase_loss, self).__init__()
74 |
75 | def forward(self, insample: t.Tensor, freq: int,
76 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float:
77 | """
78 | MASE loss as defined in "Scaled Errors" https://robjhyndman.com/papers/mase.pdf
79 |
80 | :param insample: Insample values. Shape: batch, time_i
81 | :param freq: Frequency value
82 | :param forecast: Forecast values. Shape: batch, time_o
83 | :param target: Target values. Shape: batch, time_o
84 | :param mask: 0/1 mask. Shape: batch, time_o
85 | :return: Loss value
86 | """
87 | masep = t.mean(t.abs(insample[:, freq:] - insample[:, :-freq]), dim=1)
88 | masked_masep_inv = divide_no_nan(mask, masep[:, None])
89 | return t.mean(t.abs(target - forecast) * masked_masep_inv)
90 |
91 | class zero_shot_smape_loss(nn.Module):
92 | def __init__(self):
93 | super(zero_shot_smape_loss, self).__init__()
94 | def forward(self, pred, true):
95 | return t.mean(200 * t.abs(pred - true) / (t.abs(pred) + t.abs(true) + 1e-8))
--------------------------------------------------------------------------------
/utils/m4_summary.py:
--------------------------------------------------------------------------------
1 | # This source code is provided for the purposes of scientific reproducibility
2 | # under the following limited license from Element AI Inc. The code is an
3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis
4 | # expansion analysis for interpretable time series forecasting,
5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is
6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0
7 | # International license (CC BY-NC 4.0):
8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether
9 | # for the benefit of third parties or internally in production) requires an
10 | # explicit license. The subject-matter of the N-BEATS model and associated
11 | # materials are the property of Element AI Inc. and may be subject to patent
12 | # protection. No license to patents is granted hereunder (whether express or
13 | # implied). Copyright 2020 Element AI Inc. All rights reserved.
14 |
15 | """
16 | M4 Summary
17 | """
18 | from collections import OrderedDict
19 |
20 | import numpy as np
21 | import pandas as pd
22 |
23 | from data_provider.m4 import M4Dataset
24 | from data_provider.m4 import M4Meta
25 | import os
26 |
27 |
28 | def group_values(values, groups, group_name):
29 | return np.array([v[~np.isnan(v)] for v in values[groups == group_name]])
30 |
31 |
32 | def mase(forecast, insample, outsample, frequency):
33 | return np.mean(np.abs(forecast - outsample)) / np.mean(np.abs(insample[:-frequency] - insample[frequency:]))
34 |
35 |
36 | def smape_2(forecast, target):
37 | denom = np.abs(target) + np.abs(forecast)
38 | # divide by 1.0 instead of 0.0, in case when denom is zero the enumerator will be 0.0 anyway.
39 | denom[denom == 0.0] = 1.0
40 | return 200 * np.abs(forecast - target) / denom
41 |
42 |
43 | def mape(forecast, target):
44 | denom = np.abs(target)
45 | # divide by 1.0 instead of 0.0, in case when denom is zero the enumerator will be 0.0 anyway.
46 | denom[denom == 0.0] = 1.0
47 | return 100 * np.abs(forecast - target) / denom
48 |
49 |
50 | class M4Summary:
51 | def __init__(self, file_path, root_path):
52 | self.file_path = file_path
53 | self.training_set = M4Dataset.load(training=True, dataset_file=root_path)
54 | self.test_set = M4Dataset.load(training=False, dataset_file=root_path)
55 | self.naive_path = os.path.join(root_path, 'submission-Naive2.csv')
56 |
57 | def evaluate(self):
58 | """
59 | Evaluate forecasts using M4 test dataset.
60 |
61 | :param forecast: Forecasts. Shape: timeseries, time.
62 | :return: sMAPE and OWA grouped by seasonal patterns.
63 | """
64 | grouped_owa = OrderedDict()
65 |
66 | naive2_forecasts = pd.read_csv(self.naive_path).values[:, 1:].astype(np.float32)
67 | naive2_forecasts = np.array([v[~np.isnan(v)] for v in naive2_forecasts])
68 |
69 | model_mases = {}
70 | naive2_smapes = {}
71 | naive2_mases = {}
72 | grouped_smapes = {}
73 | grouped_mapes = {}
74 | for group_name in M4Meta.seasonal_patterns:
75 | file_name = self.file_path + group_name + "_forecast.csv"
76 | if os.path.exists(file_name):
77 | model_forecast = pd.read_csv(file_name).values
78 |
79 | naive2_forecast = group_values(naive2_forecasts, self.test_set.groups, group_name)
80 | target = group_values(self.test_set.values, self.test_set.groups, group_name)
81 | # all timeseries within group have same frequency
82 | frequency = self.training_set.frequencies[self.test_set.groups == group_name][0]
83 | insample = group_values(self.training_set.values, self.test_set.groups, group_name)
84 |
85 | model_mases[group_name] = np.mean([mase(forecast=model_forecast[i],
86 | insample=insample[i],
87 | outsample=target[i],
88 | frequency=frequency) for i in range(len(model_forecast))])
89 | naive2_mases[group_name] = np.mean([mase(forecast=naive2_forecast[i],
90 | insample=insample[i],
91 | outsample=target[i],
92 | frequency=frequency) for i in range(len(model_forecast))])
93 |
94 | naive2_smapes[group_name] = np.mean(smape_2(naive2_forecast, target))
95 | grouped_smapes[group_name] = np.mean(smape_2(forecast=model_forecast, target=target))
96 | grouped_mapes[group_name] = np.mean(mape(forecast=model_forecast, target=target))
97 |
98 | grouped_smapes = self.summarize_groups(grouped_smapes)
99 | grouped_mapes = self.summarize_groups(grouped_mapes)
100 | grouped_model_mases = self.summarize_groups(model_mases)
101 | grouped_naive2_smapes = self.summarize_groups(naive2_smapes)
102 | grouped_naive2_mases = self.summarize_groups(naive2_mases)
103 | for k in grouped_model_mases.keys():
104 | grouped_owa[k] = (grouped_model_mases[k] / grouped_naive2_mases[k] +
105 | grouped_smapes[k] / grouped_naive2_smapes[k]) / 2
106 |
107 | def round_all(d):
108 | return dict(map(lambda kv: (kv[0], np.round(kv[1], 3)), d.items()))
109 |
110 | return round_all(grouped_smapes), round_all(grouped_owa), round_all(grouped_mapes), round_all(
111 | grouped_model_mases)
112 |
113 | def summarize_groups(self, scores):
114 | """
115 | Re-group scores respecting M4 rules.
116 | :param scores: Scores per group.
117 | :return: Grouped scores.
118 | """
119 | scores_summary = OrderedDict()
120 |
121 | def group_count(group_name):
122 | return len(np.where(self.test_set.groups == group_name)[0])
123 |
124 | weighted_score = {}
125 | for g in ['Yearly', 'Quarterly', 'Monthly']:
126 | weighted_score[g] = scores[g] * group_count(g)
127 | scores_summary[g] = scores[g]
128 |
129 | others_score = 0
130 | others_count = 0
131 | for g in ['Weekly', 'Daily', 'Hourly']:
132 | others_score += scores[g] * group_count(g)
133 | others_count += group_count(g)
134 | weighted_score['Others'] = others_score
135 | scores_summary['Others'] = others_score / others_count
136 |
137 | average = np.sum(list(weighted_score.values())) / len(self.test_set.groups)
138 | scores_summary['Average'] = average
139 |
140 | return scores_summary
141 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def RSE(pred, true):
5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
6 |
7 |
8 | def CORR(pred, true):
9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
11 | return (u / d).mean(-1)
12 |
13 |
14 | def MAE(pred, true):
15 | return np.mean(np.abs(pred - true))
16 |
17 |
18 | def MSE(pred, true):
19 | return np.mean((pred - true) ** 2)
20 |
21 |
22 | def RMSE(pred, true):
23 | return np.sqrt(MSE(pred, true))
24 |
25 |
26 | def MAPE(pred, true):
27 | return np.mean(np.abs((pred - true) / true))
28 |
29 |
30 | def MSPE(pred, true):
31 | return np.mean(np.square((pred - true) / true))
32 |
33 |
34 | def metric(pred, true):
35 | mae = MAE(pred, true)
36 | mse = MSE(pred, true)
37 | rmse = RMSE(pred, true)
38 | mape = MAPE(pred, true)
39 | mspe = MSPE(pred, true)
40 |
41 | return mae, mse, rmse, mape, mspe
42 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | import pandas as pd
5 | import math
6 | import torch.distributed as dist
7 | from distutils.util import strtobool
8 | from datetime import datetime
9 | plt.switch_backend('agg')
10 |
11 |
12 | def adjust_learning_rate(optimizer, epoch, args):
13 | if args.lradj == 'type1':
14 | lr_adjust = {epoch: args.learning_rate * (0.5 ** (epoch - 1))}
15 | elif args.lradj == 'type2':
16 | lr_adjust = {epoch: args.learning_rate * (0.6 ** epoch)}
17 | elif args.lradj == "cosine":
18 | lr_adjust = {epoch: args.learning_rate /2 * (1 + math.cos(epoch / args.train_epochs * math.pi))}
19 |
20 | if epoch in lr_adjust.keys():
21 | lr = lr_adjust[epoch]
22 | for param_group in optimizer.param_groups:
23 | param_group['lr'] = lr
24 | if (args.use_multi_gpu and args.local_rank == 0) or not args.use_multi_gpu:
25 | print('next learning rate is {}'.format(lr))
26 |
27 | class EarlyStopping:
28 | def __init__(self, args, verbose=False, delta=0):
29 | self.patience = args.patience
30 | self.verbose = verbose
31 | self.counter = 0
32 | self.best_score = None
33 | self.early_stop = False
34 | self.val_loss_min = np.Inf
35 | self.delta = delta
36 | self.use_multi_gpu = args.use_multi_gpu
37 | if self.use_multi_gpu:
38 | self.local_rank = args.local_rank
39 | else:
40 | self.local_rank = None
41 |
42 | def __call__(self, val_loss, model, path):
43 | score = -val_loss
44 | if self.best_score is None:
45 | self.best_score = score
46 | if self.verbose:
47 | if (self.use_multi_gpu and self.local_rank == 0) or not self.use_multi_gpu:
48 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
49 | self.val_loss_min = val_loss
50 | if self.use_multi_gpu:
51 | if self.local_rank == 0:
52 | self.save_checkpoint(val_loss, model, path)
53 | dist.barrier()
54 | else:
55 | self.save_checkpoint(val_loss, model, path)
56 | elif score < self.best_score + self.delta:
57 | self.counter += 1
58 | if (self.use_multi_gpu and self.local_rank == 0) or not self.use_multi_gpu:
59 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
60 | if self.counter >= self.patience:
61 | self.early_stop = True
62 | else:
63 | self.best_score = score
64 | if self.use_multi_gpu:
65 | if self.local_rank == 0:
66 | self.save_checkpoint(val_loss, model, path)
67 | dist.barrier()
68 | else:
69 | self.save_checkpoint(val_loss, model, path)
70 | if self.verbose:
71 | if (self.use_multi_gpu and self.local_rank == 0) or not self.use_multi_gpu:
72 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
73 | self.val_loss_min = val_loss
74 | self.counter = 0
75 |
76 |
77 | def save_checkpoint(self, val_loss, model, path):
78 | param_grad_dic = {
79 | k: v.requires_grad for (k, v) in model.named_parameters()
80 | }
81 | state_dict = model.state_dict()
82 | for k in list(state_dict.keys()):
83 | if k in param_grad_dic.keys() and not param_grad_dic[k]:
84 | # delete parameters that do not require gradient
85 | del state_dict[k]
86 | torch.save(state_dict, path + '/' + f'checkpoint.pth')
87 |
88 |
89 |
90 | class dotdict(dict):
91 | """dot.notation access to dictionary attributes"""
92 | __getattr__ = dict.get
93 | __setattr__ = dict.__setitem__
94 | __delattr__ = dict.__delitem__
95 |
96 |
97 | class StandardScaler():
98 | def __init__(self, mean, std):
99 | self.mean = mean
100 | self.std = std
101 |
102 | def transform(self, data):
103 | return (data - self.mean) / self.std
104 |
105 | def inverse_transform(self, data):
106 | return (data * self.std) + self.mean
107 |
108 |
109 | def visual(true, preds=None, name='./pic/test.pdf'):
110 | """
111 | Results visualization
112 | """
113 | plt.figure()
114 | plt.plot(true, label='GroundTruth', linewidth=2)
115 | if preds is not None:
116 | plt.plot(preds, label='Prediction', linewidth=2)
117 | plt.legend()
118 | plt.savefig(name, bbox_inches='tight')
119 |
120 |
121 | def adjustment(gt, pred):
122 | anomaly_state = False
123 | for i in range(len(gt)):
124 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state:
125 | anomaly_state = True
126 | for j in range(i, 0, -1):
127 | if gt[j] == 0:
128 | break
129 | else:
130 | if pred[j] == 0:
131 | pred[j] = 1
132 | for j in range(i, len(gt)):
133 | if gt[j] == 0:
134 | break
135 | else:
136 | if pred[j] == 0:
137 | pred[j] = 1
138 | elif gt[i] == 0:
139 | anomaly_state = False
140 | if anomaly_state:
141 | pred[i] = 1
142 | return gt, pred
143 |
144 |
145 | def cal_accuracy(y_pred, y_true):
146 | return np.mean(y_pred == y_true)
147 |
148 | def convert_tsf_to_dataframe(
149 | full_file_path_and_name,
150 | replace_missing_vals_with="NaN",
151 | value_column_name="series_value",
152 | ):
153 | col_names = []
154 | col_types = []
155 | all_data = {}
156 | line_count = 0
157 | frequency = None
158 | forecast_horizon = None
159 | contain_missing_values = None
160 | contain_equal_length = None
161 | found_data_tag = False
162 | found_data_section = False
163 | started_reading_data_section = False
164 |
165 | with open(full_file_path_and_name, "r", encoding="cp1252") as file:
166 | for line in file:
167 | # Strip white space from start/end of line
168 | line = line.strip()
169 |
170 | if line:
171 | if line.startswith("@"): # Read meta-data
172 | if not line.startswith("@data"):
173 | line_content = line.split(" ")
174 | if line.startswith("@attribute"):
175 | if (
176 | len(line_content) != 3
177 | ): # Attributes have both name and type
178 | raise Exception("Invalid meta-data specification.")
179 |
180 | col_names.append(line_content[1])
181 | col_types.append(line_content[2])
182 | else:
183 | if (
184 | len(line_content) != 2
185 | ): # Other meta-data have only values
186 | raise Exception("Invalid meta-data specification.")
187 |
188 | if line.startswith("@frequency"):
189 | frequency = line_content[1]
190 | elif line.startswith("@horizon"):
191 | forecast_horizon = int(line_content[1])
192 | elif line.startswith("@missing"):
193 | contain_missing_values = bool(
194 | strtobool(line_content[1])
195 | )
196 | elif line.startswith("@equallength"):
197 | contain_equal_length = bool(strtobool(line_content[1]))
198 |
199 | else:
200 | if len(col_names) == 0:
201 | raise Exception(
202 | "Missing attribute section. Attribute section must come before data."
203 | )
204 |
205 | found_data_tag = True
206 | elif not line.startswith("#"):
207 | if len(col_names) == 0:
208 | raise Exception(
209 | "Missing attribute section. Attribute section must come before data."
210 | )
211 | elif not found_data_tag:
212 | raise Exception("Missing @data tag.")
213 | else:
214 | if not started_reading_data_section:
215 | started_reading_data_section = True
216 | found_data_section = True
217 | all_series = []
218 |
219 | for col in col_names:
220 | all_data[col] = []
221 |
222 | full_info = line.split(":")
223 |
224 | if len(full_info) != (len(col_names) + 1):
225 | raise Exception("Missing attributes/values in series.")
226 |
227 | series = full_info[len(full_info) - 1]
228 | series = series.split(",")
229 |
230 | if len(series) == 0:
231 | raise Exception(
232 | "A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series. Missing values should be indicated with ? symbol"
233 | )
234 |
235 | numeric_series = []
236 |
237 | for val in series:
238 | if val == "?":
239 | numeric_series.append(replace_missing_vals_with)
240 | else:
241 | numeric_series.append(float(val))
242 |
243 | if numeric_series.count(replace_missing_vals_with) == len(
244 | numeric_series
245 | ):
246 | raise Exception(
247 | "All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series."
248 | )
249 |
250 | all_series.append(pd.Series(numeric_series).array)
251 |
252 | for i in range(len(col_names)):
253 | att_val = None
254 | if col_types[i] == "numeric":
255 | att_val = int(full_info[i])
256 | elif col_types[i] == "string":
257 | att_val = str(full_info[i])
258 | elif col_types[i] == "date":
259 | att_val = datetime.strptime(
260 | full_info[i], "%Y-%m-%d %H-%M-%S"
261 | )
262 | else:
263 | raise Exception(
264 | "Invalid attribute type."
265 | ) # Currently, the code supports only numeric, string and date types. Extend this as required.
266 |
267 | if att_val is None:
268 | raise Exception("Invalid attribute value.")
269 | else:
270 | all_data[col_names[i]].append(att_val)
271 |
272 | line_count = line_count + 1
273 |
274 | if line_count == 0:
275 | raise Exception("Empty file.")
276 | if len(col_names) == 0:
277 | raise Exception("Missing attribute section.")
278 | if not found_data_section:
279 | raise Exception("Missing series information under data section.")
280 |
281 | all_data[value_column_name] = all_series
282 | loaded_data = pd.DataFrame(all_data)
283 |
284 | return (
285 | loaded_data,
286 | frequency,
287 | forecast_horizon,
288 | contain_missing_values,
289 | contain_equal_length,
290 | )
291 |
292 | class dotdict(dict):
293 | """dot.notation access to dictionary attributes"""
294 | __getattr__ = dict.get
295 | __setattr__ = dict.__setitem__
296 | __delattr__ = dict.__delitem__
--------------------------------------------------------------------------------