├── .gitignore ├── LICENSE ├── README.md ├── data_provider ├── __init__.py ├── data_factory.py ├── data_loader.py └── m4.py ├── exp ├── exp_basic.py ├── exp_in_context_forecasting.py ├── exp_long_term_forecasting.py ├── exp_short_term_forecasting.py └── exp_zero_shot_forecasting.py ├── figures ├── ablation.png ├── ablation_llm.png ├── adaption_efficiency.png ├── comparison.png ├── formulation.png ├── icon.png ├── illustration.png ├── in-context.png ├── lora.png ├── method.png ├── motivation.png ├── one-for-all_results.png ├── param.png ├── showcases.png ├── subway_icf.png └── zeroshot_results.png ├── layers ├── __init__.py └── mlp.py ├── models ├── AutoTimes_Gpt2.py ├── AutoTimes_Llama.py ├── AutoTimes_Opt_1b.py ├── Preprocess_Llama.py └── __init__.py ├── predict.ipynb ├── preprocess.py ├── requirements.txt ├── run.py ├── scripts ├── in_context_forecasting │ └── M3.sh ├── method_generality │ ├── gpt2.sh │ └── opt.sh ├── time_series_forecasting │ ├── long_term │ │ ├── AutoTimes_ECL.sh │ │ ├── AutoTimes_ETTh1.sh │ │ ├── AutoTimes_Solar.sh │ │ ├── AutoTimes_Traffic.sh │ │ └── AutoTimes_Weather.sh │ └── short_term │ │ └── AutoTimes_M4.sh └── zero_shot_forecasting │ ├── sM3_tM4.sh │ └── sM4_tM3.sh └── utils ├── __init__.py ├── losses.py ├── m4_summary.py ├── metrics.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | dataset/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | */.DS_Store 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | /dataset/ 137 | /llama/ 138 | /checkpoints/ 139 | /test_results/ 140 | /m4_results/ 141 | /models--gpt2/ 142 | /models--facebook--opt-1.3b/ 143 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 THUML @ Tsinghua University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoTimes (Large Language Models for Time Series Forecasting) 2 | 3 | Official implementation: [AutoTimes: Autoregressive Time Series Forecasters via Large Language Models](https://arxiv.org/abs/2402.02370). [[Slides]](https://cloud.tsinghua.edu.cn/f/7689d30f92594ded84f0/), [[Poster]](https://cloud.tsinghua.edu.cn/f/f2c18ae34fef4e74ad46/) 4 | 5 | 6 | > **[Time Series Forecasting](./scripts/time_series_forecasting/)**: AutoTimes convert LLMs to autoregressive time series forecasters. Unlike previous methods, the forecaster can accommodate arbitrary-length lookback and predictions. 7 | 8 | > **[Zero-Shot Forecasting](./scripts/zero_shot_forecasting/)**: AutoTimes takes advantage of LLM's general-purposed token transition as the future extrapolation of time series, demonstrating good performance without downstream samples. 9 | 10 | > **[In-Context Forecasting](./scripts/in_context_forecasting/)**: We propose in-context forecasting for the first time, where time series prompts can be incorporated into the input context to enhance forecasting. 11 | 12 | > **[Easy-to-Use](scripts/method_generality)**: AutoTimes is compatiable with any decoder-only large language models, demonstrating generality and proper scaling behavior. 13 | 14 | ## Updates 15 | 16 | :triangular_flag_on_post: News (2024.10): An introduction of our works in available [[Slides]](https://cloud.tsinghua.edu.cn/f/7689d30f92594ded84f0/). See you at **NeurIPS 2024**! 17 | 18 | :triangular_flag_on_post: News (2024.10): AutoTimes has been accepted by **NeurIPS 2024**. [A revised version](https://arxiv.org/pdf/2402.02370) (**25 Pages**) is now available, including prompt engineering of in-context forecasting, adaptation cost evaluations, textual embeddings of metadata, and low-rank adaptation techique. 19 | 20 | :triangular_flag_on_post: News (2024.08): [Recent work](https://arxiv.org/abs/2406.16964) [(code)](https://github.com/bennytmt/ts_models) has also raised questions about previous non-autoregressive LLM4TS methods. We conduct ablations [here](./figures/ablation_llm.png), highlighting AutoTimes can truly utilize LLMs. Instead of adopting LLMs in a BERT-style, **the general-purpose token transition is transferable among time series and natural language**. 21 | 22 |

23 | 24 |

25 | 26 | :triangular_flag_on_post: **News** (2024.2) Scripts for the above tasks in our [paper](https://arxiv.org/pdf/2402.02370.pdf) are all available. 27 | 28 | ## Introduction 29 | 30 | 🌟 While prevalent LLM4TS methods adapt LLMs as encoder-only and non-autoregressive forecasters, we propose to **keep consistent with the inherent autoregressive property and model architecture**. 31 | 32 |

33 | 34 |

35 | 36 | 💪 We aim to **fully revitalize LLMs as foundation models for time series forecasting**, including multi-step forecasting, zero-shot capability, **in-context forecasting**, and multimodal utilization. 37 | 38 | 🏆 AutoTimes achieves **state-of-the-art performance** with **0.1% trainable parameters and over 5× training/inference speedup** compared to advanced LLM-based forecasters. 39 | 40 | ## Usage 41 | 42 | 1. Install Pytorch and necessary dependencies. 43 | 44 | ``` 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | 1. Put the datasets [[Google Drive]](https://drive.google.com/file/d/1t7jOkctNJ0rt3VMwZaqmxSuA75TFEo96/view?usp=sharing) 49 | [[Tsinghua Cloud]](https://cloud.tsinghua.edu.cn/f/0a758154e0d44de890e3/) under the folder ```./dataset/```. 50 | 51 | 2. Download the large language models from [Hugging Face](https://huggingface.co/). The default LLM is LLaMA-7B, you can change the `llm_ckp_dir` in `run.py` to use other LLMs. 52 | * [LLaMA-7B](https://huggingface.co/meta-llama/Llama-2-7b) 53 | * [OPT Family](https://huggingface.co/facebook/opt-125m) 54 | * [GPT2](https://huggingface.co/openai-community/gpt2) 55 | 56 | For example, if you download and put the LLaMA directory successfully, the directory structure is as follows: 57 | - data_provider 58 | - dataset 59 | - llama 60 | - config.json 61 | - pytorch_model-00001-of-00002.bin 62 | - pytorch_model-00002-of-00002.bin 63 | - ... 64 | - ... 65 | - run.py 66 | 67 | 3. Using the position embedding from textual timestamps. Note that we have provided the embeddings of the given datasets in the download links, which are generated by LLaMA, suffixed by `{dataset_name}.pt`. If you want to generate the embeddings from your customized datasets, please refer to the following codes: 68 | ``` 69 | # preprocess timestamps to generate text embedding 70 | python ./preprocess.py --gpu 0 --dataset ETTh1 71 | ``` 72 | 73 | 4. Train and evaluate the model. We provide all the above tasks under the folder ```./scripts/```. 74 | 75 | ``` 76 | # the default large language model is LLaMA-7B 77 | 78 | # long-term forecasting 79 | bash ./scripts/time_series_forecasting/long_term/AutoTimes_ETTh1.sh 80 | 81 | # short-term forecasting 82 | bash ./scripts/time_series_forecasting/short_term/AutoTimes_M4.sh 83 | 84 | # zero-shot forecasting 85 | # it's worth noting that sM4_tM3 utilizes models trained 86 | # on short-term, you should run AutoTimes_M4 first 87 | bash ./scripts/zero_shot_forecasting/sM4_tM3.sh 88 | bash ./scripts/zero_shot_forecasting/sM3_tM4.sh 89 | 90 | # in-context forecasting 91 | bash ./scripts/in_context_forecasting/M3.sh 92 | 93 | # try on other large language models 94 | bash ./scripts/method_generality/opt.sh 95 | ``` 96 | 97 | > Due to the simple tokenization and the frozen of LLM blocks, AutoTimes is highly compatiable with LLMs. For example, it requires only **15min** for AutoTime to repurpuse LLaMA-7B on ETTh1 on one single RTX 3090-24G. 98 | 99 | ### A Usage Example 100 | See ```predict.ipynb``` for a simple training and inference workflow. 101 | 102 | ## Overall Approach 103 | 104 |

105 | 106 |

107 | 108 | ## Comparison 109 | 110 |

111 | 112 |

113 | 114 | 115 | ## Time Series Forecasting 116 | 117 | **One-for-all** benchmark: a single forecaster is trained on one dataset and subsequently used for all prediction lengths. 118 | 119 | 120 |

121 | 122 |

123 | 124 | ## In-Context Forecasting 125 | 126 |

127 | 128 |

129 | 130 | Benefiting from time series prompts from the target domain, AutoTimes achieves averaged **13.3%** SMAPE reduction compared with zero-shot forecasting. 131 | 132 | ## Method Efficiency 133 | 134 |

135 | 136 |

137 | 138 | ## Showcases 139 | We investigate different prompt retrieval strategies. Insightful results are provided to reveal the influence of using time series prompts for interactive prediction. 140 | 141 |

142 | 143 |

144 | 145 |

146 | 147 |

148 | 149 | ## Citation 150 | 151 | If you find this repo helpful, please cite our paper. 152 | 153 | ``` 154 | @article{liu2024autotimes, 155 | title={AutoTimes: Autoregressive Time Series Forecasters via Large Language Models}, 156 | author={Liu, Yong and Qin, Guo and Huang, Xiangdong and Wang, Jianmin and Long, Mingsheng}, 157 | journal={arXiv preprint arXiv:2402.02370}, 158 | year={2024} 159 | } 160 | ``` 161 | 162 | ## Acknowledgement 163 | 164 | We appreciate the following GitHub repos a lot for their valuable code and efforts. 165 | - Time-Series-Library (https://github.com/thuml/Time-Series-Library) 166 | - FPT (https://github.com/DAMO-DI-ML/NeurIPS2023-One-Fits-All) 167 | 168 | ## Contact 169 | 170 | If you have any questions or want to use the code, feel free to contact: 171 | * Yong Liu (liuyong21@mails.tsinghua.edu.cn) 172 | * Guo Qin (qinguo24@mails.tsinghua.edu.cn) 173 | -------------------------------------------------------------------------------- /data_provider/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/data_provider/__init__.py -------------------------------------------------------------------------------- /data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_Custom, Dataset_M4, Dataset_Solar, Dataset_TSF, Dataset_TSF_ICL 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.distributed import DistributedSampler 4 | 5 | data_dict = { 6 | 'ETTh1': Dataset_ETT_hour, 7 | 'custom': Dataset_Custom, 8 | 'm4': Dataset_M4, 9 | 'Solar': Dataset_Solar, 10 | 'tsf': Dataset_TSF, 11 | 'tsf_icl': Dataset_TSF_ICL 12 | } 13 | 14 | 15 | def data_provider(args, flag): 16 | Data = data_dict[args.data] 17 | 18 | if flag == 'test': 19 | shuffle_flag = False 20 | drop_last = False 21 | batch_size = args.batch_size 22 | elif flag == 'val': 23 | shuffle_flag = args.val_set_shuffle 24 | drop_last = False 25 | batch_size = args.batch_size 26 | else: 27 | shuffle_flag = True 28 | drop_last = args.drop_last 29 | batch_size = args.batch_size 30 | 31 | if flag in ['train', 'val']: 32 | data_set = Data( 33 | root_path=args.root_path, 34 | data_path=args.data_path, 35 | flag=flag, 36 | size=[args.seq_len, args.label_len, args.token_len], 37 | seasonal_patterns=args.seasonal_patterns, 38 | drop_short=args.drop_short, 39 | ) 40 | else: 41 | data_set = Data( 42 | root_path=args.root_path, 43 | data_path=args.data_path, 44 | flag=flag, 45 | size=[args.test_seq_len, args.test_label_len, args.test_pred_len], 46 | seasonal_patterns=args.seasonal_patterns, 47 | drop_short=args.drop_short, 48 | ) 49 | if (args.use_multi_gpu and args.local_rank == 0) or not args.use_multi_gpu: 50 | print(flag, len(data_set)) 51 | if args.use_multi_gpu: 52 | train_datasampler = DistributedSampler(data_set, shuffle=shuffle_flag) 53 | data_loader = DataLoader(data_set, 54 | batch_size=batch_size, 55 | sampler=train_datasampler, 56 | num_workers=args.num_workers, 57 | persistent_workers=True, 58 | pin_memory=True, 59 | drop_last=drop_last, 60 | ) 61 | else: 62 | data_loader = DataLoader( 63 | data_set, 64 | batch_size=batch_size, 65 | shuffle=shuffle_flag, 66 | num_workers=args.num_workers, 67 | drop_last=drop_last) 68 | return data_set, data_loader -------------------------------------------------------------------------------- /data_provider/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from torch.utils.data import Dataset 7 | from data_provider.m4 import M4Dataset, M4Meta 8 | from sklearn.preprocessing import StandardScaler 9 | from utils.tools import convert_tsf_to_dataframe 10 | import warnings 11 | 12 | warnings.filterwarnings('ignore') 13 | 14 | 15 | class Dataset_ETT_hour(Dataset): 16 | def __init__(self, root_path, flag='train', size=None, data_path='ETTh1.csv', 17 | scale=True, seasonal_patterns=None, drop_short=False): 18 | self.seq_len = size[0] 19 | self.label_len = size[1] 20 | self.pred_len = size[2] 21 | self.token_len = self.seq_len - self.label_len 22 | self.token_num = self.seq_len // self.token_len 23 | self.flag = flag 24 | # init 25 | assert flag in ['train', 'test', 'val'] 26 | type_map = {'train': 0, 'val': 1, 'test': 2} 27 | self.set_type = type_map[flag] 28 | 29 | self.scale = scale 30 | 31 | self.root_path = root_path 32 | self.data_path = data_path 33 | self.__read_data__() 34 | self.enc_in = self.data_x.shape[-1] 35 | self.tot_len = len(self.data_x) - self.seq_len - self.pred_len + 1 36 | 37 | def __read_data__(self): 38 | self.scaler = StandardScaler() 39 | df_raw = pd.read_csv(os.path.join(self.root_path, 40 | self.data_path)) 41 | 42 | border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len] 43 | border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24] 44 | border1 = border1s[self.set_type] 45 | border2 = border2s[self.set_type] 46 | 47 | cols_data = df_raw.columns[1:] 48 | df_data = df_raw[cols_data] 49 | 50 | if self.scale: 51 | train_data = df_data[border1s[0]:border2s[0]] 52 | self.scaler.fit(train_data.values) 53 | data = self.scaler.transform(df_data.values) 54 | else: 55 | data = df_data.values 56 | 57 | data_name = self.data_path.split('.')[0] 58 | self.data_stamp = torch.load(os.path.join(self.root_path, f'{data_name}.pt')) 59 | self.data_stamp = self.data_stamp[border1:border2] 60 | self.data_x = data[border1:border2] 61 | self.data_y = data[border1:border2] 62 | 63 | def __getitem__(self, index): 64 | feat_id = index // self.tot_len 65 | s_begin = index % self.tot_len 66 | 67 | s_end = s_begin + self.seq_len 68 | r_begin = s_end - self.label_len 69 | r_end = r_begin + self.label_len + self.pred_len 70 | seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1] 71 | seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1] 72 | seq_x_mark = self.data_stamp[s_begin:s_end:self.token_len] 73 | seq_y_mark = self.data_stamp[s_end:r_end:self.token_len] 74 | 75 | return seq_x, seq_y, seq_x_mark, seq_y_mark 76 | 77 | def __len__(self): 78 | return (len(self.data_x) - self.seq_len - self.pred_len + 1) * self.enc_in 79 | 80 | def inverse_transform(self, data): 81 | return self.scaler.inverse_transform(data) 82 | 83 | class Dataset_Custom(Dataset): 84 | def __init__(self, root_path, flag='train', size=None, data_path='ETTh1.csv', 85 | scale=True, seasonal_patterns=None, drop_short=False): 86 | self.seq_len = size[0] 87 | self.label_len = size[1] 88 | self.pred_len = size[2] 89 | self.token_len = self.seq_len - self.label_len 90 | self.token_num = self.seq_len // self.token_len 91 | self.flag = flag 92 | # init 93 | assert flag in ['train', 'test', 'val'] 94 | type_map = {'train': 0, 'val': 1, 'test': 2} 95 | self.set_type = type_map[flag] 96 | 97 | self.scale = scale 98 | 99 | self.root_path = root_path 100 | self.data_path = data_path 101 | self.__read_data__() 102 | self.enc_in = self.data_x.shape[-1] 103 | self.tot_len = len(self.data_x) - self.seq_len - self.pred_len + 1 104 | 105 | def __read_data__(self): 106 | self.scaler = StandardScaler() 107 | df_raw = pd.read_csv(os.path.join(self.root_path, 108 | self.data_path)) 109 | num_train = int(len(df_raw) * 0.7) 110 | num_test = int(len(df_raw) * 0.2) 111 | num_vali = len(df_raw) - num_train - num_test 112 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 113 | border2s = [num_train, num_train + num_vali, len(df_raw)] 114 | border1 = border1s[self.set_type] 115 | border2 = border2s[self.set_type] 116 | 117 | 118 | cols_data = df_raw.columns[1:] 119 | df_data = df_raw[cols_data] 120 | 121 | if self.scale: 122 | train_data = df_data[border1s[0]:border2s[0]] 123 | self.scaler.fit(train_data.values) 124 | data = self.scaler.transform(df_data.values) 125 | else: 126 | data = df_data.values 127 | data_name = self.data_path.split('.')[0] 128 | self.data_stamp = torch.load(os.path.join(self.root_path, f'{data_name}.pt')) 129 | self.data_stamp = self.data_stamp[border1:border2] 130 | self.data_x = data[border1:border2] 131 | self.data_y = data[border1:border2] 132 | 133 | 134 | def __getitem__(self, index): 135 | feat_id = index // self.tot_len 136 | s_begin = index % self.tot_len 137 | 138 | s_end = s_begin + self.seq_len 139 | r_begin = s_end - self.label_len 140 | r_end = r_begin + self.label_len + self.pred_len 141 | seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1] 142 | seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1] 143 | seq_x_mark = self.data_stamp[s_begin:s_end:self.token_len] 144 | seq_y_mark = self.data_stamp[s_end:r_end:self.token_len] 145 | return seq_x, seq_y, seq_x_mark, seq_y_mark 146 | 147 | def __len__(self): 148 | return (len(self.data_x) - self.seq_len - self.pred_len + 1) * self.enc_in 149 | 150 | def inverse_transform(self, data): 151 | return self.scaler.inverse_transform(data) 152 | 153 | 154 | class Dataset_Solar(Dataset): 155 | def __init__(self, root_path, flag='train', size=None, data_path='ETTh1.csv', 156 | seasonal_patterns=None, scale=True, drop_short=False): 157 | # size [seq_len, label_len, pred_len] 158 | # info 159 | self.seq_len = size[0] 160 | self.label_len = size[1] 161 | self.pred_len = size[2] 162 | 163 | self.token_len = self.seq_len - self.label_len 164 | self.token_num = self.seq_len // self.token_len 165 | self.flag = flag 166 | # init 167 | assert flag in ['train', 'test', 'val'] 168 | type_map = {'train': 0, 'val': 1, 'test': 2} 169 | self.set_type = type_map[flag] 170 | 171 | self.scale = scale 172 | 173 | self.root_path = root_path 174 | self.data_path = data_path 175 | self.__read_data__() 176 | self.enc_in = self.data_x.shape[-1] 177 | self.tot_len = len(self.data_x) - self.seq_len - self.pred_len + 1 178 | 179 | def __read_data__(self): 180 | self.scaler = StandardScaler() 181 | df_raw = [] 182 | with open(os.path.join(self.root_path, self.data_path), "r", encoding='utf-8') as f: 183 | for line in f.readlines(): 184 | line = line.strip('\n').split(',') 185 | data_line = np.stack([float(i) for i in line]) 186 | df_raw.append(data_line) 187 | df_raw = np.stack(df_raw, 0) 188 | df_raw = pd.DataFrame(df_raw) 189 | 190 | num_train = int(len(df_raw) * 0.7) 191 | num_test = int(len(df_raw) * 0.2) 192 | num_valid = int(len(df_raw) * 0.1) 193 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 194 | border2s = [num_train, num_train + num_valid, len(df_raw)] 195 | border1 = border1s[self.set_type] 196 | border2 = border2s[self.set_type] 197 | 198 | df_data = df_raw.values 199 | 200 | if self.scale: 201 | train_data = df_data[border1s[0]:border2s[0]] 202 | self.scaler.fit(train_data) 203 | data = self.scaler.transform(df_data) 204 | else: 205 | data = df_data 206 | 207 | self.data_x = data[border1:border2] 208 | self.data_y = data[border1:border2] 209 | 210 | def __getitem__(self, index): 211 | feat_id = index // self.tot_len 212 | s_begin = index % self.tot_len 213 | 214 | s_end = s_begin + self.seq_len 215 | r_begin = s_end - self.label_len 216 | r_end = r_begin + self.label_len + self.pred_len 217 | seq_x = self.data_x[s_begin:s_end, feat_id:feat_id+1] 218 | seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1] 219 | seq_x_mark = torch.zeros((seq_x.shape[0], 1)) 220 | seq_y_mark = torch.zeros((seq_x.shape[0], 1)) 221 | 222 | return seq_x, seq_y, seq_x_mark, seq_y_mark 223 | 224 | def __len__(self): 225 | return (len(self.data_x) - self.seq_len - self.pred_len + 1) * self.enc_in 226 | 227 | def inverse_transform(self, data): 228 | return self.scaler.inverse_transform(data) 229 | 230 | 231 | class Dataset_M4(Dataset): 232 | def __init__(self, root_path, flag='pred', size=None, data_path='ETTh1.csv', 233 | scale=False, inverse=False, seasonal_patterns='Yearly', drop_short=False): 234 | self.scale = scale 235 | self.inverse = inverse 236 | self.root_path = root_path 237 | 238 | self.seq_len = size[0] 239 | self.label_len = size[1] 240 | self.pred_len = size[2] 241 | 242 | self.seasonal_patterns = seasonal_patterns 243 | self.history_size = M4Meta.history_size[seasonal_patterns] 244 | self.window_sampling_limit = int(self.history_size * self.pred_len) 245 | self.flag = flag 246 | 247 | self.__read_data__() 248 | 249 | def __read_data__(self): 250 | # M4Dataset.initialize() 251 | if self.flag == 'train': 252 | dataset = M4Dataset.load(training=True, dataset_file=self.root_path) 253 | else: 254 | dataset = M4Dataset.load(training=False, dataset_file=self.root_path) 255 | training_values = np.array( 256 | [v[~np.isnan(v)] for v in 257 | dataset.values[dataset.groups == self.seasonal_patterns]]) # split different frequencies 258 | self.ids = np.array([i for i in dataset.ids[dataset.groups == self.seasonal_patterns]]) 259 | self.timeseries = [ts for ts in training_values] 260 | 261 | def __getitem__(self, index): 262 | insample = np.zeros((self.seq_len, 1)) 263 | insample_mask = np.zeros((self.seq_len, 1)) 264 | outsample = np.zeros((self.pred_len + self.label_len, 1)) 265 | outsample_mask = np.zeros((self.pred_len + self.label_len, 1)) # m4 dataset 266 | 267 | sampled_timeseries = self.timeseries[index] 268 | cut_point = np.random.randint(low=max(1, len(sampled_timeseries) - self.window_sampling_limit), 269 | high=len(sampled_timeseries), 270 | size=1)[0] 271 | 272 | insample_window = sampled_timeseries[max(0, cut_point - self.seq_len):cut_point] 273 | insample[-len(insample_window):, 0] = insample_window 274 | insample_mask[-len(insample_window):, 0] = 1.0 275 | outsample_window = sampled_timeseries[ 276 | cut_point - self.label_len:min(len(sampled_timeseries), cut_point + self.pred_len)] 277 | outsample[:len(outsample_window), 0] = outsample_window 278 | outsample_mask[:len(outsample_window), 0] = 1.0 279 | return insample, outsample, insample_mask, outsample_mask 280 | 281 | def __len__(self): 282 | return len(self.timeseries) 283 | 284 | def inverse_transform(self, data): 285 | return self.scaler.inverse_transform(data) 286 | 287 | def last_insample_window(self): 288 | """ 289 | The last window of insample size of all timeseries. 290 | This function does not support batching and does not reshuffle timeseries. 291 | 292 | :return: Last insample window of all timeseries. Shape "timeseries, insample size" 293 | """ 294 | insample = np.zeros((len(self.timeseries), self.seq_len)) 295 | insample_mask = np.zeros((len(self.timeseries), self.seq_len)) 296 | for i, ts in enumerate(self.timeseries): 297 | ts_last_window = ts[-self.seq_len:] 298 | insample[i, -len(ts):] = ts_last_window 299 | insample_mask[i, -len(ts):] = 1.0 300 | return insample, insample_mask 301 | 302 | 303 | class Dataset_TSF(Dataset): 304 | def __init__(self, root_path, flag='train', size=None, data_path=None, 305 | scale=True, seasonal_patterns=None, drop_short=False): 306 | 307 | self.seq_len = size[0] 308 | self.label_len = size[1] 309 | self.pred_len = size[2] 310 | self.token_len = self.pred_len 311 | self.context_len = 4 * self.token_len 312 | print(self.seq_len, self.label_len, self.pred_len) 313 | type_map = {'train': 0, 'val': 1, 'test': 2} 314 | self.set_type = type_map[flag] 315 | 316 | self.root_path = root_path 317 | self.data_path = data_path 318 | self.drop_short = drop_short 319 | self.timeseries = self.__read_data__() 320 | 321 | 322 | def __read_data__(self): 323 | df, _, _, _, _ = convert_tsf_to_dataframe(os.path.join(self.root_path, self.data_path)) 324 | def dropna(x): 325 | return x[~np.isnan(x)] 326 | timeseries = [dropna(ts).astype(np.float32) for ts in df.series_value] 327 | if self.drop_short: 328 | timeseries = [ts for ts in timeseries if ts.shape[0] > self.context_len] 329 | self.tot_len = 0 330 | self.len_seq = [] 331 | self.seq_id = [] 332 | for i in range(len(timeseries)): 333 | res_len = max(self.pred_len + self.seq_len - timeseries[i].shape[0], 0) 334 | pad_zeros = np.zeros(res_len) 335 | timeseries[i] = np.hstack([pad_zeros, timeseries[i]]) 336 | 337 | _len = timeseries[i].shape[0] 338 | train_len = _len-self.pred_len 339 | border1s = [0, train_len - self.seq_len - self.pred_len, train_len-self.seq_len] 340 | border2s = [train_len - self.pred_len, train_len, _len] 341 | 342 | curr_len = border2s[self.set_type] - max(border1s[self.set_type], 0) - self.pred_len - self.seq_len + 1 343 | curr_len = max(0, curr_len) 344 | 345 | self.len_seq.append(np.zeros(curr_len) + self.tot_len) 346 | self.seq_id.append(np.zeros(curr_len) + i) 347 | self.tot_len += curr_len 348 | 349 | self.len_seq = np.hstack(self.len_seq) 350 | self.seq_id = np.hstack(self.seq_id) 351 | 352 | return timeseries 353 | 354 | def __getitem__(self, index): 355 | len_seq = self.len_seq[index] 356 | seq_id = int(self.seq_id[index]) 357 | index = index - int(len_seq) 358 | 359 | _len = self.timeseries[seq_id].shape[0] 360 | train_len = _len - self.pred_len 361 | border1s = [0, train_len - self.seq_len - self.pred_len, train_len-self.seq_len] 362 | 363 | s_begin = index + border1s[self.set_type] 364 | s_end = s_begin + self.seq_len 365 | r_begin = s_end - self.label_len 366 | r_end = s_end + self.pred_len 367 | 368 | data_x = self.timeseries[seq_id][s_begin:s_end] 369 | data_y = self.timeseries[seq_id][r_begin:r_end] 370 | data_x = np.expand_dims(data_x, axis=-1) 371 | data_y = np.expand_dims(data_y, axis=-1) 372 | 373 | return data_x, data_y, data_x, data_y 374 | 375 | def __len__(self): 376 | return self.tot_len 377 | 378 | class Dataset_TSF_ICL(Dataset): 379 | def __init__(self, root_path, flag='train', size=None, data_path=None, 380 | scale=True, seasonal_patterns=None, drop_short=True): 381 | 382 | self.pred_len = size[2] 383 | self.token_len = self.pred_len 384 | self.context_len = 4 * self.token_len 385 | 386 | self.root_path = root_path 387 | self.data_path = data_path 388 | self.timeseries = self.__read_data__() 389 | 390 | def __read_data__(self): 391 | df, _, _, _, _ = convert_tsf_to_dataframe(os.path.join(self.root_path, self.data_path)) 392 | def dropna(x): 393 | return x[~np.isnan(x)] 394 | timeseries = [dropna(ts).astype(np.float32) for ts in df.series_value] 395 | timeseries = [ts for ts in timeseries if ts.shape[0] > self.context_len] 396 | return timeseries 397 | 398 | # we uniformly adopting the first time points of the time series as the corresponding prompt. 399 | def __getitem__(self, index): 400 | data_x1 = self.timeseries[index][:2*self.token_len] 401 | data_x2 = self.timeseries[index][-2*self.token_len:-1*self.token_len] 402 | data_x = np.concatenate((data_x1, data_x2)) 403 | data_y = self.timeseries[index][-1*self.token_len:] 404 | data_x = np.expand_dims(data_x, axis=-1) 405 | data_y = np.expand_dims(data_y, axis=-1) 406 | return data_x, data_y, data_x, data_y 407 | 408 | def __len__(self): 409 | return len(self.timeseries) 410 | 411 | class Dataset_Preprocess(Dataset): 412 | def __init__(self, root_path, flag='train', size=None, 413 | data_path='ETTh1.csv', scale=True, seasonal_patterns=None): 414 | self.seq_len = size[0] 415 | self.label_len = size[1] 416 | self.pred_len = size[2] 417 | self.token_len = self.seq_len - self.label_len 418 | self.token_num = self.seq_len // self.token_len 419 | self.flag = flag 420 | self.data_set_type = data_path.split('.')[0] 421 | # init 422 | assert flag in ['train', 'test', 'val'] 423 | type_map = {'train': 0, 'val': 1, 'test': 2} 424 | self.set_type = type_map[flag] 425 | 426 | self.scale = scale 427 | 428 | self.root_path = root_path 429 | self.data_path = data_path 430 | self.__read_data__() 431 | self.tot_len = len(self.data_stamp) 432 | 433 | def __read_data__(self): 434 | df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path)) 435 | df_stamp = df_raw[['date']] 436 | df_stamp['date'] = pd.to_datetime(df_stamp.date).apply(str) 437 | self.data_stamp = df_stamp['date'].values 438 | self.data_stamp = [str(x) for x in self.data_stamp] 439 | 440 | 441 | def __getitem__(self, index): 442 | s_begin = index % self.tot_len 443 | s_end = s_begin + self.token_len 444 | start = datetime.datetime.strptime(self.data_stamp[s_begin], "%Y-%m-%d %H:%M:%S") 445 | if self.data_set_type in ['traffic', 'electricity', 'ETTh1', 'ETTh2']: 446 | end = (start + datetime.timedelta(hours=self.token_len-1)).strftime("%Y-%m-%d %H:%M:%S") 447 | elif self.data_set_type == 'weather': 448 | end = (start + datetime.timedelta(minutes=10*(self.token_len-1))).strftime("%Y-%m-%d %H:%M:%S") 449 | elif self.data_set_type in ['ETTm1', 'ETTm2']: 450 | end = (start + datetime.timedelta(minutes=15*(self.token_len-1))).strftime("%Y-%m-%d %H:%M:%S") 451 | seq_x_mark = f"This is Time Series from {self.data_stamp[s_begin]} to {end}" 452 | return seq_x_mark 453 | 454 | def __len__(self): 455 | return len(self.data_stamp) 456 | -------------------------------------------------------------------------------- /data_provider/m4.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | M4 Dataset 17 | """ 18 | import logging 19 | import os 20 | from collections import OrderedDict 21 | from dataclasses import dataclass 22 | from glob import glob 23 | 24 | import numpy as np 25 | import pandas as pd 26 | import patoolib 27 | from tqdm import tqdm 28 | import logging 29 | import os 30 | import pathlib 31 | import sys 32 | from urllib import request 33 | 34 | 35 | def url_file_name(url: str) -> str: 36 | """ 37 | Extract file name from url. 38 | 39 | :param url: URL to extract file name from. 40 | :return: File name. 41 | """ 42 | return url.split('/')[-1] if len(url) > 0 else '' 43 | 44 | 45 | def download(url: str, file_path: str) -> None: 46 | """ 47 | Download a file to the given path. 48 | 49 | :param url: URL to download 50 | :param file_path: Where to download the content. 51 | """ 52 | 53 | def progress(count, block_size, total_size): 54 | progress_pct = float(count * block_size) / float(total_size) * 100.0 55 | sys.stdout.write('\rDownloading {} to {} {:.1f}%'.format(url, file_path, progress_pct)) 56 | sys.stdout.flush() 57 | 58 | if not os.path.isfile(file_path): 59 | opener = request.build_opener() 60 | opener.addheaders = [('User-agent', 'Mozilla/5.0')] 61 | request.install_opener(opener) 62 | pathlib.Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True) 63 | f, _ = request.urlretrieve(url, file_path, progress) 64 | sys.stdout.write('\n') 65 | sys.stdout.flush() 66 | file_info = os.stat(f) 67 | logging.info(f'Successfully downloaded {os.path.basename(file_path)} {file_info.st_size} bytes.') 68 | else: 69 | file_info = os.stat(file_path) 70 | logging.info(f'File already exists: {file_path} {file_info.st_size} bytes.') 71 | 72 | 73 | @dataclass() 74 | class M4Dataset: 75 | ids: np.ndarray 76 | groups: np.ndarray 77 | frequencies: np.ndarray 78 | horizons: np.ndarray 79 | values: np.ndarray 80 | 81 | @staticmethod 82 | def load(training: bool = True, dataset_file: str = '../dataset/m4') -> 'M4Dataset': 83 | """ 84 | Load cached dataset. 85 | 86 | :param training: Load training part if training is True, test part otherwise. 87 | """ 88 | info_file = os.path.join(dataset_file, 'M4-info.csv') 89 | train_cache_file = os.path.join(dataset_file, 'training.npz') 90 | test_cache_file = os.path.join(dataset_file, 'test.npz') 91 | m4_info = pd.read_csv(info_file) 92 | return M4Dataset(ids=m4_info.M4id.values, 93 | groups=m4_info.SP.values, 94 | frequencies=m4_info.Frequency.values, 95 | horizons=m4_info.Horizon.values, 96 | values=np.load( 97 | train_cache_file if training else test_cache_file, 98 | allow_pickle=True)) 99 | 100 | 101 | @dataclass() 102 | class M4Meta: 103 | seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly'] 104 | horizons = [6, 8, 18, 13, 14, 48] 105 | frequencies = [1, 4, 12, 1, 1, 24] 106 | horizons_map = { 107 | 'Yearly': 6, 108 | 'Quarterly': 8, 109 | 'Monthly': 18, 110 | 'Weekly': 13, 111 | 'Daily': 14, 112 | 'Hourly': 48 113 | } # different predict length 114 | frequency_map = { 115 | 'Yearly': 1, 116 | 'Quarterly': 4, 117 | 'Monthly': 12, 118 | 'Weekly': 1, 119 | 'Daily': 1, 120 | 'Hourly': 24 121 | } 122 | history_size = { 123 | 'Yearly': 1.5, 124 | 'Quarterly': 1.5, 125 | 'Monthly': 1.5, 126 | 'Weekly': 10, 127 | 'Daily': 10, 128 | 'Hourly': 10 129 | } # from interpretable.gin 130 | 131 | 132 | def load_m4_info() -> pd.DataFrame: 133 | """ 134 | Load M4Info file. 135 | 136 | :return: Pandas DataFrame of M4Info. 137 | """ 138 | return pd.read_csv(INFO_FILE_PATH) -------------------------------------------------------------------------------- /exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | from models import AutoTimes_Llama, AutoTimes_Gpt2, AutoTimes_Opt_1b 2 | 3 | 4 | class Exp_Basic(object): 5 | def __init__(self, args): 6 | self.args = args 7 | self.model_dict = { 8 | 'AutoTimes_Llama': AutoTimes_Llama, 9 | 'AutoTimes_Gpt2': AutoTimes_Gpt2, 10 | 'AutoTimes_Opt_1b': AutoTimes_Opt_1b 11 | } 12 | self.model = self._build_model() 13 | 14 | def _build_model(self): 15 | raise NotImplementedError 16 | 17 | def _get_data(self): 18 | pass 19 | 20 | def vali(self): 21 | pass 22 | 23 | def train(self): 24 | pass 25 | 26 | def test(self): 27 | pass 28 | -------------------------------------------------------------------------------- /exp/exp_in_context_forecasting.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_factory import data_provider 2 | from data_provider.m4 import M4Meta 3 | from exp.exp_basic import Exp_Basic 4 | from utils.tools import EarlyStopping, adjust_learning_rate 5 | from utils.losses import mape_loss, mase_loss, smape_loss, zero_shot_smape_loss 6 | import torch 7 | import torch.nn as nn 8 | from torch import optim 9 | import os 10 | import time 11 | import warnings 12 | import numpy as np 13 | 14 | # In our in-context learning setting 15 | # the task is to apply a forecaster, trained on a source dataset, to an unseen target dataset 16 | # Additionally, several task demonstrations from the target domain, 17 | # referred to as time series prompts are available during inference 18 | # Concretely, AutoTimes trains LLMs on the source domain with a larger context length to place the additional time series prompt. 19 | # See ```Dataset_TSF_ICL``` in ```data_loader.py``` for the construction of time series prompts 20 | 21 | warnings.filterwarnings('ignore') 22 | 23 | def SMAPE(pred, true): 24 | return np.mean(200 * np.abs(pred - true) / (np.abs(pred) + np.abs(true) + 1e-8)) 25 | def MAPE(pred, true): 26 | return np.mean(np.abs(100 * (pred - true) / (true +1e-8))) 27 | 28 | class Exp_In_Context_Forecast(Exp_Basic): 29 | def __init__(self, args): 30 | super(Exp_In_Context_Forecast, self).__init__(args) 31 | 32 | def _build_model(self): 33 | if self.args.data == 'm4': 34 | self.args.frequency_map = M4Meta.frequency_map[self.args.seasonal_patterns] 35 | self.device = self.args.gpu 36 | model = self.model_dict[self.args.model].Model(self.args).to(self.device) 37 | return model 38 | 39 | def _get_data(self, flag): 40 | data_set, data_loader = data_provider(self.args, flag) 41 | return data_set, data_loader 42 | 43 | def _select_optimizer(self): 44 | p_list = [] 45 | for n, p in self.model.named_parameters(): 46 | if not p.requires_grad: 47 | continue 48 | else: 49 | p_list.append(p) 50 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 51 | print(n, p.dtype, p.shape) 52 | model_optim = optim.Adam([{'params': p_list}], lr=self.args.learning_rate, weight_decay=self.args.weight_decay) 53 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 54 | print('next learning rate is {}'.format(self.args.learning_rate)) 55 | return model_optim 56 | 57 | def _select_criterion(self, loss_name='MSE'): 58 | if loss_name == 'MSE': 59 | return nn.MSELoss() 60 | elif loss_name == 'MAPE': 61 | return mape_loss() 62 | elif loss_name == 'MASE': 63 | return mase_loss() 64 | elif loss_name == 'SMAPE': 65 | return smape_loss() 66 | 67 | def train(self, setting): 68 | train_data, train_loader = self._get_data(flag='train') 69 | vali_data, vali_loader = self._get_data(flag='val') 70 | 71 | self.args.root_path = './dataset/tsf' 72 | self.args.data_path = self.args.test_data_path 73 | self.args.data = 'tsf' 74 | test_data2, test_loader2 = self._get_data(flag='test') 75 | 76 | self.args.data = 'tsf_icl' 77 | test_data3, test_loader3 = self._get_data(flag="test") 78 | path = os.path.join(self.args.checkpoints, setting) 79 | if not os.path.exists(path): 80 | os.makedirs(path) 81 | 82 | time_now = time.time() 83 | 84 | train_steps = len(train_loader) 85 | early_stopping = EarlyStopping(self.args, verbose=True) 86 | 87 | model_optim = self._select_optimizer() 88 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=self.args.tmax, eta_min=1e-8) 89 | criterion = self._select_criterion(self.args.loss) 90 | if self.args.use_amp: 91 | scaler = torch.cuda.amp.GradScaler() 92 | 93 | for epoch in range(self.args.train_epochs): 94 | iter_count = 0 95 | train_loss = [] 96 | 97 | self.model.train() 98 | epoch_time = time.time() 99 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): 100 | iter_count += 1 101 | model_optim.zero_grad() 102 | batch_x = batch_x.float().to(self.device) 103 | 104 | batch_y = batch_y.float().to(self.device) 105 | batch_y_mark = batch_y_mark.float().to(self.device) 106 | 107 | if self.args.use_amp: 108 | with torch.cuda.amp.autocast(): 109 | outputs = self.model(batch_x, None, None, None) 110 | else: 111 | outputs = self.model(batch_x, None, None, None) 112 | 113 | loss = criterion(batch_x, self.args.frequency_map, outputs, batch_y, batch_y_mark) 114 | train_loss.append(loss.item()) 115 | 116 | if (i + 1) % 100 == 0: 117 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) 118 | speed = (time.time() - time_now) / iter_count 119 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i) 120 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) 121 | iter_count = 0 122 | time_now = time.time() 123 | 124 | if self.args.use_amp: 125 | scaler.scale(loss).backward() 126 | scaler.step(model_optim) 127 | scaler.update() 128 | else: 129 | loss.backward() 130 | model_optim.step() 131 | 132 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)) 133 | train_loss = np.average(train_loss) 134 | vali_loss = self.vali(train_loader, vali_loader, criterion) # test_loss indicates the result on the source datasets 135 | test_loss = vali_loss 136 | test_loss2 = self.vali2(test_data2, test_loader2, zero_shot_smape_loss()) # test_loss2 indicates the result on the target datasets 137 | test_loss3 = self.vali2(test_data3, test_loader3, zero_shot_smape_loss()) # test_loss3 indicates the result on the target datasets with time series prompts 138 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Zero Shot Test Loss: {4:.7f} In Context Test Loss: {5:.7f}".format( 139 | epoch + 1, train_steps, train_loss, vali_loss, test_loss2, test_loss3)) 140 | early_stopping(vali_loss, self.model, path) 141 | if early_stopping.early_stop: 142 | print("Early stopping") 143 | break 144 | if self.args.cosine: 145 | scheduler.step() 146 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 147 | print("lr = {:.10f}".format(model_optim.param_groups[0]['lr'])) 148 | else: 149 | adjust_learning_rate(model_optim, epoch + 1, self.args) 150 | 151 | best_model_path = path + '/' + f'checkpoint.pth' 152 | 153 | self.model.load_state_dict(torch.load(best_model_path), strict=False) 154 | 155 | return self.model 156 | 157 | def vali(self, train_loader, vali_loader, criterion): 158 | x, _ = train_loader.dataset.last_insample_window() 159 | y = vali_loader.dataset.timeseries 160 | x = torch.tensor(x, dtype=torch.float32).to(self.device) 161 | x = x.unsqueeze(-1) 162 | 163 | self.model.eval() 164 | with torch.no_grad(): 165 | # decoder input 166 | B, _, C = x.shape 167 | 168 | outputs = torch.zeros((B, self.args.seq_len, C)).float() # .to(self.device) 169 | id_list = np.arange(0, B, 500) # validation set size 170 | id_list = np.append(id_list, B) 171 | if self.args.use_amp: 172 | with torch.cuda.amp.autocast(): 173 | for i in range(len(id_list) - 1): 174 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None, None, None).detach().cpu() 175 | else: 176 | for i in range(len(id_list) - 1): 177 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None, None, None).detach().cpu() 178 | pred = outputs[:, -self.args.token_len:, :] 179 | true = torch.from_numpy(np.array(y)) 180 | batch_y_mark = torch.ones(true.shape) 181 | loss = criterion(x.detach().cpu()[:, :, 0], self.args.frequency_map, pred[:, :, 0], true, batch_y_mark) 182 | 183 | self.model.train() 184 | return loss 185 | 186 | def vali2(self, vali_data, vali_loader, criterion): 187 | total_loss = [] 188 | count= [] 189 | self.model.eval() 190 | with torch.no_grad(): 191 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader): 192 | batch_x = batch_x.float().to(self.device) 193 | batch_y = batch_y.float() 194 | 195 | batch_x_mark = batch_x_mark.float().to(self.device) 196 | batch_y_mark = batch_y_mark.float().to(self.device) 197 | 198 | if self.args.use_amp: 199 | with torch.cuda.amp.autocast(): 200 | outputs = self.model(batch_x, None, None, None) 201 | else: 202 | outputs = self.model(batch_x, None, None, None) 203 | 204 | batch_y = batch_y[:, -self.args.test_pred_len:, :].to(self.device) 205 | 206 | pred = outputs[:, -self.args.test_pred_len:, :].detach().cpu() 207 | true = batch_y.detach().cpu() 208 | 209 | loss = criterion(pred, true) 210 | 211 | total_loss.append(loss) 212 | count.append(batch_x.shape[0]) 213 | total_loss = np.average(total_loss, weights=count) 214 | self.model.train() 215 | 216 | return total_loss 217 | 218 | def test_(self, test_loader): 219 | preds = [] 220 | trues = [] 221 | 222 | self.model.eval() 223 | with torch.no_grad(): 224 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader): 225 | batch_x = batch_x.float().to(self.device) 226 | batch_y = batch_y.float() 227 | 228 | if self.args.use_amp: 229 | with torch.cuda.amp.autocast(): 230 | outputs = self.model(batch_x, None, None, None) 231 | else: 232 | outputs = self.model(batch_x, None, None, None) 233 | 234 | outputs = outputs[:, -self.args.test_pred_len:, :] 235 | batch_y = batch_y[:, -self.args.test_pred_len:, :].to(self.device) 236 | 237 | pred = outputs.detach().cpu().numpy() 238 | true = batch_y.detach().cpu().numpy() 239 | 240 | preds.append(pred) 241 | trues.append(true) 242 | 243 | preds = np.concatenate(preds, axis=0) 244 | trues = np.concatenate(trues, axis=0) 245 | print('test shape:', preds.shape, trues.shape) 246 | 247 | smape = SMAPE(preds, trues) 248 | mape = MAPE(preds, trues) 249 | print('mape:{:4f}, smape:{:.4f}'.format(mape, smape)) 250 | 251 | def test(self, setting, test=0): 252 | if test: 253 | print('loading model') 254 | setting = self.args.test_dir 255 | best_model_path = self.args.test_file_name 256 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 257 | print("loading model from {}".format(os.path.join(self.args.checkpoints, setting, best_model_path))) 258 | self.model.load_state_dict(torch.load(os.path.join(self.args.checkpoints, setting, best_model_path)), strict=False) 259 | 260 | self.args.data_path = self.args.test_data_path 261 | 262 | self.args.root_path = './dataset/tsf' 263 | self.args.data_path = self.args.test_data_path 264 | self.args.data = 'tsf' 265 | test_data, test_loader = self._get_data('test') 266 | self.args.data = 'tsf_icl' 267 | test_data2, test_loader2 = self._get_data('test') 268 | 269 | print("zero shot forecasting") 270 | self.test_(test_loader) 271 | print("in context forecasting") 272 | self.test_(test_loader2) 273 | 274 | -------------------------------------------------------------------------------- /exp/exp_long_term_forecasting.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_factory import data_provider 2 | from exp.exp_basic import Exp_Basic 3 | from utils.tools import EarlyStopping, adjust_learning_rate, visual 4 | from utils.metrics import metric 5 | import torch 6 | import torch.nn as nn 7 | from torch import optim 8 | import os 9 | import time 10 | import warnings 11 | import numpy as np 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | import torch.distributed as dist 14 | 15 | warnings.filterwarnings('ignore') 16 | 17 | 18 | class Exp_Long_Term_Forecast(Exp_Basic): 19 | def __init__(self, args): 20 | super(Exp_Long_Term_Forecast, self).__init__(args) 21 | 22 | def _build_model(self): 23 | model = self.model_dict[self.args.model].Model(self.args) 24 | if self.args.use_multi_gpu: 25 | self.device = torch.device('cuda:{}'.format(self.args.local_rank)) 26 | model = DDP(model.cuda(), device_ids=[self.args.local_rank]) 27 | else: 28 | self.device = self.args.gpu 29 | model = model.to(self.device) 30 | return model 31 | 32 | def _get_data(self, flag): 33 | data_set, data_loader = data_provider(self.args, flag) 34 | return data_set, data_loader 35 | 36 | def _select_optimizer(self): 37 | p_list = [] 38 | for n, p in self.model.named_parameters(): 39 | if not p.requires_grad: 40 | continue 41 | else: 42 | p_list.append(p) 43 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 44 | print(n, p.dtype, p.shape) 45 | model_optim = optim.Adam([{'params': p_list}], lr=self.args.learning_rate, weight_decay=self.args.weight_decay) 46 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 47 | print('next learning rate is {}'.format(self.args.learning_rate)) 48 | return model_optim 49 | 50 | def _select_criterion(self): 51 | criterion = nn.MSELoss() 52 | return criterion 53 | 54 | def vali(self, vali_data, vali_loader, criterion, is_test=False): 55 | total_loss = [] 56 | total_count = [] 57 | time_now = time.time() 58 | test_steps = len(vali_loader) 59 | iter_count = 0 60 | self.model.eval() 61 | with torch.no_grad(): 62 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader): 63 | iter_count += 1 64 | batch_x = batch_x.float().to(self.device) 65 | batch_y = batch_y.float() 66 | batch_x_mark = batch_x_mark.float().to(self.device) 67 | batch_y_mark = batch_y_mark.float().to(self.device) 68 | 69 | if self.args.use_amp: 70 | with torch.cuda.amp.autocast(): 71 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark) 72 | else: 73 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark) 74 | if is_test: 75 | outputs = outputs[:, -self.args.token_len:, :] 76 | batch_y = batch_y[:, -self.args.token_len:, :].to(self.device) 77 | else: 78 | outputs = outputs[:, :, :] 79 | batch_y = batch_y[:, :, :].to(self.device) 80 | 81 | loss = criterion(outputs, batch_y) 82 | 83 | loss = loss.detach().cpu() 84 | total_loss.append(loss) 85 | total_count.append(batch_x.shape[0]) 86 | if (i + 1) % 100 == 0: 87 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 88 | speed = (time.time() - time_now) / iter_count 89 | left_time = speed * (test_steps - i) 90 | print("\titers: {}, speed: {:.4f}s/iter, left time: {:.4f}s".format(i + 1, speed, left_time)) 91 | iter_count = 0 92 | time_now = time.time() 93 | if self.args.use_multi_gpu: 94 | total_loss = torch.tensor(np.average(total_loss, weights=total_count)).to(self.device) 95 | dist.barrier() 96 | dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) 97 | total_loss = total_loss.item() / dist.get_world_size() 98 | else: 99 | total_loss = np.average(total_loss, weights=total_count) 100 | self.model.train() 101 | return total_loss 102 | 103 | def train(self, setting): 104 | train_data, train_loader = self._get_data(flag='train') 105 | vali_data, vali_loader = self._get_data(flag='val') 106 | test_data, test_loader = self._get_data(flag='test') 107 | 108 | path = os.path.join(self.args.checkpoints, setting) 109 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 110 | if not os.path.exists(path): 111 | os.makedirs(path) 112 | 113 | time_now = time.time() 114 | 115 | train_steps = len(train_loader) 116 | early_stopping = EarlyStopping(self.args, verbose=True) 117 | 118 | model_optim = self._select_optimizer() 119 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=self.args.tmax, eta_min=1e-8) 120 | criterion = self._select_criterion() 121 | if self.args.use_amp: 122 | scaler = torch.cuda.amp.GradScaler() 123 | 124 | for epoch in range(self.args.train_epochs): 125 | iter_count = 0 126 | 127 | loss_val = torch.tensor(0., device="cuda") 128 | count = torch.tensor(0., device="cuda") 129 | 130 | self.model.train() 131 | epoch_time = time.time() 132 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): 133 | iter_count += 1 134 | model_optim.zero_grad() 135 | batch_x = batch_x.float().to(self.device) 136 | batch_y = batch_y.float().to(self.device) 137 | batch_x_mark = batch_x_mark.float().to(self.device) 138 | batch_y_mark = batch_y_mark.float().to(self.device) 139 | 140 | if self.args.use_amp: 141 | with torch.cuda.amp.autocast(): 142 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark) 143 | loss = criterion(outputs, batch_y) 144 | loss_val += loss 145 | count += 1 146 | else: 147 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark) 148 | loss = criterion(outputs, batch_y) 149 | loss_val += loss 150 | count += 1 151 | 152 | if (i + 1) % 100 == 0: 153 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 154 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) 155 | speed = (time.time() - time_now) / iter_count 156 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i) 157 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) 158 | iter_count = 0 159 | time_now = time.time() 160 | 161 | if self.args.use_amp: 162 | scaler.scale(loss).backward() 163 | scaler.step(model_optim) 164 | scaler.update() 165 | else: 166 | loss.backward() 167 | model_optim.step() 168 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 169 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)) 170 | if self.args.use_multi_gpu: 171 | dist.barrier() 172 | dist.all_reduce(loss_val, op=dist.ReduceOp.SUM) 173 | dist.all_reduce(count, op=dist.ReduceOp.SUM) 174 | train_loss = loss_val.item() / count.item() 175 | 176 | vali_loss = self.vali(vali_data, vali_loader, criterion) 177 | test_loss = self.vali(test_data, test_loader, criterion, is_test=True) 178 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 179 | print("Epoch: {}, Steps: {} | Train Loss: {:.7f} Vali Loss: {:.7f} Test Loss: {:.7f}".format( 180 | epoch + 1, train_steps, train_loss, vali_loss, test_loss)) 181 | early_stopping(vali_loss, self.model, path) 182 | if early_stopping.early_stop: 183 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 184 | print("Early stopping") 185 | break 186 | if self.args.cosine: 187 | scheduler.step() 188 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 189 | print("lr = {:.10f}".format(model_optim.param_groups[0]['lr'])) 190 | else: 191 | adjust_learning_rate(model_optim, epoch + 1, self.args) 192 | if self.args.use_multi_gpu: 193 | train_loader.sampler.set_epoch(epoch + 1) 194 | 195 | best_model_path = path + '/' + 'checkpoint.pth' 196 | if self.args.use_multi_gpu: 197 | dist.barrier() 198 | self.model.load_state_dict(torch.load(best_model_path), strict=False) 199 | else: 200 | self.model.load_state_dict(torch.load(best_model_path), strict=False) 201 | return self.model 202 | 203 | def test(self, setting, test=0): 204 | test_data, test_loader = self._get_data(flag='test') 205 | 206 | print("info:", self.args.test_seq_len, self.args.test_label_len, self.args.token_len, self.args.test_pred_len) 207 | if test: 208 | print('loading model') 209 | setting = self.args.test_dir 210 | best_model_path = self.args.test_file_name 211 | 212 | print("loading model from {}".format(os.path.join(self.args.checkpoints, setting, best_model_path))) 213 | load_item = torch.load(os.path.join(self.args.checkpoints, setting, best_model_path)) 214 | self.model.load_state_dict({k.replace('module.', ''): v for k, v in load_item.items()}, strict=False) 215 | 216 | preds = [] 217 | trues = [] 218 | folder_path = './test_results/' + setting + '/' 219 | if not os.path.exists(folder_path): 220 | os.makedirs(folder_path) 221 | time_now = time.time() 222 | test_steps = len(test_loader) 223 | iter_count = 0 224 | self.model.eval() 225 | with torch.no_grad(): 226 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader): 227 | iter_count += 1 228 | batch_x = batch_x.float().to(self.device) 229 | batch_y = batch_y.float().to(self.device) 230 | batch_x_mark = batch_x_mark.float().to(self.device) 231 | batch_y_mark = batch_y_mark.float().to(self.device) 232 | 233 | inference_steps = self.args.test_pred_len // self.args.token_len 234 | dis = self.args.test_pred_len - inference_steps * self.args.token_len 235 | if dis != 0: 236 | inference_steps += 1 237 | pred_y = [] 238 | for j in range(inference_steps): 239 | if len(pred_y) != 0: 240 | batch_x = torch.cat([batch_x[:, self.args.token_len:, :], pred_y[-1]], dim=1) 241 | tmp = batch_y_mark[:, j-1:j, :] 242 | batch_x_mark = torch.cat([batch_x_mark[:, 1:, :], tmp], dim=1) 243 | 244 | if self.args.use_amp: 245 | with torch.cuda.amp.autocast(): 246 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark) 247 | else: 248 | outputs = self.model(batch_x, batch_x_mark, None, batch_y_mark) 249 | pred_y.append(outputs[:, -self.args.token_len:, :]) 250 | pred_y = torch.cat(pred_y, dim=1) 251 | if dis != 0: 252 | pred_y = pred_y[:, :-(self.args.token_len - dis), :] 253 | batch_y = batch_y[:, -self.args.test_pred_len:, :].to(self.device) 254 | outputs = pred_y.detach().cpu() 255 | batch_y = batch_y.detach().cpu() 256 | 257 | pred = outputs 258 | true = batch_y 259 | 260 | preds.append(pred) 261 | trues.append(true) 262 | if (i + 1) % 100 == 0: 263 | if (self.args.use_multi_gpu and self.args.local_rank == 0) or not self.args.use_multi_gpu: 264 | speed = (time.time() - time_now) / iter_count 265 | left_time = speed * (test_steps - i) 266 | print("\titers: {}, speed: {:.4f}s/iter, left time: {:.4f}s".format(i + 1, speed, left_time)) 267 | iter_count = 0 268 | time_now = time.time() 269 | 270 | if self.args.visualize and i == 0: 271 | gt = np.array(true[0, :, -1]) 272 | pd = np.array(pred[0, :, -1]) 273 | lookback = batch_x[0, :, -1].detach().cpu().numpy() 274 | gt = np.concatenate([lookback, gt], axis=0) 275 | pd = np.concatenate([lookback, pd], axis=0) 276 | dir_path = folder_path + f'{self.args.test_pred_len}/' 277 | if not os.path.exists(dir_path): 278 | os.makedirs(dir_path) 279 | visual(gt, pd, os.path.join(dir_path, f'{i}.png')) 280 | 281 | preds = torch.cat(preds, dim=0).numpy() 282 | trues = torch.cat(trues, dim=0).numpy() 283 | 284 | mae, mse, rmse, mape, mspe = metric(preds, trues) 285 | print('mse:{}, mae:{}'.format(mse, mae)) 286 | f = open("result_long_term_forecast.txt", 'a') 287 | f.write(setting + " \n") 288 | f.write('mse:{}, mae:{}'.format(mse, mae)) 289 | f.write('\n') 290 | f.write('\n') 291 | f.close() 292 | return 293 | -------------------------------------------------------------------------------- /exp/exp_short_term_forecasting.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_factory import data_provider 2 | from data_provider.m4 import M4Meta 3 | from exp.exp_basic import Exp_Basic 4 | from utils.tools import EarlyStopping, adjust_learning_rate, visual 5 | from utils.losses import mape_loss, mase_loss, smape_loss 6 | from utils.m4_summary import M4Summary 7 | import torch 8 | import torch.nn as nn 9 | from torch import optim 10 | import os 11 | import time 12 | import warnings 13 | import numpy as np 14 | import pandas 15 | 16 | 17 | warnings.filterwarnings('ignore') 18 | 19 | class Exp_Short_Term_Forecast(Exp_Basic): 20 | def __init__(self, args): 21 | super(Exp_Short_Term_Forecast, self).__init__(args) 22 | 23 | def _build_model(self): 24 | if self.args.data == 'm4': 25 | self.args.token_len = M4Meta.horizons_map[self.args.seasonal_patterns] # Up to M4 config 26 | self.args.seq_len = 2 * self.args.token_len # input_len = 2*token_len 27 | self.args.label_len = self.args.token_len 28 | self.args.frequency_map = M4Meta.frequency_map[self.args.seasonal_patterns] 29 | self.device = self.args.gpu 30 | model = self.model_dict[self.args.model].Model(self.args).to(self.device) 31 | return model 32 | 33 | def _get_data(self, flag): 34 | data_set, data_loader = data_provider(self.args, flag) 35 | return data_set, data_loader 36 | 37 | def _select_optimizer(self): 38 | p_list = [] 39 | for n, p in self.model.named_parameters(): 40 | if not p.requires_grad: 41 | continue 42 | else: 43 | p_list.append(p) 44 | print(n, p.dtype, p.shape) 45 | model_optim = optim.Adam([{'params': p_list}], lr=self.args.learning_rate, weight_decay=self.args.weight_decay) 46 | print('next learning rate is {}'.format(self.args.learning_rate)) 47 | return model_optim 48 | 49 | def _select_criterion(self, loss_name='MSE'): 50 | if loss_name == 'MSE': 51 | return nn.MSELoss() 52 | elif loss_name == 'MAPE': 53 | return mape_loss() 54 | elif loss_name == 'MASE': 55 | return mase_loss() 56 | elif loss_name == 'SMAPE': 57 | return smape_loss() 58 | 59 | def train(self, setting): 60 | train_data, train_loader = self._get_data(flag='train') 61 | vali_data, vali_loader = self._get_data(flag='val') 62 | 63 | path = os.path.join(self.args.checkpoints, setting) 64 | if not os.path.exists(path): 65 | os.makedirs(path) 66 | 67 | time_now = time.time() 68 | 69 | train_steps = len(train_loader) 70 | early_stopping = EarlyStopping(self.args, verbose=True) 71 | 72 | model_optim = self._select_optimizer() 73 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=self.args.tmax, eta_min=1e-8) 74 | criterion = self._select_criterion(self.args.loss) 75 | if self.args.use_amp: 76 | scaler = torch.cuda.amp.GradScaler() 77 | 78 | for epoch in range(self.args.train_epochs): 79 | iter_count = 0 80 | train_loss = [] 81 | 82 | self.model.train() 83 | epoch_time = time.time() 84 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): 85 | iter_count += 1 86 | model_optim.zero_grad() 87 | batch_x = batch_x.float().to(self.device) 88 | 89 | batch_y = batch_y.float().to(self.device) 90 | batch_y_mark = batch_y_mark.float().to(self.device) 91 | 92 | if self.args.use_amp: 93 | with torch.cuda.amp.autocast(): 94 | outputs = self.model(batch_x, None, None, None) 95 | else: 96 | outputs = self.model(batch_x, None, None, None) 97 | 98 | loss = criterion(batch_x, self.args.frequency_map, outputs, batch_y, batch_y_mark) 99 | train_loss.append(loss.item()) 100 | 101 | if (i + 1) % 100 == 0: 102 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) 103 | speed = (time.time() - time_now) / iter_count 104 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i) 105 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) 106 | iter_count = 0 107 | time_now = time.time() 108 | 109 | if self.args.use_amp: 110 | scaler.scale(loss).backward() 111 | scaler.step(model_optim) 112 | scaler.update() 113 | else: 114 | loss.backward() 115 | model_optim.step() 116 | 117 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)) 118 | train_loss = np.average(train_loss) 119 | vali_loss = self.vali(train_loader, vali_loader, criterion) 120 | test_loss = vali_loss 121 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format( 122 | epoch + 1, train_steps, train_loss, vali_loss, test_loss)) 123 | early_stopping(vali_loss, self.model, path) 124 | if early_stopping.early_stop: 125 | print("Early stopping") 126 | break 127 | if self.args.cosine: 128 | scheduler.step() 129 | print("lr = {:.10f}".format(model_optim.param_groups[0]['lr'])) 130 | else: 131 | adjust_learning_rate(model_optim, epoch + 1, self.args) 132 | 133 | best_model_path = path + '/' + f'checkpoint.pth' 134 | self.model.load_state_dict(torch.load(best_model_path), strict=False) 135 | return self.model 136 | 137 | def vali(self, train_loader, vali_loader, criterion): 138 | x, _ = train_loader.dataset.last_insample_window() 139 | y = vali_loader.dataset.timeseries 140 | x = torch.tensor(x, dtype=torch.float32).to(self.device) 141 | x = x.unsqueeze(-1) 142 | 143 | self.model.eval() 144 | with torch.no_grad(): 145 | B, _, C = x.shape 146 | 147 | outputs = torch.zeros((B, self.args.seq_len, C)).float() # .to(self.device) 148 | id_list = np.arange(0, B, 500) # validation set size 149 | id_list = np.append(id_list, B) 150 | if self.args.use_amp: 151 | with torch.cuda.amp.autocast(): 152 | for i in range(len(id_list) - 1): 153 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None, 154 | None, 155 | None).detach().cpu() 156 | else: 157 | for i in range(len(id_list) - 1): 158 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None, 159 | None, 160 | None).detach().cpu() 161 | pred = outputs[:, -self.args.token_len:, :] 162 | true = torch.from_numpy(np.array(y)) 163 | batch_y_mark = torch.ones(true.shape) 164 | 165 | loss = criterion(x.detach().cpu()[:, :, 0], self.args.frequency_map, pred[:, :, 0], true, batch_y_mark) 166 | 167 | self.model.train() 168 | return loss 169 | 170 | def test(self, setting, test=0): 171 | _, train_loader = self._get_data(flag='train') 172 | _, test_loader = self._get_data(flag='test') 173 | x, _ = train_loader.dataset.last_insample_window() 174 | y = test_loader.dataset.timeseries 175 | x = torch.tensor(x, dtype=torch.float32).to(self.device) 176 | x = x.unsqueeze(-1) 177 | 178 | if test: 179 | print('loading model') 180 | setting = self.args.test_dir 181 | best_model_path = self.args.test_file_name 182 | 183 | print("loading model from {}".format(os.path.join(self.args.checkpoints, setting, best_model_path))) 184 | load_item = torch.load(os.path.join(self.args.checkpoints, setting, best_model_path)) 185 | self.model.load_state_dict({k.replace('module.', ''): v for k, v in load_item.items()}, strict=False) 186 | 187 | folder_path = './test_results/' + setting + '/' 188 | if not os.path.exists(folder_path): 189 | os.makedirs(folder_path) 190 | 191 | self.model.eval() 192 | with torch.no_grad(): 193 | B, _, C = x.shape 194 | 195 | outputs = torch.zeros((B, self.args.seq_len, C)).float().to(self.device) 196 | id_list = np.arange(0, B, 1) 197 | id_list = np.append(id_list, B) 198 | if self.args.use_amp: 199 | with torch.cuda.amp.autocast(): 200 | for i in range(len(id_list) - 1): 201 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None, 202 | None, None) 203 | else: 204 | for i in range(len(id_list) - 1): 205 | outputs[id_list[i]:id_list[i + 1], :, :] = self.model(x[id_list[i]:id_list[i + 1]], None, 206 | None, None) 207 | outputs = outputs[:, -self.args.token_len:, :] 208 | preds = outputs.detach().cpu().numpy() 209 | trues = y 210 | x = x.detach().cpu().numpy() 211 | 212 | if self.args.visualize and i % 2 == 0: 213 | gt = np.concatenate((x[i, :, 0], trues[i]), axis=0) 214 | pd = np.concatenate((x[i, :, 0], preds[i, :, 0]), axis=0) 215 | visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf')) 216 | 217 | print('test shape:', preds.shape) 218 | 219 | # result save 220 | folder_path = './m4_results/' + self.args.model + '/' 221 | if not os.path.exists(folder_path): 222 | os.makedirs(folder_path) 223 | 224 | forecasts_df = pandas.DataFrame(preds[:, :, 0], columns=[f'V{i + 1}' for i in range(self.args.token_len)]) 225 | forecasts_df.index = test_loader.dataset.ids[:preds.shape[0]] 226 | forecasts_df.index.name = 'id' 227 | forecasts_df.set_index(forecasts_df.columns[0], inplace=True) 228 | forecasts_df.to_csv(folder_path + self.args.seasonal_patterns + '_forecast.csv') 229 | 230 | print(self.args.model) 231 | file_path = './m4_results/' + self.args.model + '/' 232 | if 'Weekly_forecast.csv' in os.listdir(file_path) \ 233 | and 'Monthly_forecast.csv' in os.listdir(file_path) \ 234 | and 'Yearly_forecast.csv' in os.listdir(file_path) \ 235 | and 'Daily_forecast.csv' in os.listdir(file_path) \ 236 | and 'Hourly_forecast.csv' in os.listdir(file_path) \ 237 | and 'Quarterly_forecast.csv' in os.listdir(file_path): 238 | m4_summary = M4Summary(file_path, self.args.root_path) 239 | # m4_forecast.set_index(m4_winner_forecast.columns[0], inplace=True) 240 | smape_results, owa_results, mape, mase = m4_summary.evaluate() 241 | print('smape:', smape_results) 242 | print('mape:', mape) 243 | print('mase:', mase) 244 | print('owa:', owa_results) 245 | else: 246 | print('After all 6 tasks are finished, you can calculate the averaged index') 247 | return 248 | -------------------------------------------------------------------------------- /exp/exp_zero_shot_forecasting.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_factory import data_provider 2 | from data_provider.m4 import M4Meta 3 | from exp.exp_basic import Exp_Basic 4 | from utils.tools import EarlyStopping, adjust_learning_rate 5 | from utils.losses import zero_shot_smape_loss 6 | import torch 7 | import torch.nn as nn 8 | from torch import optim 9 | import os 10 | import time 11 | import warnings 12 | import numpy as np 13 | 14 | 15 | warnings.filterwarnings('ignore') 16 | 17 | def SMAPE(pred, true): 18 | return np.mean(200 * np.abs(pred - true) / (np.abs(pred) + np.abs(true) + 1e-8)) 19 | def MAPE(pred, true): 20 | return np.mean(np.abs(100 * (pred - true) / (true +1e-8))) 21 | 22 | class Exp_Zero_Shot_Forecast(Exp_Basic): 23 | def __init__(self, args): 24 | super(Exp_Zero_Shot_Forecast, self).__init__(args) 25 | 26 | def _build_model(self): 27 | if self.args.data == 'tsf': 28 | self.args.frequency_map = M4Meta.frequency_map[self.args.seasonal_patterns] 29 | self.device = self.args.gpu 30 | model = self.model_dict[self.args.model].Model(self.args).to(self.device) 31 | return model 32 | 33 | def _get_data(self, flag): 34 | data_set, data_loader = data_provider(self.args, flag) 35 | return data_set, data_loader 36 | 37 | def _select_optimizer(self): 38 | p_list = [] 39 | for n, p in self.model.named_parameters(): 40 | if not p.requires_grad: 41 | continue 42 | else: 43 | p_list.append(p) 44 | print(n, p.dtype, p.shape) 45 | model_optim = optim.Adam([{'params': p_list}], lr=self.args.learning_rate, weight_decay=self.args.weight_decay) 46 | print('next learning rate is {}'.format(self.args.learning_rate)) 47 | return model_optim 48 | 49 | def _select_criterion(self, loss_name='MSE'): 50 | if loss_name == 'MSE': 51 | return nn.MSELoss() 52 | elif loss_name == 'SMAPE': 53 | return zero_shot_smape_loss() 54 | 55 | def train(self, setting): 56 | train_data, train_loader = self._get_data(flag='train') 57 | vali_data, vali_loader = self._get_data(flag='val') 58 | test_data, test_loader = self._get_data(flag='test') 59 | 60 | self.args.data_path = self.args.test_data_path 61 | test_data2, test_loader2 = self._get_data(flag='test') 62 | 63 | path = os.path.join(self.args.checkpoints, setting) 64 | if not os.path.exists(path): 65 | os.makedirs(path) 66 | 67 | time_now = time.time() 68 | 69 | train_steps = len(train_loader) 70 | early_stopping = EarlyStopping(self.args, verbose=True) 71 | 72 | model_optim = self._select_optimizer() 73 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, T_max=self.args.tmax, eta_min=1e-8) 74 | criterion = self._select_criterion(self.args.loss) 75 | if self.args.use_amp: 76 | scaler = torch.cuda.amp.GradScaler() 77 | 78 | for epoch in range(self.args.train_epochs): 79 | iter_count = 0 80 | train_loss = [] 81 | 82 | self.model.train() 83 | epoch_time = time.time() 84 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): 85 | iter_count += 1 86 | model_optim.zero_grad() 87 | batch_x = batch_x.float().to(self.device) 88 | batch_y = batch_y.float().to(self.device) 89 | 90 | if self.args.use_amp: 91 | with torch.cuda.amp.autocast(): 92 | outputs = self.model(batch_x, None, None, None) 93 | else: 94 | outputs = self.model(batch_x, None, None, None) 95 | 96 | loss = criterion(outputs, batch_y) 97 | train_loss.append(loss.item()) 98 | 99 | if (i + 1) % 100 == 0: 100 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) 101 | speed = (time.time() - time_now) / iter_count 102 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i) 103 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) 104 | iter_count = 0 105 | time_now = time.time() 106 | 107 | if self.args.use_amp: 108 | scaler.scale(loss).backward() 109 | scaler.step(model_optim) 110 | scaler.update() 111 | else: 112 | loss.backward() 113 | model_optim.step() 114 | 115 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)) 116 | train_loss = np.average(train_loss) 117 | vali_loss = self.vali(vali_data, vali_loader, criterion) 118 | test_loss = self.vali2(test_data, test_loader, criterion) # test_loss indicates the result on the source datasets, 119 | test_loss2 = self.vali2(test_data2, test_loader2, criterion) # test_loss2 indicates the result on the taregt datasets. The latter is what we concerned in zero-shot forecasting​. 120 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f} Test Loss2: {5:.7f}".format( 121 | epoch + 1, train_steps, train_loss, vali_loss, test_loss, test_loss2)) 122 | early_stopping(vali_loss, self.model, path) 123 | if early_stopping.early_stop: 124 | print("Early stopping") 125 | break 126 | if self.args.cosine: 127 | scheduler.step() 128 | print("lr = {:.10f}".format(model_optim.param_groups[0]['lr'])) 129 | else: 130 | adjust_learning_rate(model_optim, epoch + 1, self.args) 131 | 132 | best_model_path = path + '/' + f'checkpoint.pth' 133 | self.model.load_state_dict(torch.load(best_model_path), strict=False) 134 | 135 | return self.model 136 | 137 | def vali(self, vali_data, vali_loader, criterion): 138 | total_loss = [] 139 | self.model.eval() 140 | with torch.no_grad(): 141 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader): 142 | batch_x = batch_x.float().to(self.device) 143 | batch_y = batch_y.float() 144 | 145 | if self.args.use_amp: 146 | with torch.cuda.amp.autocast(): 147 | outputs = self.model(batch_x, None, None, None) 148 | else: 149 | outputs = self.model(batch_x, None, None, None) 150 | 151 | pred = outputs.detach().cpu() 152 | true = batch_y.detach().cpu() 153 | 154 | loss = criterion(pred, true) 155 | total_loss.append(loss) 156 | total_loss = np.average(total_loss) 157 | self.model.train() 158 | 159 | return total_loss 160 | 161 | def vali2(self, vali_data, vali_loader, criterion): 162 | total_loss = [] 163 | count= [] 164 | self.model.eval() 165 | with torch.no_grad(): 166 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader): 167 | batch_x = batch_x.float().to(self.device) 168 | batch_y = batch_y.float() 169 | 170 | inference_steps = self.args.test_pred_len // self.args.token_len 171 | dis = self.args.test_pred_len - inference_steps * self.args.token_len 172 | if dis != 0: 173 | inference_steps += 1 174 | pred_y = [] 175 | 176 | for j in range(inference_steps): 177 | if len(pred_y) != 0: 178 | batch_x = torch.cat([batch_x[:, self.args.token_len:, :], pred_y[-1]], dim=1) 179 | if self.args.use_amp: 180 | with torch.cuda.amp.autocast(): 181 | outputs = self.model(batch_x, None, None, None) 182 | else: 183 | outputs = self.model(batch_x, None, None, None) 184 | pred_y.append(outputs[:, -self.args.token_len:, :]) 185 | pred_y = torch.cat(pred_y, dim=1) 186 | if dis != 0: 187 | pred_y = pred_y[:, :-dis, :] 188 | 189 | outputs = pred_y 190 | batch_y = batch_y[:, -self.args.test_pred_len:, :].to(self.device) 191 | 192 | pred = outputs.detach().cpu() 193 | true = batch_y.detach().cpu() 194 | 195 | loss = criterion(pred, true) 196 | total_loss.append(loss) 197 | count.append(batch_x.shape[0]) 198 | total_loss = np.average(total_loss, weights=count) 199 | self.model.train() 200 | 201 | return total_loss 202 | 203 | def test(self, setting, test=0): 204 | if test: 205 | print('loading model') 206 | setting = self.args.test_dir 207 | best_model_path = self.args.test_file_name 208 | print("loading model from {}".format(os.path.join(self.args.checkpoints, setting, best_model_path))) 209 | self.model.load_state_dict(torch.load(os.path.join(self.args.checkpoints, setting, best_model_path)), strict=False) 210 | 211 | self.args.data_path = self.args.test_data_path 212 | test_data, test_loader = self._get_data('test') 213 | 214 | preds = [] 215 | trues = [] 216 | 217 | self.model.eval() 218 | with torch.no_grad(): 219 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader): 220 | batch_x = batch_x.float().to(self.device) 221 | batch_y = batch_y.float() 222 | 223 | inference_steps = self.args.test_pred_len // self.args.token_len 224 | dis = self.args.test_pred_len - inference_steps * self.args.token_len 225 | if dis != 0: 226 | inference_steps += 1 227 | pred_y = [] 228 | 229 | for j in range(inference_steps): 230 | if len(pred_y) != 0: 231 | batch_x = torch.cat([batch_x[:, self.args.token_len:, :], pred_y[-1]], dim=1) 232 | if self.args.use_amp: 233 | with torch.cuda.amp.autocast(): 234 | outputs = self.model(batch_x, None, None, None) 235 | else: 236 | outputs = self.model(batch_x, None, None, None) 237 | pred_y.append(outputs[:, -self.args.token_len:, :]) 238 | pred_y = torch.cat(pred_y, dim=1) 239 | if dis != 0: 240 | pred_y = pred_y[:, :-dis, :] 241 | 242 | outputs = pred_y 243 | batch_y = batch_y[:, -self.args.test_pred_len:, :].to(self.device) 244 | 245 | pred = outputs.detach().cpu().numpy() 246 | true = batch_y.detach().cpu().numpy() 247 | 248 | preds.append(pred) 249 | trues.append(true) 250 | 251 | preds = np.concatenate(preds, axis=0) 252 | trues = np.concatenate(trues, axis=0) 253 | print('test shape:', preds.shape, trues.shape) 254 | 255 | smape = SMAPE(preds, trues) 256 | mape = MAPE(preds, trues) 257 | print('mape:{:4f}, smape:{:.4f}'.format(mape, smape)) 258 | -------------------------------------------------------------------------------- /figures/ablation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/ablation.png -------------------------------------------------------------------------------- /figures/ablation_llm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/ablation_llm.png -------------------------------------------------------------------------------- /figures/adaption_efficiency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/adaption_efficiency.png -------------------------------------------------------------------------------- /figures/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/comparison.png -------------------------------------------------------------------------------- /figures/formulation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/formulation.png -------------------------------------------------------------------------------- /figures/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/icon.png -------------------------------------------------------------------------------- /figures/illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/illustration.png -------------------------------------------------------------------------------- /figures/in-context.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/in-context.png -------------------------------------------------------------------------------- /figures/lora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/lora.png -------------------------------------------------------------------------------- /figures/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/method.png -------------------------------------------------------------------------------- /figures/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/motivation.png -------------------------------------------------------------------------------- /figures/one-for-all_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/one-for-all_results.png -------------------------------------------------------------------------------- /figures/param.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/param.png -------------------------------------------------------------------------------- /figures/showcases.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/showcases.png -------------------------------------------------------------------------------- /figures/subway_icf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/subway_icf.png -------------------------------------------------------------------------------- /figures/zeroshot_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/figures/zeroshot_results.png -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/layers/__init__.py -------------------------------------------------------------------------------- /layers/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class MLP(nn.Module): 4 | ''' 5 | Multilayer perceptron to encode/decode high dimension representation of sequential data 6 | ''' 7 | def __init__(self, 8 | f_in, 9 | f_out, 10 | hidden_dim=256, 11 | hidden_layers=2, 12 | dropout=0.1, 13 | activation='tanh'): 14 | super(MLP, self).__init__() 15 | self.f_in = f_in 16 | self.f_out = f_out 17 | self.hidden_dim = hidden_dim 18 | self.hidden_layers = hidden_layers 19 | self.dropout = dropout 20 | if activation == 'relu': 21 | self.activation = nn.ReLU() 22 | elif activation == 'tanh': 23 | self.activation = nn.Tanh() 24 | elif activation == 'gelu': 25 | self.activation = nn.GELU() 26 | else: 27 | raise NotImplementedError 28 | 29 | layers = [nn.Linear(self.f_in, self.hidden_dim), 30 | self.activation, nn.Dropout(self.dropout)] 31 | for i in range(self.hidden_layers-2): 32 | layers += [nn.Linear(self.hidden_dim, self.hidden_dim), 33 | self.activation, nn.Dropout(dropout)] 34 | 35 | layers += [nn.Linear(hidden_dim, f_out)] 36 | self.layers = nn.Sequential(*layers) 37 | 38 | def forward(self, x): 39 | # x: B x S x f_in 40 | # y: B x S x f_out 41 | y = self.layers(x) 42 | return y 43 | -------------------------------------------------------------------------------- /models/AutoTimes_Gpt2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers.models.gpt2.modeling_gpt2 import GPT2Model 4 | from layers.mlp import MLP 5 | 6 | class Model(nn.Module): 7 | def __init__(self, configs): 8 | super(Model, self).__init__() 9 | self.token_len = configs.token_len 10 | if configs.use_multi_gpu: 11 | self.device = f"cuda:{configs.local_rank}" 12 | else: 13 | self.device = f"cuda:{configs.gpu}" 14 | print(self.device) 15 | 16 | self.gpt2 = GPT2Model.from_pretrained(configs.llm_ckp_dir) 17 | self.hidden_dim_of_gpt2 = 768 18 | self.mix = configs.mix_embeds 19 | 20 | if self.mix: 21 | self.add_scale = nn.Parameter(torch.ones([])) 22 | 23 | for name, param in self.gpt2.named_parameters(): 24 | param.requires_grad = False 25 | 26 | if configs.mlp_hidden_layers == 0: 27 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0): 28 | print("use linear as tokenizer and detokenizer") 29 | self.encoder = nn.Linear(self.token_len, self.hidden_dim_of_gpt2) 30 | self.decoder = nn.Linear(self.hidden_dim_of_gpt2, self.token_len) 31 | else: 32 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0): 33 | print("use mlp as tokenizer and detokenizer") 34 | self.encoder = MLP(self.token_len, self.hidden_dim_of_gpt2, 35 | configs.mlp_hidden_dim, configs.mlp_hidden_layers, 36 | configs.dropout, configs.mlp_activation) 37 | self.decoder = MLP(self.hidden_dim_of_gpt2, self.token_len, 38 | configs.mlp_hidden_dim, configs.mlp_hidden_layers, 39 | configs.dropout, configs.mlp_activation) 40 | 41 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 42 | means = x_enc.mean(1, keepdim=True).detach() 43 | x_enc = x_enc - means 44 | stdev = torch.sqrt( 45 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 46 | x_enc /= stdev 47 | 48 | bs, _, n_vars = x_enc.shape 49 | # x_enc: [bs x nvars x seq_len] 50 | x_enc = x_enc.permute(0, 2, 1) 51 | # x_enc: [bs * nvars x seq_len] 52 | x_enc = x_enc.reshape(x_enc.shape[0] * x_enc.shape[1], -1) 53 | # fold_out: [bs * n_vars x token_num x token_len] 54 | fold_out = x_enc.unfold(dimension=-1, size=self.token_len, step=self.token_len) 55 | token_num = fold_out.shape[1] 56 | # times_embeds: [bs * n_vars x token_num x hidden_dim_of_gpt2] 57 | times_embeds = self.encoder(fold_out) 58 | if self.mix: 59 | times_embeds = times_embeds / times_embeds.norm(dim=2, keepdim=True) 60 | x_mark_enc = x_mark_enc / x_mark_enc.norm(dim=2, keepdim=True) 61 | times_embeds = times_embeds + self.add_scale * x_mark_enc 62 | # outputs: [bs * n_vars x token_num x hidden_dim_of_gpt2] 63 | outputs = self.gpt2( 64 | inputs_embeds=times_embeds).last_hidden_state 65 | # dec_out: [bs * n_vars x token_num x token_len] 66 | dec_out = self.decoder(outputs) 67 | dec_out = dec_out.reshape(bs, n_vars, -1) 68 | # dec_out: [bs x token_num * token_len x n_vars] 69 | dec_out = dec_out.permute(0, 2, 1) 70 | 71 | dec_out = dec_out * \ 72 | (stdev[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1)) 73 | dec_out = dec_out + \ 74 | (means[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1)) 75 | 76 | return dec_out 77 | 78 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 79 | return self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) -------------------------------------------------------------------------------- /models/AutoTimes_Llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import LlamaForCausalLM 4 | from layers.mlp import MLP 5 | 6 | class Model(nn.Module): 7 | def __init__(self, configs): 8 | super(Model, self).__init__() 9 | self.token_len = configs.token_len 10 | if configs.use_multi_gpu: 11 | self.device = f"cuda:{configs.local_rank}" 12 | else: 13 | self.device = f"cuda:{configs.gpu}" 14 | print(self.device) 15 | 16 | self.llama = LlamaForCausalLM.from_pretrained( 17 | configs.llm_ckp_dir, 18 | device_map=self.device, 19 | torch_dtype=torch.float16 if configs.use_amp else torch.float32, 20 | ) 21 | self.hidden_dim_of_llama = 4096 22 | self.mix = configs.mix_embeds 23 | if self.mix: 24 | self.add_scale = nn.Parameter(torch.ones([])) 25 | 26 | for name, param in self.llama.named_parameters(): 27 | param.requires_grad = False 28 | 29 | if configs.mlp_hidden_layers == 0: 30 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0): 31 | print("use linear as tokenizer and detokenizer") 32 | self.encoder = nn.Linear(self.token_len, self.hidden_dim_of_llama) 33 | self.decoder = nn.Linear(self.hidden_dim_of_llama, self.token_len) 34 | else: 35 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0): 36 | print("use mlp as tokenizer and detokenizer") 37 | self.encoder = MLP(self.token_len, self.hidden_dim_of_llama, 38 | configs.mlp_hidden_dim, configs.mlp_hidden_layers, 39 | configs.dropout, configs.mlp_activation) 40 | self.decoder = MLP(self.hidden_dim_of_llama, self.token_len, 41 | configs.mlp_hidden_dim, configs.mlp_hidden_layers, 42 | configs.dropout, configs.mlp_activation) 43 | 44 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 45 | means = x_enc.mean(1, keepdim=True).detach() 46 | x_enc = x_enc - means 47 | stdev = torch.sqrt( 48 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 49 | x_enc /= stdev 50 | 51 | bs, _, n_vars = x_enc.shape 52 | # x_enc: [bs x nvars x seq_len] 53 | x_enc = x_enc.permute(0, 2, 1) 54 | # x_enc: [bs * nvars x seq_len] 55 | x_enc = x_enc.reshape(x_enc.shape[0] * x_enc.shape[1], -1) 56 | # fold_out: [bs * n_vars x token_num x token_len] 57 | fold_out = x_enc.unfold(dimension=-1, size=self.token_len, step=self.token_len) 58 | token_num = fold_out.shape[1] 59 | # times_embeds: [bs * n_vars x token_num x hidden_dim_of_llama] 60 | times_embeds = self.encoder(fold_out) 61 | if self.mix: 62 | times_embeds = times_embeds / times_embeds.norm(dim=2, keepdim=True) 63 | x_mark_enc = x_mark_enc / x_mark_enc.norm(dim=2, keepdim=True) 64 | times_embeds = times_embeds + self.add_scale * x_mark_enc 65 | # outputs: [bs * n_vars x token_num x hidden_dim_of_llama] 66 | outputs = self.llama.model( 67 | inputs_embeds=times_embeds)[0] 68 | # dec_out: [bs * n_vars x token_num x token_len] 69 | dec_out = self.decoder(outputs) 70 | dec_out = dec_out.reshape(bs, n_vars, -1) 71 | # dec_out: [bs x token_num * token_len x n_vars] 72 | dec_out = dec_out.permute(0, 2, 1) 73 | 74 | dec_out = dec_out * \ 75 | (stdev[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1)) 76 | dec_out = dec_out + \ 77 | (means[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1)) 78 | 79 | return dec_out 80 | 81 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 82 | return self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) -------------------------------------------------------------------------------- /models/AutoTimes_Opt_1b.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import OPTForCausalLM 4 | from layers.mlp import MLP 5 | 6 | class Model(nn.Module): 7 | def __init__(self, configs): 8 | super(Model, self).__init__() 9 | self.token_len = configs.token_len 10 | if configs.use_multi_gpu: 11 | self.device = f"cuda:{configs.local_rank}" 12 | else: 13 | self.device = f"cuda:{configs.gpu}" 14 | print(self.device) 15 | 16 | self.opt = OPTForCausalLM.from_pretrained(configs.llm_ckp_dir, torch_dtype=torch.float16) 17 | self.opt.model.decoder.project_in = None 18 | self.opt.model.decoder.project_out = None 19 | 20 | self.hidden_dim_of_opt1b = 2048 21 | self.mix = configs.mix_embeds 22 | 23 | if self.mix: 24 | self.add_scale = nn.Parameter(torch.ones([])) 25 | 26 | for name, param in self.opt.named_parameters(): 27 | param.requires_grad = False 28 | 29 | if configs.mlp_hidden_layers == 0: 30 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0): 31 | print("use linear as tokenizer and detokenizer") 32 | self.encoder = nn.Linear(self.token_len, self.hidden_dim_of_opt1b) 33 | self.decoder = nn.Linear(self.hidden_dim_of_opt1b, self.token_len) 34 | else: 35 | if not configs.use_multi_gpu or (configs.use_multi_gpu and configs.local_rank == 0): 36 | print("use mlp as tokenizer and detokenizer") 37 | self.encoder = MLP(self.token_len, self.hidden_dim_of_opt1b, 38 | configs.mlp_hidden_dim, configs.mlp_hidden_layers, 39 | configs.dropout, configs.mlp_activation) 40 | self.decoder = MLP(self.hidden_dim_of_opt1b, self.token_len, 41 | configs.mlp_hidden_dim, configs.mlp_hidden_layers, 42 | configs.dropout, configs.mlp_activation) 43 | 44 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 45 | means = x_enc.mean(1, keepdim=True).detach() 46 | x_enc = x_enc - means 47 | stdev = torch.sqrt( 48 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 49 | x_enc /= stdev 50 | 51 | bs, _, n_vars = x_enc.shape 52 | # x_enc: [bs x nvars x seq_len] 53 | x_enc = x_enc.permute(0, 2, 1) 54 | # x_enc: [bs * nvars x seq_len] 55 | x_enc = x_enc.reshape(x_enc.shape[0] * x_enc.shape[1], -1) 56 | # fold_out: [bs * n_vars x token_num x token_len] 57 | fold_out = x_enc.unfold(dimension=-1, size=self.token_len, step=self.token_len) 58 | token_num = fold_out.shape[1] 59 | # times_embeds: [bs * n_vars x token_num x hidden_dim_of_opt1b] 60 | times_embeds = self.encoder(fold_out) 61 | if self.mix: 62 | times_embeds = times_embeds / times_embeds.norm(dim=2, keepdim=True) 63 | x_mark_enc = x_mark_enc / x_mark_enc.norm(dim=2, keepdim=True) 64 | times_embeds = times_embeds + self.add_scale * x_mark_enc 65 | # outputs: [bs * n_vars x token_num x hidden_dim_of_opt1b] 66 | outputs = self.opt.model( 67 | inputs_embeds=times_embeds).last_hidden_state 68 | # dec_out: [bs * n_vars x token_num x token_len] 69 | dec_out = self.decoder(outputs) 70 | dec_out = dec_out.reshape(bs, n_vars, -1) 71 | # dec_out: [bs x token_num * token_len x n_vars] 72 | dec_out = dec_out.permute(0, 2, 1) 73 | 74 | dec_out = dec_out * \ 75 | (stdev[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1)) 76 | dec_out = dec_out + \ 77 | (means[:, 0, :].unsqueeze(1).repeat(1, token_num * self.token_len, 1)) 78 | 79 | return dec_out 80 | 81 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 82 | return self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) -------------------------------------------------------------------------------- /models/Preprocess_Llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import ( 4 | LlamaForCausalLM, 5 | LlamaTokenizer, 6 | ) 7 | 8 | class Model(nn.Module): 9 | def __init__(self, configs): 10 | super(Model, self).__init__() 11 | self.device = configs.gpu 12 | print(self.device) 13 | 14 | self.llama = LlamaForCausalLM.from_pretrained( 15 | configs.llm_ckp_dir, 16 | device_map=self.device, 17 | torch_dtype=torch.float16, 18 | ) 19 | self.llama_tokenizer = LlamaTokenizer.from_pretrained(configs.llm_ckp_dir) 20 | self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token 21 | self.vocab_size = self.llama_tokenizer.vocab_size 22 | self.hidden_dim_of_llama = 4096 23 | 24 | for name, param in self.llama.named_parameters(): 25 | param.requires_grad = False 26 | 27 | def tokenizer(self, x): 28 | output = self.llama_tokenizer(x, return_tensors="pt")['input_ids'].to(self.device) 29 | result = self.llama.get_input_embeddings()(output) 30 | return result 31 | 32 | def forecast(self, x_mark_enc): 33 | # x_mark_enc: [bs x T x hidden_dim_of_llama] 34 | x_mark_enc = torch.cat([self.tokenizer(x_mark_enc[i]) for i in range(len(x_mark_enc))], 0) 35 | text_outputs = self.llama.model(inputs_embeds=x_mark_enc)[0] 36 | text_outputs = text_outputs[:, -1, :] 37 | return text_outputs 38 | 39 | def forward(self, x_mark_enc): 40 | return self.forecast(x_mark_enc) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/models/__init__.py -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from models.Preprocess_Llama import Model 4 | 5 | from data_provider.data_loader import Dataset_Preprocess 6 | from torch.utils.data import DataLoader 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser(description='AutoTimes Preprocess') 10 | parser.add_argument('--gpu', type=int, default=0, help='gpu id') 11 | parser.add_argument('--llm_ckp_dir', type=str, default='./llama', help='llm checkpoints dir') 12 | parser.add_argument('--dataset', type=str, default='ETTh1', 13 | help='dataset to preprocess, options:[ETTh1, electricity, weather, traffic]') 14 | args = parser.parse_args() 15 | print(args.dataset) 16 | 17 | model = Model(args) 18 | 19 | seq_len = 672 20 | label_len = 576 21 | pred_len = 96 22 | 23 | assert args.dataset in ['ETTh1', 'electricity', 'weather', 'traffic'] 24 | if args.dataset == 'ETTh1': 25 | data_set = Dataset_Preprocess( 26 | root_path='./dataset/ETT-small/', 27 | data_path='ETTh1.csv', 28 | size=[seq_len, label_len, pred_len]) 29 | elif args.dataset == 'electricity': 30 | data_set = Dataset_Preprocess( 31 | root_path='./dataset/electricity/', 32 | data_path='electricity.csv', 33 | size=[seq_len, label_len, pred_len]) 34 | elif args.dataset == 'weather': 35 | data_set = Dataset_Preprocess( 36 | root_path='./dataset/weather/', 37 | data_path='weather.csv', 38 | size=[seq_len, label_len, pred_len]) 39 | elif args.dataset == 'traffic': 40 | data_set = Dataset_Preprocess( 41 | root_path='./dataset/traffic/', 42 | data_path='traffic.csv', 43 | size=[seq_len, label_len, pred_len]) 44 | 45 | data_loader = DataLoader( 46 | data_set, 47 | batch_size=128, 48 | shuffle=False, 49 | ) 50 | 51 | from tqdm import tqdm 52 | print(len(data_set.data_stamp)) 53 | print(data_set.tot_len) 54 | save_dir_path = './dataset/' 55 | output_list = [] 56 | for idx, data in tqdm(enumerate(data_loader)): 57 | output = model(data) 58 | output_list.append(output.detach().cpu()) 59 | result = torch.cat(output_list, dim=0) 60 | print(result.shape) 61 | torch.save(result, save_dir_path + f'/{args.dataset}.pt') 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.4.0 2 | matplotlib==3.7.0 3 | numpy==1.23.5 4 | pandas==1.5.3 5 | patool==1.12 6 | reformer-pytorch==1.4.4 7 | scikit-learn==1.2.2 8 | scipy==1.10.1 9 | sktime==0.16.1 10 | sympy==1.11.1 11 | tqdm==4.64.1 12 | torch==2.0.1 13 | transformers==4.35.2 14 | accelerate 15 | sentencepiece 16 | protobuf -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast 8 | from exp.exp_short_term_forecasting import Exp_Short_Term_Forecast 9 | from exp.exp_zero_shot_forecasting import Exp_Zero_Shot_Forecast 10 | from exp.exp_in_context_forecasting import Exp_In_Context_Forecast 11 | 12 | if __name__ == '__main__': 13 | fix_seed = 2021 14 | random.seed(fix_seed) 15 | torch.manual_seed(fix_seed) 16 | np.random.seed(fix_seed) 17 | 18 | parser = argparse.ArgumentParser(description='AutoTimes') 19 | 20 | # basic config 21 | parser.add_argument('--task_name', type=str, required=True, default='long_term_forecast', 22 | help='task name, options:[long_term_forecast, short_term_forecast, zero_shot_forecasting, in_context_forecasting]') 23 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status') 24 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id') 25 | parser.add_argument('--model', type=str, required=True, default='AutoTimes_Llama', 26 | help='model name, options: [AutoTimes_Llama, AutoTimes_Gpt2, AutoTimes_Opt1b]') 27 | 28 | # data loader 29 | parser.add_argument('--data', type=str, required=True, default='ETTm1', help='dataset type') 30 | parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file') 31 | parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file') 32 | parser.add_argument('--test_data_path', type=str, default='ETTh1.csv', help='test data file used in zero shot forecasting') 33 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints') 34 | parser.add_argument('--drop_last', action='store_true', default=False, help='drop last batch in data loader') 35 | parser.add_argument('--val_set_shuffle', action='store_false', default=True, help='shuffle validation set') 36 | parser.add_argument('--drop_short', action='store_true', default=False, help='drop too short sequences in dataset') 37 | 38 | # forecasting task 39 | parser.add_argument('--seq_len', type=int, default=672, help='input sequence length') 40 | parser.add_argument('--label_len', type=int, default=576, help='label length') 41 | parser.add_argument('--token_len', type=int, default=96, help='token length') 42 | parser.add_argument('--test_seq_len', type=int, default=672, help='test seq len') 43 | parser.add_argument('--test_label_len', type=int, default=576, help='test label len') 44 | parser.add_argument('--test_pred_len', type=int, default=96, help='test pred len') 45 | parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4') 46 | 47 | # model define 48 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout') 49 | parser.add_argument('--llm_ckp_dir', type=str, default='./llama', help='llm checkpoints dir') 50 | parser.add_argument('--mlp_hidden_dim', type=int, default=256, help='mlp hidden dim') 51 | parser.add_argument('--mlp_hidden_layers', type=int, default=2, help='mlp hidden layers') 52 | parser.add_argument('--mlp_activation', type=str, default='tanh', help='mlp activation') 53 | 54 | # optimization 55 | parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') 56 | parser.add_argument('--itr', type=int, default=1, help='experiments times') 57 | parser.add_argument('--train_epochs', type=int, default=10, help='train epochs') 58 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data') 59 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience') 60 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate') 61 | parser.add_argument('--des', type=str, default='test', help='exp description') 62 | parser.add_argument('--loss', type=str, default='MSE', help='loss function') 63 | parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate') 64 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False) 65 | parser.add_argument('--cosine', action='store_true', help='use cosine annealing lr', default=False) 66 | parser.add_argument('--tmax', type=int, default=10, help='tmax in cosine anealing lr') 67 | parser.add_argument('--weight_decay', type=float, default=0) 68 | parser.add_argument('--mix_embeds', action='store_true', help='mix embeds', default=False) 69 | parser.add_argument('--test_dir', type=str, default='./test', help='test dir') 70 | parser.add_argument('--test_file_name', type=str, default='checkpoint.pth', help='test file') 71 | 72 | # GPU 73 | parser.add_argument('--gpu', type=int, default=0, help='gpu') 74 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) 75 | parser.add_argument('--visualize', action='store_true', help='visualize', default=False) 76 | args = parser.parse_args() 77 | 78 | if args.use_multi_gpu: 79 | ip = os.environ.get("MASTER_ADDR", "127.0.0.1") 80 | port = os.environ.get("MASTER_PORT", "64209") 81 | hosts = int(os.environ.get("WORLD_SIZE", "8")) 82 | rank = int(os.environ.get("RANK", "0")) 83 | local_rank = int(os.environ.get("LOCAL_RANK", "0")) 84 | gpus = torch.cuda.device_count() 85 | args.local_rank = local_rank 86 | print(ip, port, hosts, rank, local_rank, gpus) 87 | dist.init_process_group(backend="nccl", init_method=f"tcp://{ip}:{port}", world_size=hosts, 88 | rank=rank) 89 | torch.cuda.set_device(local_rank) 90 | 91 | if args.task_name == 'long_term_forecast': 92 | Exp = Exp_Long_Term_Forecast 93 | elif args.task_name == 'short_term_forecast': 94 | Exp = Exp_Short_Term_Forecast 95 | elif args.task_name == 'zero_shot_forecast': 96 | Exp = Exp_Zero_Shot_Forecast 97 | elif args.task_name == 'in_context_forecast': 98 | Exp = Exp_In_Context_Forecast 99 | else: 100 | Exp = Exp_Long_Term_Forecast 101 | 102 | if args.is_training: 103 | for ii in range(args.itr): 104 | # setting record of experiments 105 | exp = Exp(args) # set experiments 106 | setting = '{}_{}_{}_{}_sl{}_ll{}_tl{}_lr{}_bt{}_wd{}_hd{}_hl{}_cos{}_mix{}_{}_{}'.format( 107 | args.task_name, 108 | args.model_id, 109 | args.model, 110 | args.data, 111 | args.seq_len, 112 | args.label_len, 113 | args.token_len, 114 | args.learning_rate, 115 | args.batch_size, 116 | args.weight_decay, 117 | args.mlp_hidden_dim, 118 | args.mlp_hidden_layers, 119 | args.cosine, 120 | args.mix_embeds, 121 | args.des, ii) 122 | if (args.use_multi_gpu and args.local_rank == 0) or not args.use_multi_gpu: 123 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) 124 | exp.train(setting) 125 | if (args.use_multi_gpu and args.local_rank == 0) or not args.use_multi_gpu: 126 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 127 | exp.test(setting) 128 | torch.cuda.empty_cache() 129 | else: 130 | ii = 0 131 | setting = '{}_{}_{}_{}_sl{}_ll{}_tl{}_lr{}_bt{}_wd{}_hd{}_hl{}_cos{}_mix{}_{}_{}'.format( 132 | args.task_name, 133 | args.model_id, 134 | args.model, 135 | args.data, 136 | args.seq_len, 137 | args.label_len, 138 | args.token_len, 139 | args.learning_rate, 140 | args.batch_size, 141 | args.weight_decay, 142 | args.mlp_hidden_dim, 143 | args.mlp_hidden_layers, 144 | args.cosine, 145 | args.mix_embeds, 146 | args.des, ii) 147 | exp = Exp(args) # set experiments 148 | exp.test(setting, test=1) 149 | torch.cuda.empty_cache() 150 | -------------------------------------------------------------------------------- /scripts/in_context_forecasting/M3.sh: -------------------------------------------------------------------------------- 1 | model_name=AutoTimes_Llama 2 | 3 | python -u run.py \ 4 | --task_name in_context_forecast \ 5 | --is_training 1 \ 6 | --root_path ./dataset/m4 \ 7 | --test_data_path m3_yearly_dataset.tsf \ 8 | --seasonal_patterns 'Yearly' \ 9 | --model_id m4_Yearly \ 10 | --model $model_name \ 11 | --data m4 \ 12 | --seq_len 18 \ 13 | --label_len 12 \ 14 | --token_len 6 \ 15 | --test_seq_len 6 \ 16 | --test_label_len 0 \ 17 | --test_pred_len 6 \ 18 | --batch_size 16 \ 19 | --des 'Exp' \ 20 | --itr 1 \ 21 | --learning_rate 0.0005 \ 22 | --loss 'SMAPE' \ 23 | --use_amp \ 24 | --mlp_hidden_dim 128 \ 25 | --cosine \ 26 | --tmax 10 \ 27 | --drop_short 28 | 29 | python -u run.py \ 30 | --task_name in_context_forecast \ 31 | --is_training 1 \ 32 | --root_path ./dataset/m4 \ 33 | --test_data_path m3_quarterly_dataset.tsf \ 34 | --seasonal_patterns 'Quarterly' \ 35 | --model_id m4_Quarterly \ 36 | --model $model_name \ 37 | --data m4 \ 38 | --seq_len 24 \ 39 | --label_len 16 \ 40 | --token_len 8 \ 41 | --test_seq_len 8 \ 42 | --test_label_len 0 \ 43 | --test_pred_len 8 \ 44 | --batch_size 16 \ 45 | --des 'Exp' \ 46 | --itr 1 \ 47 | --learning_rate 0.0001 \ 48 | --loss 'SMAPE' \ 49 | --use_amp \ 50 | --mlp_hidden_dim 1024 \ 51 | --cosine \ 52 | --tmax 10 \ 53 | --drop_short 54 | 55 | python -u run.py \ 56 | --task_name in_context_forecast \ 57 | --is_training 1 \ 58 | --root_path ./dataset/m4 \ 59 | --test_data_path m3_monthly_dataset.tsf \ 60 | --seasonal_patterns 'Monthly' \ 61 | --model_id m4_Monthly \ 62 | --model $model_name \ 63 | --data m4 \ 64 | --seq_len 54 \ 65 | --label_len 36 \ 66 | --token_len 18 \ 67 | --test_seq_len 18 \ 68 | --test_label_len 0 \ 69 | --test_pred_len 18 \ 70 | --batch_size 16 \ 71 | --des 'Exp' \ 72 | --itr 1 \ 73 | --learning_rate 0.0005 \ 74 | --loss 'SMAPE' \ 75 | --use_amp \ 76 | --mlp_hidden_dim 256 \ 77 | --cosine \ 78 | --tmax 10 \ 79 | --drop_short 80 | 81 | python -u run.py \ 82 | --task_name in_context_forecast \ 83 | --is_training 1 \ 84 | --root_path ./dataset/m4 \ 85 | --test_data_path m3_other_dataset.tsf \ 86 | --seasonal_patterns 'Quarterly' \ 87 | --model_id m4_Quarterly \ 88 | --model $model_name \ 89 | --data m4 \ 90 | --seq_len 24 \ 91 | --label_len 16 \ 92 | --token_len 8 \ 93 | --test_seq_len 8 \ 94 | --test_label_len 0 \ 95 | --test_pred_len 8 \ 96 | --batch_size 16 \ 97 | --des 'Exp' \ 98 | --itr 1 \ 99 | --learning_rate 0.0005 \ 100 | --loss 'SMAPE' \ 101 | --use_amp \ 102 | --mlp_hidden_dim 512 \ 103 | --cosine \ 104 | --tmax 10 \ 105 | --drop_short -------------------------------------------------------------------------------- /scripts/method_generality/gpt2.sh: -------------------------------------------------------------------------------- 1 | model_name=AutoTimes_Gpt2 2 | 3 | python -u run.py \ 4 | --task_name long_term_forecast \ 5 | --is_training 1 \ 6 | --root_path ./dataset/ETT-small/ \ 7 | --data_path ETTh1.csv \ 8 | --model_id ETTh1_672_96 \ 9 | --model $model_name \ 10 | --data ETTh1 \ 11 | --seq_len 672 \ 12 | --label_len 576 \ 13 | --token_len 96 \ 14 | --test_seq_len 672 \ 15 | --test_label_len 576 \ 16 | --test_pred_len 96 \ 17 | --batch_size 2048 \ 18 | --learning_rate 0.002 \ 19 | --itr 1 \ 20 | --train_epochs 10 \ 21 | --use_amp \ 22 | --llm_ckp_dir ./models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10 \ 23 | --gpu 0 \ 24 | --des 'Gpt2' \ 25 | --cosine \ 26 | --tmax 10 \ 27 | --mlp_hidden_dim 512 28 | 29 | 30 | for test_pred_len in 96 192 336 720 31 | do 32 | python -u run.py \ 33 | --task_name long_term_forecast \ 34 | --is_training 0 \ 35 | --root_path ./dataset/ETT-small/ \ 36 | --data_path ETTh1.csv \ 37 | --model_id ETTh1_672_96 \ 38 | --model $model_name \ 39 | --data ETTh1 \ 40 | --seq_len 672 \ 41 | --label_len 576 \ 42 | --token_len 96 \ 43 | --test_seq_len 672 \ 44 | --test_label_len 576 \ 45 | --test_pred_len $test_pred_len \ 46 | --batch_size 2048 \ 47 | --learning_rate 0.002 \ 48 | --itr 1 \ 49 | --train_epochs 10 \ 50 | --use_amp \ 51 | --llm_ckp_dir ./models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10 \ 52 | --gpu 0 \ 53 | --des 'Gpt2' \ 54 | --cosine \ 55 | --tmax 10 \ 56 | --mlp_hidden_dim 512 \ 57 | --test_dir long_term_forecast_ETTh1_672_96_AutoTimes_Gpt2_ETTh1_sl672_ll576_tl96_lr0.002_bt2048_wd0_hd512_hl2_cosTrue_mixFalse_Gpt2_0 58 | done -------------------------------------------------------------------------------- /scripts/method_generality/opt.sh: -------------------------------------------------------------------------------- 1 | model_name=AutoTimes_Opt_1b 2 | 3 | python -u run.py \ 4 | --task_name long_term_forecast \ 5 | --is_training 1 \ 6 | --root_path ./dataset/ETT-small/ \ 7 | --data_path ETTh1.csv \ 8 | --model_id ETTh1_672_96 \ 9 | --model $model_name \ 10 | --data ETTh1 \ 11 | --seq_len 672 \ 12 | --label_len 576 \ 13 | --token_len 96 \ 14 | --test_seq_len 672 \ 15 | --test_label_len 576 \ 16 | --test_pred_len 96 \ 17 | --batch_size 2048 \ 18 | --learning_rate 0.001 \ 19 | --itr 1 \ 20 | --train_epochs 10 \ 21 | --use_amp \ 22 | --llm_ckp_dir ./models--facebook--opt-1.3b/snapshots/3f5c25d0bc631cb57ac65913f76e22c2dfb61d62 \ 23 | --gpu 0 \ 24 | --des 'Opt1b' \ 25 | --cosine \ 26 | --tmax 10 \ 27 | --mlp_hidden_dim 256 28 | 29 | for test_pred_len in 96 192 336 720 30 | do 31 | python -u run.py \ 32 | --task_name long_term_forecast \ 33 | --is_training 0 \ 34 | --root_path ./dataset/ETT-small/ \ 35 | --data_path ETTh1.csv \ 36 | --model_id ETTh1_672_96 \ 37 | --model $model_name \ 38 | --data ETTh1 \ 39 | --seq_len 672 \ 40 | --label_len 576 \ 41 | --token_len 96 \ 42 | --test_seq_len 672 \ 43 | --test_label_len 576 \ 44 | --test_pred_len $test_pred_len \ 45 | --batch_size 2048 \ 46 | --learning_rate 0.001 \ 47 | --itr 1 \ 48 | --train_epochs 10 \ 49 | --use_amp \ 50 | --llm_ckp_dir ./models--facebook--opt-1.3b/snapshots/3f5c25d0bc631cb57ac65913f76e22c2dfb61d62 \ 51 | --gpu 0 \ 52 | --des 'Opt1b' \ 53 | --cosine \ 54 | --tmax 10 \ 55 | --mlp_hidden_dim 256 \ 56 | --test_dir long_term_forecast_ETTh1_672_96_AutoTimes_Opt_1b_ETTh1_sl672_ll576_tl96_lr0.001_bt2048_wd0_hd256_hl2_cosTrue_mixFalse_Opt1b_0 57 | done -------------------------------------------------------------------------------- /scripts/time_series_forecasting/long_term/AutoTimes_ECL.sh: -------------------------------------------------------------------------------- 1 | model_name=AutoTimes_Llama 2 | 3 | # training one model with a context length 4 | torchrun --nnodes 1 --nproc-per-node 8 run.py \ 5 | --task_name long_term_forecast \ 6 | --is_training 1 \ 7 | --root_path ./dataset/electricity/ \ 8 | --data_path electricity.csv \ 9 | --model_id ECL_672_96 \ 10 | --model $model_name \ 11 | --data custom \ 12 | --seq_len 672 \ 13 | --label_len 576 \ 14 | --token_len 96 \ 15 | --test_seq_len 672 \ 16 | --test_label_len 576 \ 17 | --test_pred_len 96 \ 18 | --batch_size 256 \ 19 | --learning_rate 0.001 \ 20 | --weight_decay 0.00001 \ 21 | --mlp_hidden_dim 1024 \ 22 | --train_epochs 10 \ 23 | --use_amp \ 24 | --mix_embeds \ 25 | --use_multi_gpu \ 26 | --tmax 10 \ 27 | --cosine 28 | 29 | # testing the model on all forecast lengths 30 | for test_pred_len in 96 192 336 720 31 | do 32 | python -u run.py \ 33 | --task_name long_term_forecast \ 34 | --is_training 0 \ 35 | --root_path ./dataset/electricity/ \ 36 | --data_path electricity.csv \ 37 | --model_id ECL_672_96 \ 38 | --model $model_name \ 39 | --data custom \ 40 | --seq_len 672 \ 41 | --label_len 576 \ 42 | --token_len 96 \ 43 | --test_seq_len 672 \ 44 | --test_label_len 576 \ 45 | --test_pred_len $test_pred_len \ 46 | --batch_size 256 \ 47 | --learning_rate 0.001 \ 48 | --weight_decay 0.00001 \ 49 | --mlp_hidden_dim 1024 \ 50 | --train_epochs 10 \ 51 | --use_amp \ 52 | --mix_embeds \ 53 | --tmax 10 \ 54 | --cosine \ 55 | --test_dir long_term_forecast_ECL_672_96_AutoTimes_Llama_custom_sl672_ll576_tl96_lr0.001_bt256_wd1e-05_hd1024_hl2_cosTrue_mixTrue_test_0 56 | done -------------------------------------------------------------------------------- /scripts/time_series_forecasting/long_term/AutoTimes_ETTh1.sh: -------------------------------------------------------------------------------- 1 | model_name=AutoTimes_Llama 2 | 3 | # training one model with a context length 4 | python -u run.py \ 5 | --task_name long_term_forecast \ 6 | --is_training 1 \ 7 | --root_path ./dataset/ETT-small/ \ 8 | --data_path ETTh1.csv \ 9 | --model_id ETTh1_672_96 \ 10 | --model $model_name \ 11 | --data ETTh1 \ 12 | --seq_len 672 \ 13 | --label_len 576 \ 14 | --token_len 96 \ 15 | --test_seq_len 672 \ 16 | --test_label_len 576 \ 17 | --test_pred_len 96 \ 18 | --batch_size 256 \ 19 | --learning_rate 0.0005 \ 20 | --mlp_hidden_layers 0 \ 21 | --train_epochs 10 \ 22 | --use_amp \ 23 | --gpu 0 \ 24 | --cosine \ 25 | --tmax 10 \ 26 | --mix_embeds \ 27 | --drop_last 28 | 29 | # testing the model on all forecast lengths 30 | for test_pred_len in 96 192 336 720 31 | do 32 | python -u run.py \ 33 | --task_name long_term_forecast \ 34 | --is_training 0 \ 35 | --root_path ./dataset/ETT-small/ \ 36 | --data_path ETTh1.csv \ 37 | --model_id ETTh1_672_96 \ 38 | --model $model_name \ 39 | --data ETTh1 \ 40 | --seq_len 672 \ 41 | --label_len 576 \ 42 | --token_len 96 \ 43 | --test_seq_len 672 \ 44 | --test_label_len 576 \ 45 | --test_pred_len $test_pred_len \ 46 | --batch_size 256 \ 47 | --learning_rate 0.0005 \ 48 | --mlp_hidden_layers 0 \ 49 | --train_epochs 10 \ 50 | --use_amp \ 51 | --gpu 0 \ 52 | --cosine \ 53 | --tmax 10 \ 54 | --mix_embeds \ 55 | --drop_last \ 56 | --test_dir long_term_forecast_ETTh1_672_96_AutoTimes_Llama_ETTh1_sl672_ll576_tl96_lr0.0005_bt256_wd0_hd256_hl0_cosTrue_mixTrue_test_0 57 | done 58 | -------------------------------------------------------------------------------- /scripts/time_series_forecasting/long_term/AutoTimes_Solar.sh: -------------------------------------------------------------------------------- 1 | model_name=AutoTimes_Llama 2 | 3 | # training one model with a context length 4 | torchrun --nnodes 1 --nproc-per-node 4 run.py \ 5 | --task_name long_term_forecast \ 6 | --is_training 1 \ 7 | --root_path ./dataset/Solar/ \ 8 | --data_path solar_AL.txt \ 9 | --model_id solar_672_96 \ 10 | --model $model_name \ 11 | --data Solar \ 12 | --seq_len 672 \ 13 | --label_len 576 \ 14 | --token_len 96 \ 15 | --test_seq_len 672 \ 16 | --test_label_len 576 \ 17 | --test_pred_len 96 \ 18 | --batch_size 256 \ 19 | --learning_rate 0.000005 \ 20 | --train_epochs 2 \ 21 | --use_amp \ 22 | --mlp_hidden_dim 1024 \ 23 | --mlp_activation relu \ 24 | --des 'Exp' \ 25 | --use_multi_gpu \ 26 | --cosine \ 27 | --tmax 10 28 | 29 | # testing the model on all forecast lengths 30 | for test_pred_len in 96 192 336 720 31 | do 32 | python -u run.py \ 33 | --task_name long_term_forecast \ 34 | --is_training 0 \ 35 | --root_path ./dataset/Solar/ \ 36 | --data_path solar_AL.txt \ 37 | --model_id solar_672_96 \ 38 | --model $model_name \ 39 | --data Solar \ 40 | --seq_len 672 \ 41 | --label_len 576 \ 42 | --token_len 96 \ 43 | --test_seq_len 672 \ 44 | --test_label_len 576 \ 45 | --test_pred_len $test_pred_len \ 46 | --batch_size 256 \ 47 | --learning_rate 0.000005 \ 48 | --train_epochs 2 \ 49 | --use_amp \ 50 | --mlp_hidden_dim 1024 \ 51 | --mlp_activation relu \ 52 | --des 'Exp' \ 53 | --cosine \ 54 | --tmax 10 \ 55 | --test_dir long_term_forecast_solar_672_96_AutoTimes_Llama_Solar_sl672_ll576_tl96_lr5e-06_bt256_wd0_hd1024_hl2_cosTrue_mixFalse_Exp_0 56 | done -------------------------------------------------------------------------------- /scripts/time_series_forecasting/long_term/AutoTimes_Traffic.sh: -------------------------------------------------------------------------------- 1 | model_name=AutoTimes_Llama 2 | 3 | # training one model with a context length 4 | torchrun --nnodes 1 --nproc-per-node 8 run.py \ 5 | --task_name long_term_forecast \ 6 | --is_training 1 \ 7 | --root_path ./dataset/traffic/ \ 8 | --data_path traffic.csv \ 9 | --model_id traffic_672_96 \ 10 | --model $model_name \ 11 | --data custom \ 12 | --seq_len 672 \ 13 | --label_len 576 \ 14 | --token_len 96 \ 15 | --test_seq_len 672 \ 16 | --test_label_len 576 \ 17 | --test_pred_len 96 \ 18 | --batch_size 256 \ 19 | --learning_rate 0.0001 \ 20 | --weight_decay 0.00001 \ 21 | --mlp_hidden_dim 1024 \ 22 | --mlp_activation relu \ 23 | --train_epochs 10 \ 24 | --use_amp \ 25 | --cosine \ 26 | --tmax 10 \ 27 | --mix_embeds \ 28 | --use_multi_gpu 29 | 30 | # testing the model on all forecast lengths 31 | for test_pred_len in 96 192 336 720 32 | do 33 | python -u run.py \ 34 | --task_name long_term_forecast \ 35 | --is_training 0 \ 36 | --root_path ./dataset/traffic/ \ 37 | --data_path traffic.csv \ 38 | --model_id traffic_672_96 \ 39 | --model $model_name \ 40 | --data custom \ 41 | --seq_len 672 \ 42 | --label_len 576 \ 43 | --token_len 96 \ 44 | --test_seq_len 672 \ 45 | --test_label_len 576 \ 46 | --test_pred_len $test_pred_len \ 47 | --batch_size 256 \ 48 | --learning_rate 0.0001 \ 49 | --weight_decay 0.00001 \ 50 | --mlp_hidden_dim 1024 \ 51 | --mlp_activation relu \ 52 | --train_epochs 10 \ 53 | --use_amp \ 54 | --gpu 0 \ 55 | --cosine \ 56 | --tmax 10 \ 57 | --mix_embeds \ 58 | --test_dir long_term_forecast_traffic_672_96_AutoTimes_Llama_custom_sl672_ll576_tl96_lr0.0001_bt256_wd1e-05_hd1024_hl2_cosTrue_mixTrue_test_0 59 | done -------------------------------------------------------------------------------- /scripts/time_series_forecasting/long_term/AutoTimes_Weather.sh: -------------------------------------------------------------------------------- 1 | model_name=AutoTimes_Llama 2 | 3 | # training one model with a context length 4 | torchrun --nnodes 1 --nproc-per-node 4 run.py \ 5 | --task_name long_term_forecast \ 6 | --is_training 1 \ 7 | --root_path ./dataset/weather/ \ 8 | --data_path weather.csv \ 9 | --model_id weather_672_96 \ 10 | --model $model_name \ 11 | --data custom \ 12 | --seq_len 672 \ 13 | --label_len 576 \ 14 | --token_len 96 \ 15 | --test_seq_len 672 \ 16 | --test_label_len 576 \ 17 | --test_pred_len 96 \ 18 | --batch_size 384 \ 19 | --learning_rate 0.0005 \ 20 | --train_epochs 10 \ 21 | --use_amp \ 22 | --lradj type2 \ 23 | --des 'Exp' \ 24 | --mlp_hidden_dim 512 \ 25 | --mlp_activation relu \ 26 | --use_multi_gpu \ 27 | --mix_embeds 28 | 29 | # testing the model on all forecast lengths 30 | for test_pred_len in 96 192 336 720 31 | do 32 | python -u run.py \ 33 | --task_name long_term_forecast \ 34 | --is_training 0 \ 35 | --root_path ./dataset/weather/ \ 36 | --data_path weather.csv \ 37 | --model_id weather_672_96 \ 38 | --model $model_name \ 39 | --data custom \ 40 | --seq_len 672 \ 41 | --label_len 576 \ 42 | --token_len 96 \ 43 | --test_seq_len 672 \ 44 | --test_label_len 576 \ 45 | --test_pred_len $test_pred_len \ 46 | --batch_size 384 \ 47 | --learning_rate 0.0005 \ 48 | --train_epochs 10 \ 49 | --use_amp \ 50 | --lradj type2 \ 51 | --des 'Exp' \ 52 | --mlp_hidden_dim 512 \ 53 | --mlp_activation relu \ 54 | --mix_embeds \ 55 | --test_dir long_term_forecast_weather_672_96_AutoTimes_Llama_custom_sl672_ll576_tl96_lr0.0005_bt384_wd0_hd512_hl2_cosFalse_mixTrue_Exp_0 56 | done -------------------------------------------------------------------------------- /scripts/time_series_forecasting/short_term/AutoTimes_M4.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=AutoTimes_Llama 4 | python -u run.py \ 5 | --task_name short_term_forecast \ 6 | --is_training 1 \ 7 | --root_path ./dataset/m4 \ 8 | --seasonal_patterns 'Yearly' \ 9 | --model_id m4_Yearly \ 10 | --model $model_name \ 11 | --data m4 \ 12 | --batch_size 16 \ 13 | --des 'Exp' \ 14 | --itr 1 \ 15 | --learning_rate 0.0001 \ 16 | --loss 'SMAPE' \ 17 | --use_amp \ 18 | --mlp_hidden_dim 512 \ 19 | --cosine \ 20 | --tmax 10 \ 21 | --weight_decay 0.00001 22 | 23 | python -u run.py \ 24 | --task_name short_term_forecast \ 25 | --is_training 1 \ 26 | --root_path ./dataset/m4 \ 27 | --seasonal_patterns 'Quarterly' \ 28 | --model_id m4_Quarterly \ 29 | --model $model_name \ 30 | --data m4 \ 31 | --batch_size 16 \ 32 | --des 'Exp' \ 33 | --itr 1 \ 34 | --learning_rate 0.00005 \ 35 | --loss 'SMAPE' \ 36 | --use_amp \ 37 | --mlp_hidden_dim 512 \ 38 | --cosine \ 39 | --tmax 10 \ 40 | --weight_decay 0.000005 41 | 42 | python -u run.py \ 43 | --task_name short_term_forecast \ 44 | --is_training 1 \ 45 | --root_path ./dataset/m4 \ 46 | --seasonal_patterns 'Monthly' \ 47 | --model_id m4_Monthly \ 48 | --model $model_name \ 49 | --data m4 \ 50 | --batch_size 16 \ 51 | --learning_rate 0.00005 \ 52 | --des 'Exp' \ 53 | --itr 1 \ 54 | --loss 'SMAPE' \ 55 | --use_amp \ 56 | --mlp_hidden_dim 1024 \ 57 | --cosine \ 58 | --tmax 10 \ 59 | --weight_decay 0.000001 60 | 61 | python -u run.py \ 62 | --task_name short_term_forecast \ 63 | --is_training 1 \ 64 | --root_path ./dataset/m4 \ 65 | --seasonal_patterns 'Weekly' \ 66 | --model_id m4_Weekly \ 67 | --model $model_name \ 68 | --data m4 \ 69 | --batch_size 16 \ 70 | --des 'Exp' \ 71 | --itr 1 \ 72 | --learning_rate 0.0001 \ 73 | --loss 'SMAPE' \ 74 | --use_amp \ 75 | --mlp_hidden_dim 1024 \ 76 | --cosine \ 77 | --tmax 10 78 | 79 | python -u run.py \ 80 | --task_name short_term_forecast \ 81 | --is_training 1 \ 82 | --root_path ./dataset/m4 \ 83 | --seasonal_patterns 'Daily' \ 84 | --model_id m4_Daily \ 85 | --model $model_name \ 86 | --data m4 \ 87 | --batch_size 16 \ 88 | --des 'Exp' \ 89 | --itr 1 \ 90 | --learning_rate 0.0005 \ 91 | --loss 'SMAPE' \ 92 | --use_amp \ 93 | --mlp_hidden_dim 1024 \ 94 | --weight_decay 0.000005 95 | 96 | python -u run.py \ 97 | --task_name short_term_forecast \ 98 | --is_training 1 \ 99 | --root_path ./dataset/m4 \ 100 | --seasonal_patterns 'Hourly' \ 101 | --model_id m4_Hourly \ 102 | --model $model_name \ 103 | --data m4 \ 104 | --batch_size 16 \ 105 | --des 'Exp' \ 106 | --itr 1 \ 107 | --learning_rate 0.0001 \ 108 | --loss 'SMAPE' \ 109 | --use_amp \ 110 | --mlp_hidden_dim 1024 \ 111 | --cosine \ 112 | --tmax 10 -------------------------------------------------------------------------------- /scripts/zero_shot_forecasting/sM3_tM4.sh: -------------------------------------------------------------------------------- 1 | model_name=AutoTimes_Llama 2 | 3 | python -u run.py \ 4 | --task_name zero_shot_forecast \ 5 | --is_training 1 \ 6 | --root_path ./dataset/tsf \ 7 | --data_path m3_yearly_dataset.tsf \ 8 | --test_data_path m4_yearly_dataset.tsf \ 9 | --seasonal_patterns 'Yearly' \ 10 | --model_id m3_Yearly \ 11 | --model $model_name \ 12 | --data tsf \ 13 | --seq_len 6 \ 14 | --label_len 0 \ 15 | --token_len 6 \ 16 | --test_seq_len 6 \ 17 | --test_label_len 0 \ 18 | --test_pred_len 6 \ 19 | --learning_rate 0.0001 \ 20 | --mlp_hidden_dim 256 \ 21 | --mlp_hidden_layers 3 \ 22 | --batch_size 16 \ 23 | --des 'Exp' \ 24 | --itr 1 \ 25 | --loss 'SMAPE' \ 26 | --use_amp \ 27 | --cosine \ 28 | --tmax 10 \ 29 | --val_set_shuffle 30 | 31 | python -u run.py \ 32 | --task_name zero_shot_forecast \ 33 | --is_training 1 \ 34 | --root_path ./dataset/tsf \ 35 | --data_path m3_quarterly_dataset.tsf \ 36 | --test_data_path m4_quarterly_dataset.tsf \ 37 | --seasonal_patterns 'Quarterly' \ 38 | --model_id m3_Quarterly \ 39 | --model $model_name \ 40 | --data tsf \ 41 | --seq_len 8 \ 42 | --label_len 0 \ 43 | --token_len 8 \ 44 | --test_seq_len 8 \ 45 | --test_label_len 0 \ 46 | --test_pred_len 8 \ 47 | --learning_rate 0.000005 \ 48 | --mlp_hidden_dim 1024 \ 49 | --batch_size 16 \ 50 | --des 'Exp' \ 51 | --itr 1 \ 52 | --loss 'SMAPE' \ 53 | --use_amp \ 54 | --cosine \ 55 | --tmax 10 \ 56 | --val_set_shuffle 57 | 58 | python -u run.py \ 59 | --task_name zero_shot_forecast \ 60 | --is_training 1 \ 61 | --root_path ./dataset/tsf \ 62 | --data_path m3_monthly_dataset.tsf \ 63 | --test_data_path m4_monthly_dataset.tsf \ 64 | --seasonal_patterns 'Monthly' \ 65 | --model_id m3_Monthly \ 66 | --model $model_name \ 67 | --data tsf \ 68 | --seq_len 24 \ 69 | --label_len 0 \ 70 | --token_len 24 \ 71 | --test_seq_len 24 \ 72 | --test_label_len 0 \ 73 | --test_pred_len 24 \ 74 | --learning_rate 0.00001 \ 75 | --mlp_hidden_dim 512 \ 76 | --batch_size 16 \ 77 | --des 'Exp' \ 78 | --itr 1 \ 79 | --loss 'SMAPE' \ 80 | --use_amp \ 81 | --cosine \ 82 | --tmax 10 \ 83 | --val_set_shuffle 84 | 85 | python -u run.py \ 86 | --task_name zero_shot_forecast \ 87 | --is_training 1 \ 88 | --root_path ./dataset/tsf \ 89 | --data_path m3_monthly_dataset.tsf \ 90 | --test_data_path m4_weekly_dataset.tsf \ 91 | --seasonal_patterns 'Monthly' \ 92 | --model_id m3_Monthly \ 93 | --model $model_name \ 94 | --data tsf \ 95 | --seq_len 26 \ 96 | --label_len 13 \ 97 | --token_len 13 \ 98 | --test_seq_len 26 \ 99 | --test_label_len 13 \ 100 | --test_pred_len 13 \ 101 | --learning_rate 0.001 \ 102 | --mlp_hidden_dim 256 \ 103 | --batch_size 16 \ 104 | --des 'Exp' \ 105 | --itr 1 \ 106 | --loss 'SMAPE' \ 107 | --use_amp \ 108 | --cosine \ 109 | --tmax 10 \ 110 | --val_set_shuffle 111 | 112 | python -u run.py \ 113 | --task_name zero_shot_forecast \ 114 | --is_training 1 \ 115 | --root_path ./dataset/tsf \ 116 | --data_path m3_monthly_dataset.tsf \ 117 | --test_data_path m4_daily_dataset.tsf \ 118 | --seasonal_patterns 'Monthly' \ 119 | --model_id m3_Monthly \ 120 | --model $model_name \ 121 | --data tsf \ 122 | --seq_len 28 \ 123 | --label_len 14 \ 124 | --token_len 14 \ 125 | --test_seq_len 28 \ 126 | --test_label_len 14 \ 127 | --test_pred_len 14 \ 128 | --learning_rate 0.0001 \ 129 | --mlp_hidden_dim 256 \ 130 | --batch_size 16 \ 131 | --des 'Exp' \ 132 | --itr 1 \ 133 | --loss 'SMAPE' \ 134 | --use_amp \ 135 | --cosine \ 136 | --tmax 10 \ 137 | --val_set_shuffle 138 | 139 | python -u run.py \ 140 | --task_name zero_shot_forecast \ 141 | --is_training 1 \ 142 | --root_path ./dataset/tsf \ 143 | --data_path m3_monthly_dataset.tsf \ 144 | --test_data_path m4_hourly_dataset.tsf \ 145 | --seasonal_patterns 'Monthly' \ 146 | --model_id m3_Monthly \ 147 | --model $model_name \ 148 | --data tsf \ 149 | --seq_len 48 \ 150 | --label_len 24 \ 151 | --token_len 24 \ 152 | --test_seq_len 48 \ 153 | --test_label_len 24 \ 154 | --test_pred_len 48 \ 155 | --learning_rate 0.001 \ 156 | --mlp_hidden_dim 128 \ 157 | --mlp_hidden_layers 3 \ 158 | --batch_size 16 \ 159 | --des 'Exp' \ 160 | --itr 1 \ 161 | --loss 'SMAPE' \ 162 | --use_amp \ 163 | --cosine \ 164 | --tmax 10 \ 165 | --val_set_shuffle -------------------------------------------------------------------------------- /scripts/zero_shot_forecasting/sM4_tM3.sh: -------------------------------------------------------------------------------- 1 | model_name=AutoTimes_Llama 2 | 3 | python -u run.py \ 4 | --task_name zero_shot_forecast \ 5 | --is_training 0 \ 6 | --root_path ./dataset/tsf \ 7 | --test_data_path m3_yearly_dataset.tsf \ 8 | --seasonal_patterns 'Yearly' \ 9 | --model_id m4_Yearly \ 10 | --model $model_name \ 11 | --data tsf \ 12 | --seq_len 12 \ 13 | --label_len 6 \ 14 | --token_len 6 \ 15 | --test_seq_len 12 \ 16 | --test_label_len 6 \ 17 | --test_pred_len 6 \ 18 | --batch_size 16 \ 19 | --des 'Exp' \ 20 | --itr 1 \ 21 | --loss 'SMAPE' \ 22 | --use_amp \ 23 | --mlp_hidden_dim 512 \ 24 | --test_dir short_term_forecast_m4_Yearly_AutoTimes_Llama_m4_sl12_ll6_tl6_lr0.0001_bt16_wd1e-05_hd512_hl2_cosTrue_mixFalse_Exp_0 25 | 26 | python -u run.py \ 27 | --task_name zero_shot_forecast \ 28 | --is_training 0 \ 29 | --root_path ./dataset/tsf \ 30 | --test_data_path m3_quarterly_dataset.tsf \ 31 | --seasonal_patterns 'Quarterly' \ 32 | --model_id m4_Quarterly \ 33 | --model $model_name \ 34 | --data tsf \ 35 | --seq_len 16 \ 36 | --label_len 8 \ 37 | --token_len 8 \ 38 | --test_seq_len 16 \ 39 | --test_label_len 8 \ 40 | --test_pred_len 8 \ 41 | --batch_size 16 \ 42 | --des 'Exp' \ 43 | --itr 1 \ 44 | --loss 'SMAPE' \ 45 | --use_amp \ 46 | --mlp_hidden_dim 512 \ 47 | --test_dir short_term_forecast_m4_Quarterly_AutoTimes_Llama_m4_sl16_ll8_tl8_lr5e-05_bt16_wd5e-06_hd512_hl2_cosTrue_mixFalse_Exp_0 48 | 49 | python -u run.py \ 50 | --task_name zero_shot_forecast \ 51 | --is_training 0 \ 52 | --root_path ./dataset/tsf \ 53 | --test_data_path m3_monthly_dataset.tsf \ 54 | --seasonal_patterns 'Monthly' \ 55 | --model_id m4_Monthly \ 56 | --model $model_name \ 57 | --data tsf \ 58 | --seq_len 36 \ 59 | --label_len 18 \ 60 | --token_len 18 \ 61 | --test_seq_len 36 \ 62 | --test_label_len 18 \ 63 | --test_pred_len 18 \ 64 | --batch_size 16 \ 65 | --des 'Exp' \ 66 | --itr 1 \ 67 | --loss 'SMAPE' \ 68 | --use_amp \ 69 | --mlp_hidden_dim 1024 \ 70 | --test_dir short_term_forecast_m4_Monthly_AutoTimes_Llama_m4_sl36_ll18_tl18_lr5e-05_bt16_wd1e-06_hd1024_hl2_cosTrue_mixFalse_Exp_0 71 | 72 | python -u run.py \ 73 | --task_name zero_shot_forecast \ 74 | --is_training 0 \ 75 | --root_path ./dataset/tsf \ 76 | --test_data_path m3_other_dataset.tsf \ 77 | --seasonal_patterns 'Quarterly' \ 78 | --model_id m4_Quarterly \ 79 | --model $model_name \ 80 | --data tsf \ 81 | --seq_len 16 \ 82 | --label_len 8 \ 83 | --token_len 8 \ 84 | --test_seq_len 16 \ 85 | --test_label_len 8 \ 86 | --test_pred_len 8 \ 87 | --batch_size 16 \ 88 | --des 'Exp' \ 89 | --itr 1 \ 90 | --loss 'SMAPE' \ 91 | --use_amp \ 92 | --mlp_hidden_dim 512 \ 93 | --test_dir short_term_forecast_m4_Quarterly_AutoTimes_Llama_m4_sl16_ll8_tl8_lr5e-05_bt16_wd5e-06_hd512_hl2_cosTrue_mixFalse_Exp_0 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/AutoTimes/197a84374d489a38b5ed0626d67db34194969079/utils/__init__.py -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | Loss functions for PyTorch. 17 | """ 18 | 19 | import torch as t 20 | import torch.nn as nn 21 | import numpy as np 22 | import pdb 23 | 24 | 25 | def divide_no_nan(a, b): 26 | """ 27 | a/b where the resulted NaN or Inf are replaced by 0. 28 | """ 29 | result = a / b 30 | result[result != result] = .0 31 | result[result == np.inf] = .0 32 | return result 33 | 34 | 35 | class mape_loss(nn.Module): 36 | def __init__(self): 37 | super(mape_loss, self).__init__() 38 | 39 | def forward(self, insample: t.Tensor, freq: int, 40 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 41 | """ 42 | MAPE loss as defined in: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error 43 | 44 | :param forecast: Forecast values. Shape: batch, time 45 | :param target: Target values. Shape: batch, time 46 | :param mask: 0/1 mask. Shape: batch, time 47 | :return: Loss value 48 | """ 49 | weights = divide_no_nan(mask, target) 50 | return t.mean(t.abs((forecast - target) * weights)) 51 | 52 | 53 | class smape_loss(nn.Module): 54 | def __init__(self): 55 | super(smape_loss, self).__init__() 56 | 57 | def forward(self, insample: t.Tensor, freq: int, 58 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 59 | """ 60 | sMAPE loss as defined in https://robjhyndman.com/hyndsight/smape/ (Makridakis 1993) 61 | 62 | :param forecast: Forecast values. Shape: batch, time 63 | :param target: Target values. Shape: batch, time 64 | :param mask: 0/1 mask. Shape: batch, time 65 | :return: Loss value 66 | """ 67 | return 200 * t.mean(divide_no_nan(t.abs(forecast - target), 68 | t.abs(forecast.data) + t.abs(target.data)) * mask) 69 | 70 | 71 | class mase_loss(nn.Module): 72 | def __init__(self): 73 | super(mase_loss, self).__init__() 74 | 75 | def forward(self, insample: t.Tensor, freq: int, 76 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 77 | """ 78 | MASE loss as defined in "Scaled Errors" https://robjhyndman.com/papers/mase.pdf 79 | 80 | :param insample: Insample values. Shape: batch, time_i 81 | :param freq: Frequency value 82 | :param forecast: Forecast values. Shape: batch, time_o 83 | :param target: Target values. Shape: batch, time_o 84 | :param mask: 0/1 mask. Shape: batch, time_o 85 | :return: Loss value 86 | """ 87 | masep = t.mean(t.abs(insample[:, freq:] - insample[:, :-freq]), dim=1) 88 | masked_masep_inv = divide_no_nan(mask, masep[:, None]) 89 | return t.mean(t.abs(target - forecast) * masked_masep_inv) 90 | 91 | class zero_shot_smape_loss(nn.Module): 92 | def __init__(self): 93 | super(zero_shot_smape_loss, self).__init__() 94 | def forward(self, pred, true): 95 | return t.mean(200 * t.abs(pred - true) / (t.abs(pred) + t.abs(true) + 1e-8)) -------------------------------------------------------------------------------- /utils/m4_summary.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | M4 Summary 17 | """ 18 | from collections import OrderedDict 19 | 20 | import numpy as np 21 | import pandas as pd 22 | 23 | from data_provider.m4 import M4Dataset 24 | from data_provider.m4 import M4Meta 25 | import os 26 | 27 | 28 | def group_values(values, groups, group_name): 29 | return np.array([v[~np.isnan(v)] for v in values[groups == group_name]]) 30 | 31 | 32 | def mase(forecast, insample, outsample, frequency): 33 | return np.mean(np.abs(forecast - outsample)) / np.mean(np.abs(insample[:-frequency] - insample[frequency:])) 34 | 35 | 36 | def smape_2(forecast, target): 37 | denom = np.abs(target) + np.abs(forecast) 38 | # divide by 1.0 instead of 0.0, in case when denom is zero the enumerator will be 0.0 anyway. 39 | denom[denom == 0.0] = 1.0 40 | return 200 * np.abs(forecast - target) / denom 41 | 42 | 43 | def mape(forecast, target): 44 | denom = np.abs(target) 45 | # divide by 1.0 instead of 0.0, in case when denom is zero the enumerator will be 0.0 anyway. 46 | denom[denom == 0.0] = 1.0 47 | return 100 * np.abs(forecast - target) / denom 48 | 49 | 50 | class M4Summary: 51 | def __init__(self, file_path, root_path): 52 | self.file_path = file_path 53 | self.training_set = M4Dataset.load(training=True, dataset_file=root_path) 54 | self.test_set = M4Dataset.load(training=False, dataset_file=root_path) 55 | self.naive_path = os.path.join(root_path, 'submission-Naive2.csv') 56 | 57 | def evaluate(self): 58 | """ 59 | Evaluate forecasts using M4 test dataset. 60 | 61 | :param forecast: Forecasts. Shape: timeseries, time. 62 | :return: sMAPE and OWA grouped by seasonal patterns. 63 | """ 64 | grouped_owa = OrderedDict() 65 | 66 | naive2_forecasts = pd.read_csv(self.naive_path).values[:, 1:].astype(np.float32) 67 | naive2_forecasts = np.array([v[~np.isnan(v)] for v in naive2_forecasts]) 68 | 69 | model_mases = {} 70 | naive2_smapes = {} 71 | naive2_mases = {} 72 | grouped_smapes = {} 73 | grouped_mapes = {} 74 | for group_name in M4Meta.seasonal_patterns: 75 | file_name = self.file_path + group_name + "_forecast.csv" 76 | if os.path.exists(file_name): 77 | model_forecast = pd.read_csv(file_name).values 78 | 79 | naive2_forecast = group_values(naive2_forecasts, self.test_set.groups, group_name) 80 | target = group_values(self.test_set.values, self.test_set.groups, group_name) 81 | # all timeseries within group have same frequency 82 | frequency = self.training_set.frequencies[self.test_set.groups == group_name][0] 83 | insample = group_values(self.training_set.values, self.test_set.groups, group_name) 84 | 85 | model_mases[group_name] = np.mean([mase(forecast=model_forecast[i], 86 | insample=insample[i], 87 | outsample=target[i], 88 | frequency=frequency) for i in range(len(model_forecast))]) 89 | naive2_mases[group_name] = np.mean([mase(forecast=naive2_forecast[i], 90 | insample=insample[i], 91 | outsample=target[i], 92 | frequency=frequency) for i in range(len(model_forecast))]) 93 | 94 | naive2_smapes[group_name] = np.mean(smape_2(naive2_forecast, target)) 95 | grouped_smapes[group_name] = np.mean(smape_2(forecast=model_forecast, target=target)) 96 | grouped_mapes[group_name] = np.mean(mape(forecast=model_forecast, target=target)) 97 | 98 | grouped_smapes = self.summarize_groups(grouped_smapes) 99 | grouped_mapes = self.summarize_groups(grouped_mapes) 100 | grouped_model_mases = self.summarize_groups(model_mases) 101 | grouped_naive2_smapes = self.summarize_groups(naive2_smapes) 102 | grouped_naive2_mases = self.summarize_groups(naive2_mases) 103 | for k in grouped_model_mases.keys(): 104 | grouped_owa[k] = (grouped_model_mases[k] / grouped_naive2_mases[k] + 105 | grouped_smapes[k] / grouped_naive2_smapes[k]) / 2 106 | 107 | def round_all(d): 108 | return dict(map(lambda kv: (kv[0], np.round(kv[1], 3)), d.items())) 109 | 110 | return round_all(grouped_smapes), round_all(grouped_owa), round_all(grouped_mapes), round_all( 111 | grouped_model_mases) 112 | 113 | def summarize_groups(self, scores): 114 | """ 115 | Re-group scores respecting M4 rules. 116 | :param scores: Scores per group. 117 | :return: Grouped scores. 118 | """ 119 | scores_summary = OrderedDict() 120 | 121 | def group_count(group_name): 122 | return len(np.where(self.test_set.groups == group_name)[0]) 123 | 124 | weighted_score = {} 125 | for g in ['Yearly', 'Quarterly', 'Monthly']: 126 | weighted_score[g] = scores[g] * group_count(g) 127 | scores_summary[g] = scores[g] 128 | 129 | others_score = 0 130 | others_count = 0 131 | for g in ['Weekly', 'Daily', 'Hourly']: 132 | others_score += scores[g] * group_count(g) 133 | others_count += group_count(g) 134 | weighted_score['Others'] = others_score 135 | scores_summary['Others'] = others_score / others_count 136 | 137 | average = np.sum(list(weighted_score.values())) / len(self.test_set.groups) 138 | scores_summary['Average'] = average 139 | 140 | return scores_summary 141 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 11 | return (u / d).mean(-1) 12 | 13 | 14 | def MAE(pred, true): 15 | return np.mean(np.abs(pred - true)) 16 | 17 | 18 | def MSE(pred, true): 19 | return np.mean((pred - true) ** 2) 20 | 21 | 22 | def RMSE(pred, true): 23 | return np.sqrt(MSE(pred, true)) 24 | 25 | 26 | def MAPE(pred, true): 27 | return np.mean(np.abs((pred - true) / true)) 28 | 29 | 30 | def MSPE(pred, true): 31 | return np.mean(np.square((pred - true) / true)) 32 | 33 | 34 | def metric(pred, true): 35 | mae = MAE(pred, true) 36 | mse = MSE(pred, true) 37 | rmse = RMSE(pred, true) 38 | mape = MAPE(pred, true) 39 | mspe = MSPE(pred, true) 40 | 41 | return mae, mse, rmse, mape, mspe 42 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import math 6 | import torch.distributed as dist 7 | from distutils.util import strtobool 8 | from datetime import datetime 9 | plt.switch_backend('agg') 10 | 11 | 12 | def adjust_learning_rate(optimizer, epoch, args): 13 | if args.lradj == 'type1': 14 | lr_adjust = {epoch: args.learning_rate * (0.5 ** (epoch - 1))} 15 | elif args.lradj == 'type2': 16 | lr_adjust = {epoch: args.learning_rate * (0.6 ** epoch)} 17 | elif args.lradj == "cosine": 18 | lr_adjust = {epoch: args.learning_rate /2 * (1 + math.cos(epoch / args.train_epochs * math.pi))} 19 | 20 | if epoch in lr_adjust.keys(): 21 | lr = lr_adjust[epoch] 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | if (args.use_multi_gpu and args.local_rank == 0) or not args.use_multi_gpu: 25 | print('next learning rate is {}'.format(lr)) 26 | 27 | class EarlyStopping: 28 | def __init__(self, args, verbose=False, delta=0): 29 | self.patience = args.patience 30 | self.verbose = verbose 31 | self.counter = 0 32 | self.best_score = None 33 | self.early_stop = False 34 | self.val_loss_min = np.Inf 35 | self.delta = delta 36 | self.use_multi_gpu = args.use_multi_gpu 37 | if self.use_multi_gpu: 38 | self.local_rank = args.local_rank 39 | else: 40 | self.local_rank = None 41 | 42 | def __call__(self, val_loss, model, path): 43 | score = -val_loss 44 | if self.best_score is None: 45 | self.best_score = score 46 | if self.verbose: 47 | if (self.use_multi_gpu and self.local_rank == 0) or not self.use_multi_gpu: 48 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).') 49 | self.val_loss_min = val_loss 50 | if self.use_multi_gpu: 51 | if self.local_rank == 0: 52 | self.save_checkpoint(val_loss, model, path) 53 | dist.barrier() 54 | else: 55 | self.save_checkpoint(val_loss, model, path) 56 | elif score < self.best_score + self.delta: 57 | self.counter += 1 58 | if (self.use_multi_gpu and self.local_rank == 0) or not self.use_multi_gpu: 59 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 60 | if self.counter >= self.patience: 61 | self.early_stop = True 62 | else: 63 | self.best_score = score 64 | if self.use_multi_gpu: 65 | if self.local_rank == 0: 66 | self.save_checkpoint(val_loss, model, path) 67 | dist.barrier() 68 | else: 69 | self.save_checkpoint(val_loss, model, path) 70 | if self.verbose: 71 | if (self.use_multi_gpu and self.local_rank == 0) or not self.use_multi_gpu: 72 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).') 73 | self.val_loss_min = val_loss 74 | self.counter = 0 75 | 76 | 77 | def save_checkpoint(self, val_loss, model, path): 78 | param_grad_dic = { 79 | k: v.requires_grad for (k, v) in model.named_parameters() 80 | } 81 | state_dict = model.state_dict() 82 | for k in list(state_dict.keys()): 83 | if k in param_grad_dic.keys() and not param_grad_dic[k]: 84 | # delete parameters that do not require gradient 85 | del state_dict[k] 86 | torch.save(state_dict, path + '/' + f'checkpoint.pth') 87 | 88 | 89 | 90 | class dotdict(dict): 91 | """dot.notation access to dictionary attributes""" 92 | __getattr__ = dict.get 93 | __setattr__ = dict.__setitem__ 94 | __delattr__ = dict.__delitem__ 95 | 96 | 97 | class StandardScaler(): 98 | def __init__(self, mean, std): 99 | self.mean = mean 100 | self.std = std 101 | 102 | def transform(self, data): 103 | return (data - self.mean) / self.std 104 | 105 | def inverse_transform(self, data): 106 | return (data * self.std) + self.mean 107 | 108 | 109 | def visual(true, preds=None, name='./pic/test.pdf'): 110 | """ 111 | Results visualization 112 | """ 113 | plt.figure() 114 | plt.plot(true, label='GroundTruth', linewidth=2) 115 | if preds is not None: 116 | plt.plot(preds, label='Prediction', linewidth=2) 117 | plt.legend() 118 | plt.savefig(name, bbox_inches='tight') 119 | 120 | 121 | def adjustment(gt, pred): 122 | anomaly_state = False 123 | for i in range(len(gt)): 124 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 125 | anomaly_state = True 126 | for j in range(i, 0, -1): 127 | if gt[j] == 0: 128 | break 129 | else: 130 | if pred[j] == 0: 131 | pred[j] = 1 132 | for j in range(i, len(gt)): 133 | if gt[j] == 0: 134 | break 135 | else: 136 | if pred[j] == 0: 137 | pred[j] = 1 138 | elif gt[i] == 0: 139 | anomaly_state = False 140 | if anomaly_state: 141 | pred[i] = 1 142 | return gt, pred 143 | 144 | 145 | def cal_accuracy(y_pred, y_true): 146 | return np.mean(y_pred == y_true) 147 | 148 | def convert_tsf_to_dataframe( 149 | full_file_path_and_name, 150 | replace_missing_vals_with="NaN", 151 | value_column_name="series_value", 152 | ): 153 | col_names = [] 154 | col_types = [] 155 | all_data = {} 156 | line_count = 0 157 | frequency = None 158 | forecast_horizon = None 159 | contain_missing_values = None 160 | contain_equal_length = None 161 | found_data_tag = False 162 | found_data_section = False 163 | started_reading_data_section = False 164 | 165 | with open(full_file_path_and_name, "r", encoding="cp1252") as file: 166 | for line in file: 167 | # Strip white space from start/end of line 168 | line = line.strip() 169 | 170 | if line: 171 | if line.startswith("@"): # Read meta-data 172 | if not line.startswith("@data"): 173 | line_content = line.split(" ") 174 | if line.startswith("@attribute"): 175 | if ( 176 | len(line_content) != 3 177 | ): # Attributes have both name and type 178 | raise Exception("Invalid meta-data specification.") 179 | 180 | col_names.append(line_content[1]) 181 | col_types.append(line_content[2]) 182 | else: 183 | if ( 184 | len(line_content) != 2 185 | ): # Other meta-data have only values 186 | raise Exception("Invalid meta-data specification.") 187 | 188 | if line.startswith("@frequency"): 189 | frequency = line_content[1] 190 | elif line.startswith("@horizon"): 191 | forecast_horizon = int(line_content[1]) 192 | elif line.startswith("@missing"): 193 | contain_missing_values = bool( 194 | strtobool(line_content[1]) 195 | ) 196 | elif line.startswith("@equallength"): 197 | contain_equal_length = bool(strtobool(line_content[1])) 198 | 199 | else: 200 | if len(col_names) == 0: 201 | raise Exception( 202 | "Missing attribute section. Attribute section must come before data." 203 | ) 204 | 205 | found_data_tag = True 206 | elif not line.startswith("#"): 207 | if len(col_names) == 0: 208 | raise Exception( 209 | "Missing attribute section. Attribute section must come before data." 210 | ) 211 | elif not found_data_tag: 212 | raise Exception("Missing @data tag.") 213 | else: 214 | if not started_reading_data_section: 215 | started_reading_data_section = True 216 | found_data_section = True 217 | all_series = [] 218 | 219 | for col in col_names: 220 | all_data[col] = [] 221 | 222 | full_info = line.split(":") 223 | 224 | if len(full_info) != (len(col_names) + 1): 225 | raise Exception("Missing attributes/values in series.") 226 | 227 | series = full_info[len(full_info) - 1] 228 | series = series.split(",") 229 | 230 | if len(series) == 0: 231 | raise Exception( 232 | "A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series. Missing values should be indicated with ? symbol" 233 | ) 234 | 235 | numeric_series = [] 236 | 237 | for val in series: 238 | if val == "?": 239 | numeric_series.append(replace_missing_vals_with) 240 | else: 241 | numeric_series.append(float(val)) 242 | 243 | if numeric_series.count(replace_missing_vals_with) == len( 244 | numeric_series 245 | ): 246 | raise Exception( 247 | "All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series." 248 | ) 249 | 250 | all_series.append(pd.Series(numeric_series).array) 251 | 252 | for i in range(len(col_names)): 253 | att_val = None 254 | if col_types[i] == "numeric": 255 | att_val = int(full_info[i]) 256 | elif col_types[i] == "string": 257 | att_val = str(full_info[i]) 258 | elif col_types[i] == "date": 259 | att_val = datetime.strptime( 260 | full_info[i], "%Y-%m-%d %H-%M-%S" 261 | ) 262 | else: 263 | raise Exception( 264 | "Invalid attribute type." 265 | ) # Currently, the code supports only numeric, string and date types. Extend this as required. 266 | 267 | if att_val is None: 268 | raise Exception("Invalid attribute value.") 269 | else: 270 | all_data[col_names[i]].append(att_val) 271 | 272 | line_count = line_count + 1 273 | 274 | if line_count == 0: 275 | raise Exception("Empty file.") 276 | if len(col_names) == 0: 277 | raise Exception("Missing attribute section.") 278 | if not found_data_section: 279 | raise Exception("Missing series information under data section.") 280 | 281 | all_data[value_column_name] = all_series 282 | loaded_data = pd.DataFrame(all_data) 283 | 284 | return ( 285 | loaded_data, 286 | frequency, 287 | forecast_horizon, 288 | contain_missing_values, 289 | contain_equal_length, 290 | ) 291 | 292 | class dotdict(dict): 293 | """dot.notation access to dictionary attributes""" 294 | __getattr__ = dict.get 295 | __setattr__ = dict.__setitem__ 296 | __delattr__ = dict.__delitem__ --------------------------------------------------------------------------------