├── .gitignore ├── LICENSE ├── README.md ├── data_provider ├── __init__.py ├── data_factory.py └── data_loader.py ├── environment.yml ├── exp ├── __init__.py ├── exp_basic.py └── exp_main.py ├── figures ├── arch.png ├── da.png ├── main_results.png ├── promotion.png ├── showcases.png └── ss.png ├── layers ├── AutoCorrelation.py ├── Autoformer_EncDec.py ├── Embed.py ├── SelfAttention_Family.py ├── Transformer_EncDec.py └── __init__.py ├── models ├── Autoformer.py ├── Informer.py ├── Transformer.py └── __init__.py ├── ns_layers ├── AutoCorrelation.py ├── Autoformer_EncDec.py ├── SelfAttention_Family.py ├── Transformer_EncDec.py └── __init__.py ├── ns_models ├── __init__.py ├── ns_Autoformer.py ├── ns_Informer.py └── ns_Transformer.py ├── requirements.txt ├── run.py ├── scripts ├── ECL_script │ ├── Autoformer.sh │ ├── Informer.sh │ ├── Transformer.sh │ ├── ns_Autoformer.sh │ ├── ns_Informer.sh │ └── ns_Transformer.sh ├── ETT_script │ ├── Autoformer.sh │ ├── Informer.sh │ ├── Transformer.sh │ ├── ns_Autoformer.sh │ ├── ns_Informer.sh │ └── ns_Transformer.sh ├── Exchange_script │ ├── Autoformer.sh │ ├── Informer.sh │ ├── Transformer.sh │ ├── ns_Autoformer.sh │ ├── ns_Informer.sh │ └── ns_Transformer.sh ├── ILI_script │ ├── Autoformer.sh │ ├── Informer.sh │ ├── Transformer.sh │ ├── ns_Autoformer.sh │ ├── ns_Informer.sh │ └── ns_Transformer.sh ├── Traffic_script │ ├── Autoformer.sh │ ├── Informer.sh │ ├── Transformer.sh │ ├── ns_Autoformer.sh │ ├── ns_Informer.sh │ └── ns_Transformer.sh └── Weather_script │ ├── Autoformer.sh │ ├── Informer.sh │ ├── Transformer.sh │ ├── ns_Autoformer.sh │ ├── ns_Informer.sh │ └── ns_Transformer.sh └── utils ├── __init__.py ├── masking.py ├── metrics.py ├── timefeatures.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | */.DS_Store 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 THUML @ Tsinghua University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Non-stationary Transformers 2 | 3 | This is the codebase for the paper: 4 | [Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting](https://arxiv.org/abs/2205.14415), NeurIPS 2022. [[Slides]](https://cloud.tsinghua.edu.cn/f/8d6ce7b18d3c468190e7/), [[Poster]](https://cloud.tsinghua.edu.cn/f/6eea66909aa7465ca9a4/). 5 | 6 | :triangular_flag_on_post: **News** (2023.02) Non-stationary Transformer has been included in [[Time-Series-Library]](https://github.com/thuml/Time-Series-Library), which covers long- and short-term forecasting, imputation, anomaly detection, and classification. 7 | 8 | ## Discussions 9 | 10 | There are already several discussions about our paper, we appreciate a lot for their valuable comments and efforts: [[Official]](https://mp.weixin.qq.com/s/LkpkTiNBVBYA-FqzAdy4dw), [[OpenReview]](https://openreview.net/forum?id=ucNDIDRNjjv), [[Zhihu]](https://zhuanlan.zhihu.com/p/535931701). 11 | 12 | ## Architecture 13 | 14 | ![arch](./figures/arch.png) 15 | 16 | ### Series Stationarization 17 | 18 | Series Stationarization unifies the statistics of each input and converts the output with restored statistics for better predictability. 19 | 20 | ![arch](./figures/ss.png) 21 | 22 | ### De-stationary Attention 23 | 24 | De-stationary Attention is devised to recover the intrinsic non-stationary information into temporal dependencies by approximating distinguishable attentions learned from unstationarized series. 25 | 26 | ![arch](./figures/da.png) 27 | 28 | 29 | ## Showcases 30 | 31 | ![arch](./figures/showcases.png) 32 | 33 | ## Preparation 34 | 35 | 1. Install Python 3.7 and neccessary dependencies. 36 | ``` 37 | pip install -r requirements.txt 38 | ``` 39 | 2. All the six benchmark datasets can be obtained from [Google Drive](https://drive.google.com/file/d/1CC4ZrUD4EKncndzgy5PSTzOPSqcuyqqj/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b8f4a78a39874ac9893e/?dl=1). 40 | 41 | ## Training scripts 42 | 43 | ### Non-stationary Transformer 44 | 45 | We provide the Non-stationary Transformer experiment scripts and hyperparameters of all benchmark dataset under the folder `./scripts`. 46 | 47 | ```bash 48 | # Transformer with our framework 49 | bash ./scripts/ECL_script/ns_Transformer.sh 50 | bash ./scripts/Traffic_script/ns_Transformer.sh 51 | bash ./scripts/Weather_script/ns_Transformer.sh 52 | bash ./scripts/ILI_script/ns_Transformer.sh 53 | bash ./scripts/Exchange_script/ns_Transformer.sh 54 | bash ./scripts/ETT_script/ns_Transformer.sh 55 | ``` 56 | 57 | ```bash 58 | # Transformer baseline 59 | bash ./scripts/ECL_script/Transformer.sh 60 | bash ./scripts/Traffic_script/Transformer.sh 61 | bash ./scripts/Weather_script/Transformer.sh 62 | bash ./scripts/ILI_script/Transformer.sh 63 | bash ./scripts/Exchange_script/Transformer.sh 64 | bash ./scripts/ETT_script/Transformer.sh 65 | ``` 66 | 67 | ### Non-stationary framework to promote other Attention-based models 68 | 69 | We also provide the scripts for other Attention-based models (Informer, Autoformer), for example: 70 | 71 | ```bash 72 | # Informer promoted by our Non-stationary framework 73 | bash ./scripts/Exchange_script/Informer.sh 74 | bash ./scripts/Exchange_script/ns_Informer.sh 75 | 76 | # Autoformer promoted by our Non-stationary framework 77 | bash ./scripts/Weather_script/Autoformer.sh 78 | bash ./scripts/Weather_script/ns_Autoformer.sh 79 | ``` 80 | 81 | ## Experiment Results 82 | 83 | ### Main Results 84 | 85 | For multivariate forecasting results, the vanilla Transformer equipped with our framework consistently achieves state-of-the-art performance in all six benchmarks and prediction lengths. 86 | 87 | ![arch](./figures/main_results.png) 88 | 89 | ### Model Promotion 90 | 91 | By applying our framework to six mainstream Attention-based models. Our method consistently improves the forecasting ability. Overall, it achieves averaged **49.43%** promotion on Transformer, **47.34%** on Informer, **46.89%** on Reformer, **10.57%** on Autoformer, **5.17%** on ETSformer and **4.51%** on FEDformer, making each of them surpass previous state-of-the-art. 92 | 93 | ![arch](./figures/promotion.png) 94 | 95 | ## Future Work 96 | 97 | We will keep equip the following models with our proposed Non-stationary Transformers framework: 98 | 99 | - [x] Transformer 100 | - [x] iTransformer 101 | - [x] Informer 102 | - [x] Autoformer 103 | - [x] FEDformer 104 | - [x] Crossformer 105 | - [x] Reformer 106 | - [x] ...... 107 | 108 | Note: Series Stationarization as an architecture-free module has been widely applied for addressing non-stationarity in time series. Please refer to [time-series-library](https://github.com/thuml/Time-Series-Library/tree/main/models) for the implementationdetails. 109 | 110 | ## Citation 111 | 112 | If you find this repo useful, please cite our paper. 113 | 114 | ``` 115 | @article{liu2022non, 116 | title={Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting}, 117 | author={Liu, Yong and Wu, Haixu and Wang, Jianmin and Long, Mingsheng}, 118 | booktitle={Advances in Neural Information Processing Systems}, 119 | year={2022} 120 | } 121 | ``` 122 | 123 | ## Contact 124 | 125 | If you have any questions or want to use the code, please contact liuyong21@mails.tsinghua.edu.cn. 126 | 127 | 128 | ## Acknowledgement 129 | 130 | This repo is built on the [Autoformer repo](https://github.com/thuml/Autoformer), we appreciate the authors a lot for their valuable code and efforts. 131 | -------------------------------------------------------------------------------- /data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred 2 | from torch.utils.data import DataLoader 3 | 4 | data_dict = { 5 | 'ETTh1': Dataset_ETT_hour, 6 | 'ETTh2': Dataset_ETT_hour, 7 | 'ETTm1': Dataset_ETT_minute, 8 | 'ETTm2': Dataset_ETT_minute, 9 | 'custom': Dataset_Custom, 10 | } 11 | 12 | 13 | def data_provider(args, flag): 14 | Data = data_dict[args.data] 15 | timeenc = 0 if args.embed != 'timeF' else 1 16 | 17 | if flag == 'test': 18 | shuffle_flag = False 19 | drop_last = False 20 | batch_size = 1 21 | freq = args.freq 22 | elif flag == 'pred': 23 | shuffle_flag = False 24 | drop_last = False 25 | batch_size = 1 26 | freq = args.freq 27 | Data = Dataset_Pred 28 | else: 29 | shuffle_flag = True 30 | drop_last = True 31 | batch_size = args.batch_size 32 | freq = args.freq 33 | 34 | data_set = Data( 35 | root_path=args.root_path, 36 | data_path=args.data_path, 37 | flag=flag, 38 | size=[args.seq_len, args.label_len, args.pred_len], 39 | features=args.features, 40 | target=args.target, 41 | timeenc=timeenc, 42 | freq=freq 43 | ) 44 | print(flag, len(data_set)) 45 | data_loader = DataLoader( 46 | data_set, 47 | batch_size=batch_size, 48 | shuffle=shuffle_flag, 49 | num_workers=args.num_workers, 50 | drop_last=drop_last) 51 | return data_set, data_loader 52 | -------------------------------------------------------------------------------- /data_provider/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | from sklearn.preprocessing import StandardScaler 8 | from utils.timefeatures import time_features 9 | import warnings 10 | 11 | warnings.filterwarnings('ignore') 12 | 13 | 14 | class Dataset_ETT_hour(Dataset): 15 | def __init__(self, root_path, flag='train', size=None, 16 | features='S', data_path='ETTh1.csv', 17 | target='OT', scale=True, timeenc=0, freq='h'): 18 | # size [seq_len, label_len, pred_len] 19 | # info 20 | if size == None: 21 | self.seq_len = 24 * 4 * 4 22 | self.label_len = 24 * 4 23 | self.pred_len = 24 * 4 24 | else: 25 | self.seq_len = size[0] 26 | self.label_len = size[1] 27 | self.pred_len = size[2] 28 | # init 29 | assert flag in ['train', 'test', 'val'] 30 | type_map = {'train': 0, 'val': 1, 'test': 2} 31 | self.set_type = type_map[flag] 32 | 33 | self.features = features 34 | self.target = target 35 | self.scale = scale 36 | self.timeenc = timeenc 37 | self.freq = freq 38 | 39 | self.root_path = root_path 40 | self.data_path = data_path 41 | self.__read_data__() 42 | 43 | def __read_data__(self): 44 | self.scaler = StandardScaler() 45 | df_raw = pd.read_csv(os.path.join(self.root_path, 46 | self.data_path)) 47 | 48 | border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len] 49 | border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24] 50 | border1 = border1s[self.set_type] 51 | border2 = border2s[self.set_type] 52 | 53 | if self.features == 'M' or self.features == 'MS': 54 | cols_data = df_raw.columns[1:] 55 | df_data = df_raw[cols_data] 56 | elif self.features == 'S': 57 | df_data = df_raw[[self.target]] 58 | 59 | if self.scale: 60 | train_data = df_data[border1s[0]:border2s[0]] 61 | self.scaler.fit(train_data.values) 62 | data = self.scaler.transform(df_data.values) 63 | else: 64 | data = df_data.values 65 | 66 | df_stamp = df_raw[['date']][border1:border2] 67 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 68 | if self.timeenc == 0: 69 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 70 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 71 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 72 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 73 | data_stamp = df_stamp.drop(['date'], 1).values 74 | elif self.timeenc == 1: 75 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 76 | data_stamp = data_stamp.transpose(1, 0) 77 | 78 | self.data_x = data[border1:border2] 79 | self.data_y = data[border1:border2] 80 | self.data_stamp = data_stamp 81 | 82 | def __getitem__(self, index): 83 | s_begin = index 84 | s_end = s_begin + self.seq_len 85 | r_begin = s_end - self.label_len 86 | r_end = r_begin + self.label_len + self.pred_len 87 | 88 | seq_x = self.data_x[s_begin:s_end] 89 | seq_y = self.data_y[r_begin:r_end] 90 | seq_x_mark = self.data_stamp[s_begin:s_end] 91 | seq_y_mark = self.data_stamp[r_begin:r_end] 92 | 93 | return seq_x, seq_y, seq_x_mark, seq_y_mark 94 | 95 | def __len__(self): 96 | return len(self.data_x) - self.seq_len - self.pred_len + 1 97 | 98 | def inverse_transform(self, data): 99 | return self.scaler.inverse_transform(data) 100 | 101 | 102 | class Dataset_ETT_minute(Dataset): 103 | def __init__(self, root_path, flag='train', size=None, 104 | features='S', data_path='ETTm1.csv', 105 | target='OT', scale=True, timeenc=0, freq='t'): 106 | # size [seq_len, label_len, pred_len] 107 | # info 108 | if size == None: 109 | self.seq_len = 24 * 4 * 4 110 | self.label_len = 24 * 4 111 | self.pred_len = 24 * 4 112 | else: 113 | self.seq_len = size[0] 114 | self.label_len = size[1] 115 | self.pred_len = size[2] 116 | # init 117 | assert flag in ['train', 'test', 'val'] 118 | type_map = {'train': 0, 'val': 1, 'test': 2} 119 | self.set_type = type_map[flag] 120 | 121 | self.features = features 122 | self.target = target 123 | self.scale = scale 124 | self.timeenc = timeenc 125 | self.freq = freq 126 | 127 | self.root_path = root_path 128 | self.data_path = data_path 129 | self.__read_data__() 130 | 131 | def __read_data__(self): 132 | self.scaler = StandardScaler() 133 | df_raw = pd.read_csv(os.path.join(self.root_path, 134 | self.data_path)) 135 | 136 | border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len] 137 | border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4] 138 | border1 = border1s[self.set_type] 139 | border2 = border2s[self.set_type] 140 | 141 | if self.features == 'M' or self.features == 'MS': 142 | cols_data = df_raw.columns[1:] 143 | df_data = df_raw[cols_data] 144 | elif self.features == 'S': 145 | df_data = df_raw[[self.target]] 146 | 147 | if self.scale: 148 | train_data = df_data[border1s[0]:border2s[0]] 149 | self.scaler.fit(train_data.values) 150 | data = self.scaler.transform(df_data.values) 151 | else: 152 | data = df_data.values 153 | 154 | df_stamp = df_raw[['date']][border1:border2] 155 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 156 | if self.timeenc == 0: 157 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 158 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 159 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 160 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 161 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) 162 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) 163 | data_stamp = df_stamp.drop(['date'], 1).values 164 | elif self.timeenc == 1: 165 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 166 | data_stamp = data_stamp.transpose(1, 0) 167 | 168 | self.data_x = data[border1:border2] 169 | self.data_y = data[border1:border2] 170 | self.data_stamp = data_stamp 171 | 172 | def __getitem__(self, index): 173 | s_begin = index 174 | s_end = s_begin + self.seq_len 175 | r_begin = s_end - self.label_len 176 | r_end = r_begin + self.label_len + self.pred_len 177 | 178 | seq_x = self.data_x[s_begin:s_end] 179 | seq_y = self.data_y[r_begin:r_end] 180 | seq_x_mark = self.data_stamp[s_begin:s_end] 181 | seq_y_mark = self.data_stamp[r_begin:r_end] 182 | 183 | return seq_x, seq_y, seq_x_mark, seq_y_mark 184 | 185 | def __len__(self): 186 | return len(self.data_x) - self.seq_len - self.pred_len + 1 187 | 188 | def inverse_transform(self, data): 189 | return self.scaler.inverse_transform(data) 190 | 191 | 192 | class Dataset_Custom(Dataset): 193 | def __init__(self, root_path, flag='train', size=None, 194 | features='S', data_path='ETTh1.csv', 195 | target='OT', scale=True, timeenc=0, freq='h'): 196 | # size [seq_len, label_len, pred_len] 197 | # info 198 | if size == None: 199 | self.seq_len = 24 * 4 * 4 200 | self.label_len = 24 * 4 201 | self.pred_len = 24 * 4 202 | else: 203 | self.seq_len = size[0] 204 | self.label_len = size[1] 205 | self.pred_len = size[2] 206 | # init 207 | assert flag in ['train', 'test', 'val'] 208 | type_map = {'train': 0, 'val': 1, 'test': 2} 209 | self.set_type = type_map[flag] 210 | 211 | self.features = features 212 | self.target = target 213 | self.scale = scale 214 | self.timeenc = timeenc 215 | self.freq = freq 216 | 217 | self.root_path = root_path 218 | self.data_path = data_path 219 | self.__read_data__() 220 | 221 | def __read_data__(self): 222 | self.scaler = StandardScaler() 223 | df_raw = pd.read_csv(os.path.join(self.root_path, 224 | self.data_path)) 225 | 226 | ''' 227 | df_raw.columns: ['date', ...(other features), target feature] 228 | ''' 229 | cols = list(df_raw.columns) 230 | cols.remove(self.target) 231 | cols.remove('date') 232 | df_raw = df_raw[['date'] + cols + [self.target]] 233 | # print(cols) 234 | num_train = int(len(df_raw) * 0.7) 235 | num_test = int(len(df_raw) * 0.2) 236 | num_vali = len(df_raw) - num_train - num_test 237 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 238 | border2s = [num_train, num_train + num_vali, len(df_raw)] 239 | border1 = border1s[self.set_type] 240 | border2 = border2s[self.set_type] 241 | 242 | if self.features == 'M' or self.features == 'MS': 243 | cols_data = df_raw.columns[1:] 244 | df_data = df_raw[cols_data] 245 | elif self.features == 'S': 246 | df_data = df_raw[[self.target]] 247 | 248 | if self.scale: 249 | train_data = df_data[border1s[0]:border2s[0]] 250 | self.scaler.fit(train_data.values) 251 | data = self.scaler.transform(df_data.values) 252 | else: 253 | data = df_data.values 254 | 255 | df_stamp = df_raw[['date']][border1:border2] 256 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 257 | if self.timeenc == 0: 258 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 259 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 260 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 261 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 262 | data_stamp = df_stamp.drop(['date'], 1).values 263 | elif self.timeenc == 1: 264 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 265 | data_stamp = data_stamp.transpose(1, 0) 266 | 267 | self.data_x = data[border1:border2] 268 | self.data_y = data[border1:border2] 269 | self.data_stamp = data_stamp 270 | 271 | def __getitem__(self, index): 272 | s_begin = index 273 | s_end = s_begin + self.seq_len 274 | r_begin = s_end - self.label_len 275 | r_end = r_begin + self.label_len + self.pred_len 276 | 277 | seq_x = self.data_x[s_begin:s_end] 278 | seq_y = self.data_y[r_begin:r_end] 279 | seq_x_mark = self.data_stamp[s_begin:s_end] 280 | seq_y_mark = self.data_stamp[r_begin:r_end] 281 | 282 | return seq_x, seq_y, seq_x_mark, seq_y_mark 283 | 284 | def __len__(self): 285 | return len(self.data_x) - self.seq_len - self.pred_len + 1 286 | 287 | def inverse_transform(self, data): 288 | return self.scaler.inverse_transform(data) 289 | 290 | 291 | class Dataset_Pred(Dataset): 292 | def __init__(self, root_path, flag='pred', size=None, 293 | features='S', data_path='ETTh1.csv', 294 | target='OT', scale=True, inverse=False, timeenc=0, freq='15min', cols=None): 295 | # size [seq_len, label_len, pred_len] 296 | # info 297 | if size == None: 298 | self.seq_len = 24 * 4 * 4 299 | self.label_len = 24 * 4 300 | self.pred_len = 24 * 4 301 | else: 302 | self.seq_len = size[0] 303 | self.label_len = size[1] 304 | self.pred_len = size[2] 305 | # init 306 | assert flag in ['pred'] 307 | 308 | self.features = features 309 | self.target = target 310 | self.scale = scale 311 | self.inverse = inverse 312 | self.timeenc = timeenc 313 | self.freq = freq 314 | self.cols = cols 315 | self.root_path = root_path 316 | self.data_path = data_path 317 | self.__read_data__() 318 | 319 | def __read_data__(self): 320 | self.scaler = StandardScaler() 321 | df_raw = pd.read_csv(os.path.join(self.root_path, 322 | self.data_path)) 323 | ''' 324 | df_raw.columns: ['date', ...(other features), target feature] 325 | ''' 326 | if self.cols: 327 | cols = self.cols.copy() 328 | cols.remove(self.target) 329 | else: 330 | cols = list(df_raw.columns) 331 | cols.remove(self.target) 332 | cols.remove('date') 333 | df_raw = df_raw[['date'] + cols + [self.target]] 334 | border1 = len(df_raw) - self.seq_len 335 | border2 = len(df_raw) 336 | 337 | if self.features == 'M' or self.features == 'MS': 338 | cols_data = df_raw.columns[1:] 339 | df_data = df_raw[cols_data] 340 | elif self.features == 'S': 341 | df_data = df_raw[[self.target]] 342 | 343 | if self.scale: 344 | self.scaler.fit(df_data.values) 345 | data = self.scaler.transform(df_data.values) 346 | else: 347 | data = df_data.values 348 | 349 | tmp_stamp = df_raw[['date']][border1:border2] 350 | tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date) 351 | pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq) 352 | 353 | df_stamp = pd.DataFrame(columns=['date']) 354 | df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:]) 355 | if self.timeenc == 0: 356 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 357 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 358 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 359 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 360 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) 361 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) 362 | data_stamp = df_stamp.drop(['date'], 1).values 363 | elif self.timeenc == 1: 364 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 365 | data_stamp = data_stamp.transpose(1, 0) 366 | 367 | self.data_x = data[border1:border2] 368 | if self.inverse: 369 | self.data_y = df_data.values[border1:border2] 370 | else: 371 | self.data_y = data[border1:border2] 372 | self.data_stamp = data_stamp 373 | 374 | def __getitem__(self, index): 375 | s_begin = index 376 | s_end = s_begin + self.seq_len 377 | r_begin = s_end - self.label_len 378 | r_end = r_begin + self.label_len + self.pred_len 379 | 380 | seq_x = self.data_x[s_begin:s_end] 381 | if self.inverse: 382 | seq_y = self.data_x[r_begin:r_begin + self.label_len] 383 | else: 384 | seq_y = self.data_y[r_begin:r_begin + self.label_len] 385 | seq_x_mark = self.data_stamp[s_begin:s_end] 386 | seq_y_mark = self.data_stamp[r_begin:r_end] 387 | 388 | return seq_x, seq_y, seq_x_mark, seq_y_mark 389 | 390 | def __len__(self): 391 | return len(self.data_x) - self.seq_len + 1 392 | 393 | def inverse_transform(self, data): 394 | return self.scaler.inverse_transform(data) 395 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: nsformer 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - python=3.7 8 | - pip 9 | - matplotlib 10 | - numpy 11 | - pandas 12 | - scikit-learn 13 | - pip: 14 | - torch==1.9.0 15 | -------------------------------------------------------------------------------- /exp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/exp/__init__.py -------------------------------------------------------------------------------- /exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class Exp_Basic(object): 7 | def __init__(self, args): 8 | self.args = args 9 | self.device = self._acquire_device() 10 | self.model = self._build_model().to(self.device) 11 | 12 | def _build_model(self): 13 | raise NotImplementedError 14 | return None 15 | 16 | def _acquire_device(self): 17 | if self.args.use_gpu: 18 | os.environ["CUDA_VISIBLE_DEVICES"] = str( 19 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices 20 | device = torch.device('cuda:{}'.format(self.args.gpu)) 21 | print('Use GPU: cuda:{}'.format(self.args.gpu)) 22 | else: 23 | device = torch.device('cpu') 24 | print('Use CPU') 25 | return device 26 | 27 | def _get_data(self): 28 | pass 29 | 30 | def vali(self): 31 | pass 32 | 33 | def train(self): 34 | pass 35 | 36 | def test(self): 37 | pass 38 | -------------------------------------------------------------------------------- /exp/exp_main.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_factory import data_provider 2 | from exp.exp_basic import Exp_Basic 3 | from models import Transformer, Informer, Autoformer 4 | from ns_models import ns_Transformer, ns_Informer, ns_Autoformer 5 | from utils.tools import EarlyStopping, adjust_learning_rate, visual 6 | from utils.metrics import metric 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch import optim 12 | 13 | import os 14 | import time 15 | 16 | import warnings 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | 20 | warnings.filterwarnings('ignore') 21 | 22 | 23 | class Exp_Main(Exp_Basic): 24 | def __init__(self, args): 25 | super(Exp_Main, self).__init__(args) 26 | 27 | def _build_model(self): 28 | model_dict = { 29 | 'Transformer': Transformer, 30 | 'Informer': Informer, 31 | 'Autoformer': Autoformer, 32 | 'ns_Transformer': ns_Transformer, 33 | 'ns_Informer': ns_Informer, 34 | 'ns_Autoformer': ns_Autoformer, 35 | } 36 | model = model_dict[self.args.model].Model(self.args).float() 37 | 38 | if self.args.use_multi_gpu and self.args.use_gpu: 39 | model = nn.DataParallel(model, device_ids=self.args.device_ids) 40 | return model 41 | 42 | def _get_data(self, flag): 43 | data_set, data_loader = data_provider(self.args, flag) 44 | return data_set, data_loader 45 | 46 | def _select_optimizer(self): 47 | model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate) 48 | return model_optim 49 | 50 | def _select_criterion(self): 51 | criterion = nn.MSELoss() 52 | return criterion 53 | 54 | def vali(self, vali_data, vali_loader, criterion): 55 | total_loss = [] 56 | self.model.eval() 57 | with torch.no_grad(): 58 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader): 59 | batch_x = batch_x.float().to(self.device) 60 | batch_y = batch_y.float() 61 | 62 | batch_x_mark = batch_x_mark.float().to(self.device) 63 | batch_y_mark = batch_y_mark.float().to(self.device) 64 | 65 | # decoder input 66 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() 67 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) 68 | # encoder - decoder 69 | if self.args.use_amp: 70 | with torch.cuda.amp.autocast(): 71 | if self.args.output_attention: 72 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] 73 | else: 74 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 75 | else: 76 | if self.args.output_attention: 77 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] 78 | else: 79 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 80 | f_dim = -1 if self.args.features == 'MS' else 0 81 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 82 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 83 | 84 | pred = outputs.detach().cpu() 85 | true = batch_y.detach().cpu() 86 | 87 | loss = criterion(pred, true) 88 | 89 | total_loss.append(loss) 90 | total_loss = np.average(total_loss) 91 | self.model.train() 92 | return total_loss 93 | 94 | def train(self, setting): 95 | train_data, train_loader = self._get_data(flag='train') 96 | vali_data, vali_loader = self._get_data(flag='val') 97 | test_data, test_loader = self._get_data(flag='test') 98 | 99 | path = os.path.join(self.args.checkpoints, setting) 100 | if not os.path.exists(path): 101 | os.makedirs(path) 102 | 103 | time_now = time.time() 104 | 105 | train_steps = len(train_loader) 106 | early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) 107 | 108 | model_optim = self._select_optimizer() 109 | criterion = self._select_criterion() 110 | 111 | if self.args.use_amp: 112 | scaler = torch.cuda.amp.GradScaler() 113 | 114 | for epoch in range(self.args.train_epochs): 115 | iter_count = 0 116 | train_loss = [] 117 | 118 | self.model.train() 119 | epoch_time = time.time() 120 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): 121 | iter_count += 1 122 | model_optim.zero_grad() 123 | batch_x = batch_x.float().to(self.device) 124 | 125 | batch_y = batch_y.float().to(self.device) 126 | batch_x_mark = batch_x_mark.float().to(self.device) 127 | batch_y_mark = batch_y_mark.float().to(self.device) 128 | 129 | # decoder input 130 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() 131 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) 132 | 133 | # encoder - decoder 134 | if self.args.use_amp: 135 | with torch.cuda.amp.autocast(): 136 | if self.args.output_attention: 137 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] 138 | else: 139 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 140 | 141 | f_dim = -1 if self.args.features == 'MS' else 0 142 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 143 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 144 | loss = criterion(outputs, batch_y) 145 | train_loss.append(loss.item()) 146 | else: 147 | if self.args.output_attention: 148 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] 149 | else: 150 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 151 | 152 | f_dim = -1 if self.args.features == 'MS' else 0 153 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 154 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 155 | loss = criterion(outputs, batch_y) 156 | train_loss.append(loss.item()) 157 | 158 | if (i + 1) % 100 == 0: 159 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) 160 | speed = (time.time() - time_now) / iter_count 161 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i) 162 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) 163 | iter_count = 0 164 | time_now = time.time() 165 | 166 | if self.args.use_amp: 167 | scaler.scale(loss).backward() 168 | scaler.step(model_optim) 169 | scaler.update() 170 | else: 171 | loss.backward() 172 | model_optim.step() 173 | 174 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)) 175 | train_loss = np.average(train_loss) 176 | vali_loss = self.vali(vali_data, vali_loader, criterion) 177 | test_loss = self.vali(test_data, test_loader, criterion) 178 | 179 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format( 180 | epoch + 1, train_steps, train_loss, vali_loss, test_loss)) 181 | early_stopping(vali_loss, self.model, path) 182 | if early_stopping.early_stop: 183 | print("Early stopping") 184 | break 185 | 186 | adjust_learning_rate(model_optim, epoch + 1, self.args) 187 | 188 | best_model_path = path + '/' + 'checkpoint.pth' 189 | self.model.load_state_dict(torch.load(best_model_path)) 190 | 191 | return self.model 192 | 193 | def test(self, setting, test=0): 194 | test_data, test_loader = self._get_data(flag='test') 195 | if test: 196 | print('loading model') 197 | self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth'))) 198 | 199 | preds = [] 200 | trues = [] 201 | folder_path = './test_results/' + setting + '/' 202 | if not os.path.exists(folder_path): 203 | os.makedirs(folder_path) 204 | 205 | self.model.eval() 206 | with torch.no_grad(): 207 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader): 208 | batch_x = batch_x.float().to(self.device) 209 | batch_y = batch_y.float().to(self.device) 210 | 211 | batch_x_mark = batch_x_mark.float().to(self.device) 212 | batch_y_mark = batch_y_mark.float().to(self.device) 213 | 214 | # decoder input 215 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() 216 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) 217 | # encoder - decoder 218 | if self.args.use_amp: 219 | with torch.cuda.amp.autocast(): 220 | if self.args.output_attention: 221 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] 222 | else: 223 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 224 | else: 225 | if self.args.output_attention: 226 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] 227 | 228 | else: 229 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 230 | 231 | f_dim = -1 if self.args.features == 'MS' else 0 232 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 233 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 234 | outputs = outputs.detach().cpu().numpy() 235 | batch_y = batch_y.detach().cpu().numpy() 236 | 237 | pred = outputs # outputs.detach().cpu().numpy() # .squeeze() 238 | true = batch_y # batch_y.detach().cpu().numpy() # .squeeze() 239 | 240 | preds.append(pred) 241 | trues.append(true) 242 | if i % 20 == 0: 243 | input = batch_x.detach().cpu().numpy() 244 | gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0) 245 | pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0) 246 | visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf')) 247 | 248 | preds = np.array(preds) 249 | trues = np.array(trues) 250 | print('test shape:', preds.shape, trues.shape) 251 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1]) 252 | trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1]) 253 | print('test shape:', preds.shape, trues.shape) 254 | 255 | # result save 256 | folder_path = './results/' + setting + '/' 257 | if not os.path.exists(folder_path): 258 | os.makedirs(folder_path) 259 | 260 | mae, mse, rmse, mape, mspe = metric(preds, trues) 261 | print('mse:{}, mae:{}'.format(mse, mae)) 262 | f = open("result.txt", 'a') 263 | f.write(setting + " \n") 264 | f.write('mse:{}, mae:{}'.format(mse, mae)) 265 | f.write('\n') 266 | f.write('\n') 267 | f.close() 268 | 269 | np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe])) 270 | np.save(folder_path + 'pred.npy', preds) 271 | np.save(folder_path + 'true.npy', trues) 272 | 273 | return 274 | 275 | def predict(self, setting, load=False): 276 | pred_data, pred_loader = self._get_data(flag='pred') 277 | 278 | if load: 279 | path = os.path.join(self.args.checkpoints, setting) 280 | best_model_path = path + '/' + 'checkpoint.pth' 281 | self.model.load_state_dict(torch.load(best_model_path)) 282 | 283 | preds = [] 284 | 285 | self.model.eval() 286 | with torch.no_grad(): 287 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader): 288 | batch_x = batch_x.float().to(self.device) 289 | batch_y = batch_y.float() 290 | batch_x_mark = batch_x_mark.float().to(self.device) 291 | batch_y_mark = batch_y_mark.float().to(self.device) 292 | 293 | # decoder input 294 | dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[2]]).float() 295 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) 296 | # encoder - decoder 297 | if self.args.use_amp: 298 | with torch.cuda.amp.autocast(): 299 | if self.args.output_attention: 300 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] 301 | else: 302 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 303 | else: 304 | if self.args.output_attention: 305 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] 306 | else: 307 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) 308 | pred = outputs.detach().cpu().numpy() # .squeeze() 309 | preds.append(pred) 310 | 311 | preds = np.array(preds) 312 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1]) 313 | 314 | # result save 315 | folder_path = './results/' + setting + '/' 316 | if not os.path.exists(folder_path): 317 | os.makedirs(folder_path) 318 | 319 | np.save(folder_path + 'real_prediction.npy', preds) 320 | 321 | return 322 | -------------------------------------------------------------------------------- /figures/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/figures/arch.png -------------------------------------------------------------------------------- /figures/da.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/figures/da.png -------------------------------------------------------------------------------- /figures/main_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/figures/main_results.png -------------------------------------------------------------------------------- /figures/promotion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/figures/promotion.png -------------------------------------------------------------------------------- /figures/showcases.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/figures/showcases.png -------------------------------------------------------------------------------- /figures/ss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/figures/ss.png -------------------------------------------------------------------------------- /layers/AutoCorrelation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class AutoCorrelation(nn.Module): 7 | """ 8 | AutoCorrelation Mechanism with the following two phases: 9 | (1) period-based dependencies discovery 10 | (2) time delay aggregation 11 | This block can replace the self-attention family mechanism seamlessly. 12 | """ 13 | def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False): 14 | super(AutoCorrelation, self).__init__() 15 | self.factor = factor 16 | self.scale = scale 17 | self.mask_flag = mask_flag 18 | self.output_attention = output_attention 19 | self.dropout = nn.Dropout(attention_dropout) 20 | 21 | def time_delay_agg_training(self, values, corr): 22 | """ 23 | SpeedUp version of Autocorrelation (a batch-normalization style design) 24 | This is for the training phase. 25 | """ 26 | head = values.shape[1] 27 | channel = values.shape[2] 28 | length = values.shape[3] 29 | # find top k 30 | top_k = int(self.factor * math.log(length)) 31 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 32 | index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] 33 | weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) 34 | # update corr 35 | tmp_corr = torch.softmax(weights, dim=-1) 36 | # aggregation 37 | tmp_values = values 38 | delays_agg = torch.zeros_like(values).float() 39 | for i in range(top_k): 40 | pattern = torch.roll(tmp_values, -int(index[i]), -1) 41 | delays_agg = delays_agg + pattern * \ 42 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 43 | return delays_agg 44 | 45 | def time_delay_agg_inference(self, values, corr): 46 | """ 47 | SpeedUp version of Autocorrelation (a batch-normalization style design) 48 | This is for the inference phase. 49 | """ 50 | batch = values.shape[0] 51 | head = values.shape[1] 52 | channel = values.shape[2] 53 | length = values.shape[3] 54 | # index init 55 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\ 56 | .repeat(batch, head, channel, 1).to(values.device) 57 | # find top k 58 | top_k = int(self.factor * math.log(length)) 59 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 60 | weights, delay = torch.topk(mean_value, top_k, dim=-1) 61 | # update corr 62 | tmp_corr = torch.softmax(weights, dim=-1) 63 | # aggregation 64 | tmp_values = values.repeat(1, 1, 1, 2) 65 | delays_agg = torch.zeros_like(values).float() 66 | for i in range(top_k): 67 | tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) 68 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 69 | delays_agg = delays_agg + pattern * \ 70 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 71 | return delays_agg 72 | 73 | def time_delay_agg_full(self, values, corr): 74 | """ 75 | Standard version of Autocorrelation 76 | """ 77 | batch = values.shape[0] 78 | head = values.shape[1] 79 | channel = values.shape[2] 80 | length = values.shape[3] 81 | # index init 82 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\ 83 | .repeat(batch, head, channel, 1).to(values.device) 84 | # find top k 85 | top_k = int(self.factor * math.log(length)) 86 | weights, delay = torch.topk(corr, top_k, dim=-1) 87 | # update corr 88 | tmp_corr = torch.softmax(weights, dim=-1) 89 | # aggregation 90 | tmp_values = values.repeat(1, 1, 1, 2) 91 | delays_agg = torch.zeros_like(values).float() 92 | for i in range(top_k): 93 | tmp_delay = init_index + delay[..., i].unsqueeze(-1) 94 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 95 | delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) 96 | return delays_agg 97 | 98 | def forward(self, queries, keys, values, attn_mask): 99 | B, L, H, E = queries.shape 100 | _, S, _, D = values.shape 101 | if L > S: 102 | zeros = torch.zeros_like(queries[:, :(L - S), :]).float() 103 | values = torch.cat([values, zeros], dim=1) 104 | keys = torch.cat([keys, zeros], dim=1) 105 | else: 106 | values = values[:, :L, :, :] 107 | keys = keys[:, :L, :, :] 108 | 109 | # period-based dependencies 110 | q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) 111 | k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) 112 | res = q_fft * torch.conj(k_fft) 113 | corr = torch.fft.irfft(res, dim=-1) 114 | 115 | # time delay agg 116 | if self.training: 117 | V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 118 | else: 119 | V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 120 | 121 | if self.output_attention: 122 | return (V.contiguous(), corr.permute(0, 3, 1, 2)) 123 | else: 124 | return (V.contiguous(), None) 125 | 126 | 127 | class AutoCorrelationLayer(nn.Module): 128 | def __init__(self, correlation, d_model, n_heads, d_keys=None, 129 | d_values=None): 130 | super(AutoCorrelationLayer, self).__init__() 131 | 132 | d_keys = d_keys or (d_model // n_heads) 133 | d_values = d_values or (d_model // n_heads) 134 | 135 | self.inner_correlation = correlation 136 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 137 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 138 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 139 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 140 | self.n_heads = n_heads 141 | 142 | def forward(self, queries, keys, values, attn_mask): 143 | B, L, _ = queries.shape 144 | _, S, _ = keys.shape 145 | H = self.n_heads 146 | 147 | queries = self.query_projection(queries).view(B, L, H, -1) 148 | keys = self.key_projection(keys).view(B, S, H, -1) 149 | values = self.value_projection(values).view(B, S, H, -1) 150 | 151 | out, attn = self.inner_correlation( 152 | queries, 153 | keys, 154 | values, 155 | attn_mask 156 | ) 157 | out = out.view(B, L, -1) 158 | 159 | return self.out_projection(out), attn 160 | -------------------------------------------------------------------------------- /layers/Autoformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class my_Layernorm(nn.Module): 7 | """ 8 | Special designed layernorm for the seasonal part 9 | """ 10 | def __init__(self, channels): 11 | super(my_Layernorm, self).__init__() 12 | self.layernorm = nn.LayerNorm(channels) 13 | 14 | def forward(self, x): 15 | x_hat = self.layernorm(x) 16 | bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) 17 | return x_hat - bias 18 | 19 | 20 | class moving_avg(nn.Module): 21 | """ 22 | Moving average block to highlight the trend of time series 23 | """ 24 | def __init__(self, kernel_size, stride): 25 | super(moving_avg, self).__init__() 26 | self.kernel_size = kernel_size 27 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 28 | 29 | def forward(self, x): 30 | # padding on the both ends of time series 31 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 32 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 33 | x = torch.cat([front, x, end], dim=1) 34 | x = self.avg(x.permute(0, 2, 1)) 35 | x = x.permute(0, 2, 1) 36 | return x 37 | 38 | 39 | class series_decomp(nn.Module): 40 | """ 41 | Series decomposition block 42 | """ 43 | def __init__(self, kernel_size): 44 | super(series_decomp, self).__init__() 45 | self.moving_avg = moving_avg(kernel_size, stride=1) 46 | 47 | def forward(self, x): 48 | moving_mean = self.moving_avg(x) 49 | res = x - moving_mean 50 | return res, moving_mean 51 | 52 | 53 | class EncoderLayer(nn.Module): 54 | """ 55 | Autoformer encoder layer with the progressive decomposition architecture 56 | """ 57 | def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): 58 | super(EncoderLayer, self).__init__() 59 | d_ff = d_ff or 4 * d_model 60 | self.attention = attention 61 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 62 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 63 | self.decomp1 = series_decomp(moving_avg) 64 | self.decomp2 = series_decomp(moving_avg) 65 | self.dropout = nn.Dropout(dropout) 66 | self.activation = F.relu if activation == "relu" else F.gelu 67 | 68 | def forward(self, x, attn_mask=None): 69 | new_x, attn = self.attention( 70 | x, x, x, 71 | attn_mask=attn_mask 72 | ) 73 | x = x + self.dropout(new_x) 74 | x, _ = self.decomp1(x) 75 | y = x 76 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 77 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 78 | res, _ = self.decomp2(x + y) 79 | return res, attn 80 | 81 | 82 | class Encoder(nn.Module): 83 | """ 84 | Autoformer encoder 85 | """ 86 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 87 | super(Encoder, self).__init__() 88 | self.attn_layers = nn.ModuleList(attn_layers) 89 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 90 | self.norm = norm_layer 91 | 92 | def forward(self, x, attn_mask=None): 93 | attns = [] 94 | if self.conv_layers is not None: 95 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 96 | x, attn = attn_layer(x, attn_mask=attn_mask) 97 | x = conv_layer(x) 98 | attns.append(attn) 99 | x, attn = self.attn_layers[-1](x) 100 | attns.append(attn) 101 | else: 102 | for attn_layer in self.attn_layers: 103 | x, attn = attn_layer(x, attn_mask=attn_mask) 104 | attns.append(attn) 105 | 106 | if self.norm is not None: 107 | x = self.norm(x) 108 | 109 | return x, attns 110 | 111 | 112 | class DecoderLayer(nn.Module): 113 | """ 114 | Autoformer decoder layer with the progressive decomposition architecture 115 | """ 116 | def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None, 117 | moving_avg=25, dropout=0.1, activation="relu"): 118 | super(DecoderLayer, self).__init__() 119 | d_ff = d_ff or 4 * d_model 120 | self.self_attention = self_attention 121 | self.cross_attention = cross_attention 122 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 123 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 124 | self.decomp1 = series_decomp(moving_avg) 125 | self.decomp2 = series_decomp(moving_avg) 126 | self.decomp3 = series_decomp(moving_avg) 127 | self.dropout = nn.Dropout(dropout) 128 | self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1, 129 | padding_mode='circular', bias=False) 130 | self.activation = F.relu if activation == "relu" else F.gelu 131 | 132 | def forward(self, x, cross, x_mask=None, cross_mask=None): 133 | x = x + self.dropout(self.self_attention( 134 | x, x, x, 135 | attn_mask=x_mask 136 | )[0]) 137 | x, trend1 = self.decomp1(x) 138 | x = x + self.dropout(self.cross_attention( 139 | x, cross, cross, 140 | attn_mask=cross_mask 141 | )[0]) 142 | x, trend2 = self.decomp2(x) 143 | y = x 144 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 145 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 146 | x, trend3 = self.decomp3(x + y) 147 | 148 | residual_trend = trend1 + trend2 + trend3 149 | residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) 150 | return x, residual_trend 151 | 152 | 153 | class Decoder(nn.Module): 154 | """ 155 | Autoformer encoder 156 | """ 157 | def __init__(self, layers, norm_layer=None, projection=None): 158 | super(Decoder, self).__init__() 159 | self.layers = nn.ModuleList(layers) 160 | self.norm = norm_layer 161 | self.projection = projection 162 | 163 | def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): 164 | for layer in self.layers: 165 | x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 166 | trend = trend + residual_trend 167 | 168 | if self.norm is not None: 169 | x = self.norm(x) 170 | 171 | if self.projection is not None: 172 | x = self.projection(x) 173 | return x, trend 174 | -------------------------------------------------------------------------------- /layers/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class PositionalEmbedding(nn.Module): 7 | def __init__(self, d_model, max_len=5000): 8 | super(PositionalEmbedding, self).__init__() 9 | # Compute the positional encodings once in log space. 10 | pe = torch.zeros(max_len, d_model).float() 11 | pe.require_grad = False 12 | 13 | position = torch.arange(0, max_len).float().unsqueeze(1) 14 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 15 | 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | 19 | pe = pe.unsqueeze(0) 20 | self.register_buffer('pe', pe) 21 | 22 | def forward(self, x): 23 | return self.pe[:, :x.size(1)] 24 | 25 | 26 | class TokenEmbedding(nn.Module): 27 | def __init__(self, c_in, d_model): 28 | super(TokenEmbedding, self).__init__() 29 | padding = 1 if torch.__version__ >= '1.5.0' else 2 30 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 31 | kernel_size=3, padding=padding, padding_mode='circular', bias=False) 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv1d): 34 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') 35 | 36 | def forward(self, x): 37 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 38 | return x 39 | 40 | 41 | class FixedEmbedding(nn.Module): 42 | def __init__(self, c_in, d_model): 43 | super(FixedEmbedding, self).__init__() 44 | 45 | w = torch.zeros(c_in, d_model).float() 46 | w.require_grad = False 47 | 48 | position = torch.arange(0, c_in).float().unsqueeze(1) 49 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 50 | 51 | w[:, 0::2] = torch.sin(position * div_term) 52 | w[:, 1::2] = torch.cos(position * div_term) 53 | 54 | self.emb = nn.Embedding(c_in, d_model) 55 | self.emb.weight = nn.Parameter(w, requires_grad=False) 56 | 57 | def forward(self, x): 58 | return self.emb(x).detach() 59 | 60 | 61 | class TemporalEmbedding(nn.Module): 62 | def __init__(self, d_model, embed_type='fixed', freq='h'): 63 | super(TemporalEmbedding, self).__init__() 64 | 65 | minute_size = 4 66 | hour_size = 24 67 | weekday_size = 7 68 | day_size = 32 69 | month_size = 13 70 | 71 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding 72 | if freq == 't': 73 | self.minute_embed = Embed(minute_size, d_model) 74 | self.hour_embed = Embed(hour_size, d_model) 75 | self.weekday_embed = Embed(weekday_size, d_model) 76 | self.day_embed = Embed(day_size, d_model) 77 | self.month_embed = Embed(month_size, d_model) 78 | 79 | def forward(self, x): 80 | x = x.long() 81 | 82 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0. 83 | hour_x = self.hour_embed(x[:, :, 3]) 84 | weekday_x = self.weekday_embed(x[:, :, 2]) 85 | day_x = self.day_embed(x[:, :, 1]) 86 | month_x = self.month_embed(x[:, :, 0]) 87 | 88 | return hour_x + weekday_x + day_x + month_x + minute_x 89 | 90 | 91 | class TimeFeatureEmbedding(nn.Module): 92 | def __init__(self, d_model, embed_type='timeF', freq='h'): 93 | super(TimeFeatureEmbedding, self).__init__() 94 | 95 | freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 96 | d_inp = freq_map[freq] 97 | self.embed = nn.Linear(d_inp, d_model, bias=False) 98 | 99 | def forward(self, x): 100 | return self.embed(x) 101 | 102 | 103 | class DataEmbedding(nn.Module): 104 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 105 | super(DataEmbedding, self).__init__() 106 | 107 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 108 | self.position_embedding = PositionalEmbedding(d_model=d_model) 109 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 110 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 111 | d_model=d_model, embed_type=embed_type, freq=freq) 112 | self.dropout = nn.Dropout(p=dropout) 113 | 114 | def forward(self, x, x_mark): 115 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 116 | return self.dropout(x) 117 | 118 | 119 | class DataEmbedding_wo_pos(nn.Module): 120 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 121 | super(DataEmbedding_wo_pos, self).__init__() 122 | 123 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 124 | self.position_embedding = PositionalEmbedding(d_model=d_model) 125 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 126 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 127 | d_model=d_model, embed_type=embed_type, freq=freq) 128 | self.dropout = nn.Dropout(p=dropout) 129 | 130 | def forward(self, x, x_mark): 131 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) 132 | return self.dropout(x) 133 | -------------------------------------------------------------------------------- /layers/SelfAttention_Family.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from math import sqrt 5 | from utils.masking import TriangularCausalMask, ProbMask 6 | 7 | 8 | class FullAttention(nn.Module): 9 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 10 | super(FullAttention, self).__init__() 11 | self.scale = scale 12 | self.mask_flag = mask_flag 13 | self.output_attention = output_attention 14 | self.dropout = nn.Dropout(attention_dropout) 15 | 16 | def forward(self, queries, keys, values, attn_mask): 17 | B, L, H, E = queries.shape 18 | _, S, _, D = values.shape 19 | scale = self.scale or 1. / sqrt(E) 20 | 21 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) 22 | 23 | if self.mask_flag: 24 | if attn_mask is None: 25 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 26 | 27 | scores.masked_fill_(attn_mask.mask, -np.inf) 28 | 29 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 30 | V = torch.einsum("bhls,bshd->blhd", A, values) 31 | 32 | if self.output_attention: 33 | return (V.contiguous(), A) 34 | else: 35 | return (V.contiguous(), None) 36 | 37 | 38 | class ProbAttention(nn.Module): 39 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 40 | super(ProbAttention, self).__init__() 41 | self.factor = factor 42 | self.scale = scale 43 | self.mask_flag = mask_flag 44 | self.output_attention = output_attention 45 | self.dropout = nn.Dropout(attention_dropout) 46 | 47 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) 48 | # Q [B, H, L, D] 49 | B, H, L_K, E = K.shape 50 | _, _, L_Q, _ = Q.shape 51 | 52 | # calculate the sampled Q_K 53 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) 54 | index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q 55 | K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] 56 | Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() 57 | 58 | # find the Top_k query with sparisty measurement 59 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) 60 | M_top = M.topk(n_top, sorted=False)[1] 61 | 62 | # use the reduced Q to calculate Q_K 63 | Q_reduce = Q[torch.arange(B)[:, None, None], 64 | torch.arange(H)[None, :, None], 65 | M_top, :] # factor*ln(L_q) 66 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k 67 | 68 | return Q_K, M_top 69 | 70 | def _get_initial_context(self, V, L_Q): 71 | B, H, L_V, D = V.shape 72 | if not self.mask_flag: 73 | # V_sum = V.sum(dim=-2) 74 | V_sum = V.mean(dim=-2) 75 | contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() 76 | else: # use mask 77 | assert (L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only 78 | contex = V.cumsum(dim=-2) 79 | return contex 80 | 81 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): 82 | B, H, L_V, D = V.shape 83 | 84 | if self.mask_flag: 85 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) 86 | scores.masked_fill_(attn_mask.mask, -np.inf) 87 | 88 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) 89 | 90 | context_in[torch.arange(B)[:, None, None], 91 | torch.arange(H)[None, :, None], 92 | index, :] = torch.matmul(attn, V).type_as(context_in) 93 | if self.output_attention: 94 | attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) 95 | attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn 96 | return (context_in, attns) 97 | else: 98 | return (context_in, None) 99 | 100 | def forward(self, queries, keys, values, attn_mask): 101 | B, L_Q, H, D = queries.shape 102 | _, L_K, _, _ = keys.shape 103 | 104 | queries = queries.transpose(2, 1) 105 | keys = keys.transpose(2, 1) 106 | values = values.transpose(2, 1) 107 | 108 | U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) 109 | u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 110 | 111 | U_part = U_part if U_part < L_K else L_K 112 | u = u if u < L_Q else L_Q 113 | 114 | scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) 115 | 116 | # add scale factor 117 | scale = self.scale or 1. / sqrt(D) 118 | if scale is not None: 119 | scores_top = scores_top * scale 120 | # get the context 121 | context = self._get_initial_context(values, L_Q) 122 | # update the context with selected top_k queries 123 | context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask) 124 | 125 | return context.contiguous(), attn 126 | 127 | 128 | class AttentionLayer(nn.Module): 129 | def __init__(self, attention, d_model, n_heads, d_keys=None, 130 | d_values=None): 131 | super(AttentionLayer, self).__init__() 132 | 133 | d_keys = d_keys or (d_model // n_heads) 134 | d_values = d_values or (d_model // n_heads) 135 | 136 | self.inner_attention = attention 137 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 138 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 139 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 140 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 141 | self.n_heads = n_heads 142 | 143 | def forward(self, queries, keys, values, attn_mask): 144 | B, L, _ = queries.shape 145 | _, S, _ = keys.shape 146 | H = self.n_heads 147 | 148 | queries = self.query_projection(queries).view(B, L, H, -1) 149 | keys = self.key_projection(keys).view(B, S, H, -1) 150 | values = self.value_projection(values).view(B, S, H, -1) 151 | 152 | out, attn = self.inner_attention( 153 | queries, 154 | keys, 155 | values, 156 | attn_mask 157 | ) 158 | out = out.view(B, L, -1) 159 | 160 | return self.out_projection(out), attn 161 | -------------------------------------------------------------------------------- /layers/Transformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class ConvLayer(nn.Module): 6 | def __init__(self, c_in): 7 | super(ConvLayer, self).__init__() 8 | self.downConv = nn.Conv1d(in_channels=c_in, 9 | out_channels=c_in, 10 | kernel_size=3, 11 | padding=2, 12 | padding_mode='circular') 13 | self.norm = nn.BatchNorm1d(c_in) 14 | self.activation = nn.ELU() 15 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 16 | 17 | def forward(self, x): 18 | x = self.downConv(x.permute(0, 2, 1)) 19 | x = self.norm(x) 20 | x = self.activation(x) 21 | x = self.maxPool(x) 22 | x = x.transpose(1, 2) 23 | return x 24 | 25 | 26 | class EncoderLayer(nn.Module): 27 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 28 | super(EncoderLayer, self).__init__() 29 | d_ff = d_ff or 4 * d_model 30 | self.attention = attention 31 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 32 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 33 | self.norm1 = nn.LayerNorm(d_model) 34 | self.norm2 = nn.LayerNorm(d_model) 35 | self.dropout = nn.Dropout(dropout) 36 | self.activation = F.relu if activation == "relu" else F.gelu 37 | 38 | def forward(self, x, attn_mask=None): 39 | new_x, attn = self.attention( 40 | x, x, x, 41 | attn_mask=attn_mask 42 | ) 43 | x = x + self.dropout(new_x) 44 | 45 | y = x = self.norm1(x) 46 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 47 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 48 | 49 | return self.norm2(x + y), attn 50 | 51 | 52 | class Encoder(nn.Module): 53 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 54 | super(Encoder, self).__init__() 55 | self.attn_layers = nn.ModuleList(attn_layers) 56 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 57 | self.norm = norm_layer 58 | 59 | def forward(self, x, attn_mask=None): 60 | # x [B, L, D] 61 | attns = [] 62 | if self.conv_layers is not None: 63 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 64 | x, attn = attn_layer(x, attn_mask=attn_mask) 65 | x = conv_layer(x) 66 | attns.append(attn) 67 | x, attn = self.attn_layers[-1](x) 68 | attns.append(attn) 69 | else: 70 | for attn_layer in self.attn_layers: 71 | x, attn = attn_layer(x, attn_mask=attn_mask) 72 | attns.append(attn) 73 | 74 | if self.norm is not None: 75 | x = self.norm(x) 76 | 77 | return x, attns 78 | 79 | 80 | class DecoderLayer(nn.Module): 81 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 82 | dropout=0.1, activation="relu"): 83 | super(DecoderLayer, self).__init__() 84 | d_ff = d_ff or 4 * d_model 85 | self.self_attention = self_attention 86 | self.cross_attention = cross_attention 87 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 88 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 89 | self.norm1 = nn.LayerNorm(d_model) 90 | self.norm2 = nn.LayerNorm(d_model) 91 | self.norm3 = nn.LayerNorm(d_model) 92 | self.dropout = nn.Dropout(dropout) 93 | self.activation = F.relu if activation == "relu" else F.gelu 94 | 95 | def forward(self, x, cross, x_mask=None, cross_mask=None): 96 | x = x + self.dropout(self.self_attention( 97 | x, x, x, 98 | attn_mask=x_mask 99 | )[0]) 100 | x = self.norm1(x) 101 | 102 | x = x + self.dropout(self.cross_attention( 103 | x, cross, cross, 104 | attn_mask=cross_mask 105 | )[0]) 106 | 107 | y = x = self.norm2(x) 108 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 109 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 110 | 111 | return self.norm3(x + y) 112 | 113 | 114 | class Decoder(nn.Module): 115 | def __init__(self, layers, norm_layer=None, projection=None): 116 | super(Decoder, self).__init__() 117 | self.layers = nn.ModuleList(layers) 118 | self.norm = norm_layer 119 | self.projection = projection 120 | 121 | def forward(self, x, cross, x_mask=None, cross_mask=None): 122 | for layer in self.layers: 123 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 124 | 125 | if self.norm is not None: 126 | x = self.norm(x) 127 | 128 | if self.projection is not None: 129 | x = self.projection(x) 130 | return x 131 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/layers/__init__.py -------------------------------------------------------------------------------- /models/Autoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers.Embed import DataEmbedding_wo_pos 4 | from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer 5 | from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp 6 | 7 | 8 | 9 | class Model(nn.Module): 10 | """ 11 | Autoformer is the first method to achieve the series-wise connection, 12 | with inherent O(LlogL) complexity 13 | """ 14 | def __init__(self, configs): 15 | super(Model, self).__init__() 16 | self.seq_len = configs.seq_len 17 | self.label_len = configs.label_len 18 | self.pred_len = configs.pred_len 19 | self.output_attention = configs.output_attention 20 | 21 | # Decomp 22 | kernel_size = configs.moving_avg 23 | self.decomp = series_decomp(kernel_size) 24 | 25 | # Embedding 26 | # The series-wise connection inherently contains the sequential information. 27 | # Thus, we can discard the position embedding of transformers. 28 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, 29 | configs.dropout) 30 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, 31 | configs.dropout) 32 | 33 | # Encoder 34 | self.encoder = Encoder( 35 | [ 36 | EncoderLayer( 37 | AutoCorrelationLayer( 38 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 39 | output_attention=configs.output_attention), 40 | configs.d_model, configs.n_heads), 41 | configs.d_model, 42 | configs.d_ff, 43 | moving_avg=configs.moving_avg, 44 | dropout=configs.dropout, 45 | activation=configs.activation 46 | ) for l in range(configs.e_layers) 47 | ], 48 | norm_layer=my_Layernorm(configs.d_model) 49 | ) 50 | # Decoder 51 | self.decoder = Decoder( 52 | [ 53 | DecoderLayer( 54 | AutoCorrelationLayer( 55 | AutoCorrelation(True, configs.factor, attention_dropout=configs.dropout, 56 | output_attention=False), 57 | configs.d_model, configs.n_heads), 58 | AutoCorrelationLayer( 59 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 60 | output_attention=False), 61 | configs.d_model, configs.n_heads), 62 | configs.d_model, 63 | configs.c_out, 64 | configs.d_ff, 65 | moving_avg=configs.moving_avg, 66 | dropout=configs.dropout, 67 | activation=configs.activation, 68 | ) 69 | for l in range(configs.d_layers) 70 | ], 71 | norm_layer=my_Layernorm(configs.d_model), 72 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 73 | ) 74 | 75 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 76 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 77 | # decomp init 78 | mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) 79 | zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device) 80 | seasonal_init, trend_init = self.decomp(x_enc) 81 | # decoder input 82 | trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1) 83 | seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1) 84 | # enc 85 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 86 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 87 | # dec 88 | dec_out = self.dec_embedding(seasonal_init, x_mark_dec) 89 | seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, 90 | trend=trend_init) 91 | # final 92 | dec_out = trend_part + seasonal_part 93 | 94 | if self.output_attention: 95 | return dec_out[:, -self.pred_len:, :], attns 96 | else: 97 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 98 | -------------------------------------------------------------------------------- /models/Informer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer 4 | from layers.SelfAttention_Family import ProbAttention, AttentionLayer 5 | from layers.Embed import DataEmbedding 6 | 7 | 8 | class Model(nn.Module): 9 | """ 10 | Informer with Propspare attention in O(LlogL) complexity 11 | """ 12 | def __init__(self, configs): 13 | super(Model, self).__init__() 14 | self.pred_len = configs.pred_len 15 | self.output_attention = configs.output_attention 16 | 17 | # Embedding 18 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 19 | configs.dropout) 20 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 21 | configs.dropout) 22 | 23 | # Encoder 24 | self.encoder = Encoder( 25 | [ 26 | EncoderLayer( 27 | AttentionLayer( 28 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, 29 | output_attention=configs.output_attention), 30 | configs.d_model, configs.n_heads), 31 | configs.d_model, 32 | configs.d_ff, 33 | dropout=configs.dropout, 34 | activation=configs.activation 35 | ) for l in range(configs.e_layers) 36 | ], 37 | [ 38 | ConvLayer( 39 | configs.d_model 40 | ) for l in range(configs.e_layers - 1) 41 | ] if configs.distil else None, 42 | norm_layer=torch.nn.LayerNorm(configs.d_model) 43 | ) 44 | # Decoder 45 | self.decoder = Decoder( 46 | [ 47 | DecoderLayer( 48 | AttentionLayer( 49 | ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False), 50 | configs.d_model, configs.n_heads), 51 | AttentionLayer( 52 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), 53 | configs.d_model, configs.n_heads), 54 | configs.d_model, 55 | configs.d_ff, 56 | dropout=configs.dropout, 57 | activation=configs.activation, 58 | ) 59 | for l in range(configs.d_layers) 60 | ], 61 | norm_layer=torch.nn.LayerNorm(configs.d_model), 62 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 63 | ) 64 | 65 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 66 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 67 | 68 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 69 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 70 | 71 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 72 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) 73 | 74 | if self.output_attention: 75 | return dec_out[:, -self.pred_len:, :], attns 76 | else: 77 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 78 | -------------------------------------------------------------------------------- /models/Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer 4 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 5 | from layers.Embed import DataEmbedding 6 | 7 | 8 | class Model(nn.Module): 9 | """ 10 | Vanilla Transformer 11 | """ 12 | def __init__(self, configs): 13 | super(Model, self).__init__() 14 | self.pred_len = configs.pred_len 15 | self.output_attention = configs.output_attention 16 | 17 | # Embedding 18 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 19 | configs.dropout) 20 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 21 | configs.dropout) 22 | # Encoder 23 | self.encoder = Encoder( 24 | [ 25 | EncoderLayer( 26 | AttentionLayer( 27 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 28 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 29 | configs.d_model, 30 | configs.d_ff, 31 | dropout=configs.dropout, 32 | activation=configs.activation 33 | ) for l in range(configs.e_layers) 34 | ], 35 | norm_layer=torch.nn.LayerNorm(configs.d_model) 36 | ) 37 | # Decoder 38 | self.decoder = Decoder( 39 | [ 40 | DecoderLayer( 41 | AttentionLayer( 42 | FullAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False), 43 | configs.d_model, configs.n_heads), 44 | AttentionLayer( 45 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), 46 | configs.d_model, configs.n_heads), 47 | configs.d_model, 48 | configs.d_ff, 49 | dropout=configs.dropout, 50 | activation=configs.activation, 51 | ) 52 | for l in range(configs.d_layers) 53 | ], 54 | norm_layer=torch.nn.LayerNorm(configs.d_model), 55 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 56 | ) 57 | 58 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 59 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 60 | 61 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 62 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 63 | 64 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 65 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) 66 | 67 | if self.output_attention: 68 | return dec_out[:, -self.pred_len:, :], attns 69 | else: 70 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 71 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/models/__init__.py -------------------------------------------------------------------------------- /ns_layers/AutoCorrelation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class DSAutoCorrelation(nn.Module): 7 | """ 8 | AutoCorrelation Mechanism with the following two phases: 9 | (1) period-based dependencies discovery 10 | (2) time delay aggregation 11 | This block can replace the self-attention family mechanism seamlessly. 12 | """ 13 | def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False): 14 | super(DSAutoCorrelation, self).__init__() 15 | self.factor = factor 16 | self.scale = scale 17 | self.mask_flag = mask_flag 18 | self.output_attention = output_attention 19 | self.dropout = nn.Dropout(attention_dropout) 20 | 21 | def time_delay_agg_training(self, values, corr): 22 | """ 23 | SpeedUp version of Autocorrelation (a batch-normalization style design) 24 | This is for the training phase. 25 | """ 26 | head = values.shape[1] 27 | channel = values.shape[2] 28 | length = values.shape[3] 29 | # find top k 30 | top_k = int(self.factor * math.log(length)) 31 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 32 | index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] 33 | weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) 34 | # update corr 35 | tmp_corr = torch.softmax(weights, dim=-1) 36 | # aggregation 37 | tmp_values = values 38 | delays_agg = torch.zeros_like(values).float() 39 | for i in range(top_k): 40 | pattern = torch.roll(tmp_values, -int(index[i]), -1) 41 | delays_agg = delays_agg + pattern * \ 42 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 43 | return delays_agg 44 | 45 | def time_delay_agg_inference(self, values, corr): 46 | """ 47 | SpeedUp version of Autocorrelation (a batch-normalization style design) 48 | This is for the inference phase. 49 | """ 50 | batch = values.shape[0] 51 | head = values.shape[1] 52 | channel = values.shape[2] 53 | length = values.shape[3] 54 | # index init 55 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\ 56 | .repeat(batch, head, channel, 1).to(values.device) 57 | # find top k 58 | top_k = int(self.factor * math.log(length)) 59 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 60 | weights, delay = torch.topk(mean_value, top_k, dim=-1) 61 | # update corr 62 | tmp_corr = torch.softmax(weights, dim=-1) 63 | # aggregation 64 | tmp_values = values.repeat(1, 1, 1, 2) 65 | delays_agg = torch.zeros_like(values).float() 66 | for i in range(top_k): 67 | tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) 68 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 69 | delays_agg = delays_agg + pattern * \ 70 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 71 | return delays_agg 72 | 73 | def time_delay_agg_full(self, values, corr): 74 | """ 75 | Standard version of Autocorrelation 76 | """ 77 | batch = values.shape[0] 78 | head = values.shape[1] 79 | channel = values.shape[2] 80 | length = values.shape[3] 81 | # index init 82 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0)\ 83 | .repeat(batch, head, channel, 1).to(values.device) 84 | # find top k 85 | top_k = int(self.factor * math.log(length)) 86 | weights, delay = torch.topk(corr, top_k, dim=-1) 87 | # update corr 88 | tmp_corr = torch.softmax(weights, dim=-1) 89 | # aggregation 90 | tmp_values = values.repeat(1, 1, 1, 2) 91 | delays_agg = torch.zeros_like(values).float() 92 | for i in range(top_k): 93 | tmp_delay = init_index + delay[..., i].unsqueeze(-1) 94 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 95 | delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) 96 | return delays_agg 97 | 98 | 99 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 100 | B, L, H, E = queries.shape 101 | _, S, _, D = values.shape 102 | if L > S: 103 | zeros = torch.zeros_like(queries[:, :(L - S), :]).float() 104 | values = torch.cat([values, zeros], dim=1) 105 | keys = torch.cat([keys, zeros], dim=1) 106 | else: 107 | values = values[:, :L, :, :] 108 | keys = keys[:, :L, :, :] 109 | 110 | # period-based dependencies 111 | q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) # B x H x E x S//2 112 | k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) 113 | res = q_fft * torch.conj(k_fft) 114 | corr = torch.fft.irfft(res, dim=-1) # B x H x E x S 115 | 116 | tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1 117 | delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x S 118 | corr = corr * tau + delta 119 | 120 | # time delay agg 121 | if self.training: 122 | V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 123 | else: 124 | V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 125 | 126 | if self.output_attention: 127 | return (V.contiguous(), corr.permute(0, 3, 1, 2)) # B x S x H x E 128 | else: 129 | return (V.contiguous(), None) 130 | 131 | 132 | class AutoCorrelationLayer(nn.Module): 133 | def __init__(self, correlation, d_model, n_heads, d_keys=None, 134 | d_values=None): 135 | super(AutoCorrelationLayer, self).__init__() 136 | 137 | d_keys = d_keys or (d_model // n_heads) 138 | d_values = d_values or (d_model // n_heads) 139 | 140 | self.inner_correlation = correlation 141 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 142 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 143 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 144 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 145 | self.n_heads = n_heads 146 | 147 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 148 | B, L, _ = queries.shape 149 | _, S, _ = keys.shape 150 | H = self.n_heads 151 | 152 | queries = self.query_projection(queries).view(B, L, H, -1) 153 | keys = self.key_projection(keys).view(B, S, H, -1) 154 | values = self.value_projection(values).view(B, S, H, -1) 155 | 156 | out, attn = self.inner_correlation( 157 | queries, 158 | keys, 159 | values, 160 | attn_mask, 161 | tau, delta 162 | ) 163 | out = out.view(B, L, -1) 164 | 165 | return self.out_projection(out), attn 166 | -------------------------------------------------------------------------------- /ns_layers/Autoformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class my_Layernorm(nn.Module): 7 | """ 8 | Special designed layernorm for the seasonal part 9 | """ 10 | def __init__(self, channels): 11 | super(my_Layernorm, self).__init__() 12 | self.layernorm = nn.LayerNorm(channels) 13 | 14 | def forward(self, x): 15 | x_hat = self.layernorm(x) 16 | bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) 17 | return x_hat - bias 18 | 19 | 20 | class moving_avg(nn.Module): 21 | """ 22 | Moving average block to highlight the trend of time series 23 | """ 24 | def __init__(self, kernel_size, stride): 25 | super(moving_avg, self).__init__() 26 | self.kernel_size = kernel_size 27 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 28 | 29 | def forward(self, x): 30 | # padding on the both ends of time series 31 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 32 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 33 | x = torch.cat([front, x, end], dim=1) 34 | x = self.avg(x.permute(0, 2, 1)) 35 | x = x.permute(0, 2, 1) 36 | return x 37 | 38 | 39 | class series_decomp(nn.Module): 40 | """ 41 | Series decomposition block 42 | """ 43 | def __init__(self, kernel_size): 44 | super(series_decomp, self).__init__() 45 | self.moving_avg = moving_avg(kernel_size, stride=1) 46 | 47 | def forward(self, x): 48 | moving_mean = self.moving_avg(x) 49 | res = x - moving_mean 50 | return res, moving_mean 51 | 52 | 53 | class EncoderLayer(nn.Module): 54 | """ 55 | Autoformer encoder layer with the progressive decomposition architecture 56 | """ 57 | def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): 58 | super(EncoderLayer, self).__init__() 59 | d_ff = d_ff or 4 * d_model 60 | self.attention = attention 61 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 62 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 63 | self.decomp1 = series_decomp(moving_avg) 64 | self.decomp2 = series_decomp(moving_avg) 65 | self.dropout = nn.Dropout(dropout) 66 | self.activation = F.relu if activation == "relu" else F.gelu 67 | 68 | def forward(self, x, attn_mask=None, tau=None, delta=None): 69 | new_x, attn = self.attention( 70 | x, x, x, 71 | attn_mask=attn_mask, 72 | tau=tau, delta=delta 73 | ) 74 | x = x + self.dropout(new_x) 75 | x, _ = self.decomp1(x) 76 | y = x 77 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 78 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 79 | res, _ = self.decomp2(x + y) 80 | return res, attn 81 | 82 | 83 | class Encoder(nn.Module): 84 | """ 85 | Autoformer encoder 86 | """ 87 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 88 | super(Encoder, self).__init__() 89 | self.attn_layers = nn.ModuleList(attn_layers) 90 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 91 | self.norm = norm_layer 92 | 93 | def forward(self, x, attn_mask=None, tau=None, delta=None): 94 | attns = [] 95 | if self.conv_layers is not None: 96 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 97 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 98 | x = conv_layer(x) 99 | attns.append(attn) 100 | x, attn = self.attn_layers[-1](x, tau=tau, delta=delta) 101 | attns.append(attn) 102 | else: 103 | for attn_layer in self.attn_layers: 104 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 105 | attns.append(attn) 106 | 107 | if self.norm is not None: 108 | x = self.norm(x) 109 | 110 | return x, attns 111 | 112 | 113 | class DecoderLayer(nn.Module): 114 | """ 115 | Autoformer decoder layer with the progressive decomposition architecture 116 | """ 117 | def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None, 118 | moving_avg=25, dropout=0.1, activation="relu"): 119 | super(DecoderLayer, self).__init__() 120 | d_ff = d_ff or 4 * d_model 121 | self.self_attention = self_attention 122 | self.cross_attention = cross_attention 123 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 124 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 125 | self.decomp1 = series_decomp(moving_avg) 126 | self.decomp2 = series_decomp(moving_avg) 127 | self.decomp3 = series_decomp(moving_avg) 128 | self.dropout = nn.Dropout(dropout) 129 | self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1, 130 | padding_mode='circular', bias=False) 131 | self.activation = F.relu if activation == "relu" else F.gelu 132 | 133 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 134 | # Note that delta only used for Self-Attention(x_enc with x_enc) 135 | # and Cross-Attention(x_enc with x_dec), 136 | # but not suitable for Self-Attention(x_dec with x_dec) 137 | 138 | x = x + self.dropout(self.self_attention( 139 | x, x, x, 140 | attn_mask=x_mask, 141 | tau=tau, delta=None 142 | )[0]) 143 | x, trend1 = self.decomp1(x) 144 | x = x + self.dropout(self.cross_attention( 145 | x, cross, cross, 146 | attn_mask=cross_mask, 147 | tau=tau, delta=delta 148 | )[0]) 149 | x, trend2 = self.decomp2(x) 150 | y = x 151 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 152 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 153 | x, trend3 = self.decomp3(x + y) 154 | 155 | residual_trend = trend1 + trend2 + trend3 156 | residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) 157 | return x, residual_trend 158 | 159 | 160 | class Decoder(nn.Module): 161 | """ 162 | Autoformer encoder 163 | """ 164 | def __init__(self, layers, norm_layer=None, projection=None): 165 | super(Decoder, self).__init__() 166 | self.layers = nn.ModuleList(layers) 167 | self.norm = norm_layer 168 | self.projection = projection 169 | 170 | def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None, tau=None, delta=None): 171 | for layer in self.layers: 172 | x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta) 173 | trend = trend + residual_trend 174 | 175 | if self.norm is not None: 176 | x = self.norm(x) 177 | 178 | if self.projection is not None: 179 | x = self.projection(x) 180 | return x, trend 181 | -------------------------------------------------------------------------------- /ns_layers/SelfAttention_Family.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from math import sqrt 5 | from utils.masking import TriangularCausalMask, ProbMask 6 | 7 | 8 | class DSAttention(nn.Module): 9 | '''De-stationary Attention''' 10 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 11 | super(DSAttention, self).__init__() 12 | self.scale = scale 13 | self.mask_flag = mask_flag 14 | self.output_attention = output_attention 15 | self.dropout = nn.Dropout(attention_dropout) 16 | 17 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 18 | B, L, H, E = queries.shape 19 | _, S, _, D = values.shape 20 | scale = self.scale or 1. / sqrt(E) 21 | 22 | tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1 23 | delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x S 24 | 25 | # De-stationary Attention, rescaling pre-softmax score with learned de-stationary factors 26 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) * tau + delta 27 | 28 | if self.mask_flag: 29 | if attn_mask is None: 30 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 31 | 32 | scores.masked_fill_(attn_mask.mask, -np.inf) 33 | 34 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 35 | V = torch.einsum("bhls,bshd->blhd", A, values) 36 | 37 | if self.output_attention: 38 | return (V.contiguous(), A) 39 | else: 40 | return (V.contiguous(), None) 41 | 42 | 43 | class DSProbAttention(nn.Module): 44 | '''De-stationary ProbAttention for Informer''' 45 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 46 | super(DSProbAttention, self).__init__() 47 | self.factor = factor 48 | self.scale = scale 49 | self.mask_flag = mask_flag 50 | self.output_attention = output_attention 51 | self.dropout = nn.Dropout(attention_dropout) 52 | 53 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) 54 | # Q [B, H, L, D] 55 | B, H, L_K, E = K.shape 56 | _, _, L_Q, _ = Q.shape 57 | 58 | # calculate the sampled Q_K 59 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) 60 | index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q 61 | K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] 62 | Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() 63 | 64 | # find the Top_k query with sparisty measurement 65 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) 66 | M_top = M.topk(n_top, sorted=False)[1] 67 | 68 | # use the reduced Q to calculate Q_K 69 | Q_reduce = Q[torch.arange(B)[:, None, None], 70 | torch.arange(H)[None, :, None], 71 | M_top, :] # factor*ln(L_q) 72 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k 73 | 74 | return Q_K, M_top 75 | 76 | def _get_initial_context(self, V, L_Q): 77 | B, H, L_V, D = V.shape 78 | if not self.mask_flag: 79 | # V_sum = V.sum(dim=-2) 80 | V_sum = V.mean(dim=-2) 81 | contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() 82 | else: # use mask 83 | assert (L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only 84 | contex = V.cumsum(dim=-2) 85 | return contex 86 | 87 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): 88 | B, H, L_V, D = V.shape 89 | 90 | if self.mask_flag: 91 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) 92 | scores.masked_fill_(attn_mask.mask, -np.inf) 93 | 94 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) 95 | 96 | context_in[torch.arange(B)[:, None, None], 97 | torch.arange(H)[None, :, None], 98 | index, :] = torch.matmul(attn, V).type_as(context_in) 99 | if self.output_attention: 100 | attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) 101 | attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn 102 | return (context_in, attns) 103 | else: 104 | return (context_in, None) 105 | 106 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 107 | B, L_Q, H, D = queries.shape 108 | _, L_K, _, _ = keys.shape 109 | 110 | queries = queries.transpose(2, 1) 111 | keys = keys.transpose(2, 1) 112 | values = values.transpose(2, 1) 113 | 114 | U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) 115 | u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 116 | 117 | U_part = U_part if U_part < L_K else L_K 118 | u = u if u < L_Q else L_Q 119 | 120 | scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) 121 | 122 | tau = 1.0 if tau is None else tau.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x 1 123 | delta = 0.0 if delta is None else delta.unsqueeze(1).unsqueeze(1) # B x 1 x 1 x S 124 | scores_top = scores_top * tau + delta 125 | 126 | # add scale factor 127 | scale = self.scale or 1. / sqrt(D) 128 | if scale is not None: 129 | scores_top = scores_top * scale 130 | # get the context 131 | context = self._get_initial_context(values, L_Q) 132 | # update the context with selected top_k queries 133 | context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask) 134 | 135 | return context.contiguous(), attn 136 | 137 | 138 | class AttentionLayer(nn.Module): 139 | def __init__(self, attention, d_model, n_heads, d_keys=None, 140 | d_values=None): 141 | super(AttentionLayer, self).__init__() 142 | 143 | d_keys = d_keys or (d_model // n_heads) 144 | d_values = d_values or (d_model // n_heads) 145 | 146 | self.inner_attention = attention 147 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 148 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 149 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 150 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 151 | self.n_heads = n_heads 152 | 153 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 154 | B, L, _ = queries.shape 155 | _, S, _ = keys.shape 156 | H = self.n_heads 157 | 158 | queries = self.query_projection(queries).view(B, L, H, -1) 159 | keys = self.key_projection(keys).view(B, S, H, -1) 160 | values = self.value_projection(values).view(B, S, H, -1) 161 | 162 | out, attn = self.inner_attention( 163 | queries, 164 | keys, 165 | values, 166 | attn_mask, 167 | tau, delta 168 | ) 169 | out = out.view(B, L, -1) 170 | 171 | return self.out_projection(out), attn 172 | 173 | -------------------------------------------------------------------------------- /ns_layers/Transformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class ConvLayer(nn.Module): 6 | def __init__(self, c_in): 7 | super(ConvLayer, self).__init__() 8 | self.downConv = nn.Conv1d(in_channels=c_in, 9 | out_channels=c_in, 10 | kernel_size=3, 11 | padding=2, 12 | padding_mode='circular') 13 | self.norm = nn.BatchNorm1d(c_in) 14 | self.activation = nn.ELU() 15 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 16 | 17 | def forward(self, x): 18 | x = self.downConv(x.permute(0, 2, 1)) # BxExS 19 | x = self.norm(x) 20 | x = self.activation(x) 21 | x = self.maxPool(x) 22 | x = x.transpose(1, 2) 23 | return x 24 | 25 | 26 | class EncoderLayer(nn.Module): 27 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 28 | super(EncoderLayer, self).__init__() 29 | d_ff = d_ff or 4 * d_model 30 | self.attention = attention 31 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 32 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 33 | self.norm1 = nn.LayerNorm(d_model) 34 | self.norm2 = nn.LayerNorm(d_model) 35 | self.dropout = nn.Dropout(dropout) 36 | self.activation = F.relu if activation == "relu" else F.gelu 37 | 38 | def forward(self, x, attn_mask=None, tau=None, delta=None): 39 | new_x, attn = self.attention( 40 | x, x, x, 41 | attn_mask=attn_mask, 42 | tau=tau, delta=delta 43 | ) 44 | x = x + self.dropout(new_x) 45 | 46 | y = x = self.norm1(x) 47 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 48 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 49 | 50 | return self.norm2(x + y), attn 51 | 52 | 53 | class Encoder(nn.Module): 54 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 55 | super(Encoder, self).__init__() 56 | self.attn_layers = nn.ModuleList(attn_layers) 57 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 58 | self.norm = norm_layer 59 | 60 | def forward(self, x, attn_mask=None, tau=None, delta=None): 61 | # x [B, L, D] 62 | attns = [] 63 | if self.conv_layers is not None: 64 | # The reason why we only import delta for the first attn_block of Encoder 65 | # is to integrate Informer into our framework, where row size of attention of Informer is changing each layer 66 | # and inconsistent to the sequence length of the initial input, 67 | # then no way to add delta to every row, so we make delta=0.0 (See our Appendix E.2) 68 | # 69 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): 70 | delta = delta if i==0 else None 71 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 72 | x = conv_layer(x) 73 | attns.append(attn) 74 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None) 75 | attns.append(attn) 76 | else: 77 | for attn_layer in self.attn_layers: 78 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 79 | attns.append(attn) 80 | 81 | if self.norm is not None: 82 | x = self.norm(x) 83 | 84 | return x, attns 85 | 86 | 87 | class DecoderLayer(nn.Module): 88 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 89 | dropout=0.1, activation="relu"): 90 | super(DecoderLayer, self).__init__() 91 | d_ff = d_ff or 4 * d_model 92 | self.self_attention = self_attention 93 | self.cross_attention = cross_attention 94 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 95 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 96 | self.norm1 = nn.LayerNorm(d_model) 97 | self.norm2 = nn.LayerNorm(d_model) 98 | self.norm3 = nn.LayerNorm(d_model) 99 | self.dropout = nn.Dropout(dropout) 100 | self.activation = F.relu if activation == "relu" else F.gelu 101 | 102 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 103 | # Note that delta only used for Self-Attention(x_enc with x_enc) 104 | # and Cross-Attention(x_enc with x_dec), 105 | # but not suitable for Self-Attention(x_dec with x_dec) 106 | 107 | x = x + self.dropout(self.self_attention( 108 | x, x, x, 109 | attn_mask=x_mask, 110 | tau=tau, delta=None 111 | )[0]) 112 | x = self.norm1(x) 113 | 114 | x = x + self.dropout(self.cross_attention( 115 | x, cross, cross, 116 | attn_mask=cross_mask, 117 | tau=tau, delta=delta 118 | )[0]) 119 | 120 | y = x = self.norm2(x) 121 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 122 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 123 | 124 | return self.norm3(x + y) 125 | 126 | 127 | class Decoder(nn.Module): 128 | def __init__(self, layers, norm_layer=None, projection=None): 129 | super(Decoder, self).__init__() 130 | self.layers = nn.ModuleList(layers) 131 | self.norm = norm_layer 132 | self.projection = projection 133 | 134 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 135 | for layer in self.layers: 136 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta) 137 | 138 | if self.norm is not None: 139 | x = self.norm(x) 140 | 141 | if self.projection is not None: 142 | x = self.projection(x) 143 | return x 144 | -------------------------------------------------------------------------------- /ns_layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/ns_layers/__init__.py -------------------------------------------------------------------------------- /ns_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/ns_models/__init__.py -------------------------------------------------------------------------------- /ns_models/ns_Autoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ns_layers.AutoCorrelation import DSAutoCorrelation, AutoCorrelationLayer 4 | from ns_layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp 5 | from layers.Embed import DataEmbedding_wo_pos 6 | 7 | 8 | class Projector(nn.Module): 9 | ''' 10 | MLP to learn the De-stationary factors 11 | ''' 12 | def __init__(self, enc_in, seq_len, hidden_dims, hidden_layers, output_dim, kernel_size=3): 13 | super(Projector, self).__init__() 14 | 15 | padding = 1 if torch.__version__ >= '1.5.0' else 2 16 | self.series_conv = nn.Conv1d(in_channels=seq_len, out_channels=1, kernel_size=kernel_size, padding=padding, padding_mode='circular', bias=False) 17 | 18 | layers = [nn.Linear(2 * enc_in, hidden_dims[0]), nn.ReLU()] 19 | for i in range(hidden_layers-1): 20 | layers += [nn.Linear(hidden_dims[i], hidden_dims[i+1]), nn.ReLU()] 21 | 22 | layers += [nn.Linear(hidden_dims[-1], output_dim, bias=False)] 23 | self.backbone = nn.Sequential(*layers) 24 | 25 | def forward(self, x, stats): 26 | # x: B x S x E 27 | # stats: B x 1 x E 28 | # y: B x O 29 | batch_size = x.shape[0] 30 | x = self.series_conv(x) # B x 1 x E 31 | x = torch.cat([x, stats], dim=1) # B x 2 x E 32 | x = x.view(batch_size, -1) # B x 2E 33 | y = self.backbone(x) # B x O 34 | 35 | return y 36 | 37 | class Model(nn.Module): 38 | """ 39 | Autoformer is the first method to achieve the series-wise connection, 40 | with inherent O(LlogL) complexity 41 | """ 42 | def __init__(self, configs): 43 | super(Model, self).__init__() 44 | self.seq_len = configs.seq_len 45 | self.label_len = configs.label_len 46 | self.pred_len = configs.pred_len 47 | self.output_attention = configs.output_attention 48 | 49 | # Decomp 50 | kernel_size = configs.moving_avg 51 | self.decomp = series_decomp(kernel_size) 52 | 53 | # Embedding 54 | # The series-wise connection inherently contains the sequential information. 55 | # Thus, we can discard the position embedding of transformers. 56 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, 57 | configs.dropout) 58 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, 59 | configs.dropout) 60 | 61 | # Encoder 62 | self.encoder = Encoder( 63 | [ 64 | EncoderLayer( 65 | AutoCorrelationLayer( 66 | DSAutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 67 | output_attention=configs.output_attention), 68 | configs.d_model, configs.n_heads), 69 | configs.d_model, 70 | configs.d_ff, 71 | moving_avg=configs.moving_avg, 72 | dropout=configs.dropout, 73 | activation=configs.activation 74 | ) for l in range(configs.e_layers) 75 | ], 76 | norm_layer=my_Layernorm(configs.d_model) 77 | ) 78 | # Decoder 79 | self.decoder = Decoder( 80 | [ 81 | DecoderLayer( 82 | AutoCorrelationLayer( 83 | DSAutoCorrelation(True, configs.factor, attention_dropout=configs.dropout, 84 | output_attention=False), 85 | configs.d_model, configs.n_heads), 86 | AutoCorrelationLayer( 87 | DSAutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 88 | output_attention=False), 89 | configs.d_model, configs.n_heads), 90 | configs.d_model, 91 | configs.c_out, 92 | configs.d_ff, 93 | moving_avg=configs.moving_avg, 94 | dropout=configs.dropout, 95 | activation=configs.activation, 96 | ) 97 | for l in range(configs.d_layers) 98 | ], 99 | norm_layer=my_Layernorm(configs.d_model), 100 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 101 | ) 102 | 103 | self.tau_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len, hidden_dims=configs.p_hidden_dims, hidden_layers=configs.p_hidden_layers, output_dim=1) 104 | self.delta_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len, hidden_dims=configs.p_hidden_dims, hidden_layers=configs.p_hidden_layers, output_dim=configs.seq_len) 105 | 106 | 107 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 108 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 109 | 110 | x_raw = x_enc.clone().detach() 111 | 112 | # Normalization 113 | mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E 114 | x_enc = x_enc - mean_enc 115 | std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E 116 | x_enc = x_enc / std_enc 117 | x_dec_new = torch.cat([x_enc[:, -self.label_len: , :], torch.zeros_like(x_dec[:, -self.pred_len:, :])], dim=1).to(x_enc.device).clone() 118 | 119 | tau = self.tau_learner(x_raw, std_enc).exp() # B x S x E, B x 1 x E -> B x 1, positive scalar 120 | delta = self.delta_learner(x_raw, mean_enc) # B x S x E, B x 1 x E -> B x S 121 | 122 | # Model Inference 123 | 124 | # decomp init 125 | mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) 126 | zeros = torch.zeros([x_dec_new.shape[0], self.pred_len, x_dec_new.shape[2]], device=x_enc.device) 127 | seasonal_init, trend_init = self.decomp(x_enc) 128 | # decoder input 129 | trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1) 130 | seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1) 131 | # enc 132 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 133 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask, tau=tau, delta=delta) 134 | # dec 135 | dec_out = self.dec_embedding(seasonal_init, x_mark_dec) 136 | seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, 137 | trend=trend_init, tau=tau, delta=None) 138 | # final 139 | dec_out = trend_part + seasonal_part 140 | 141 | # De-normalization 142 | dec_out = dec_out * std_enc + mean_enc 143 | 144 | if self.output_attention: 145 | return dec_out[:, -self.pred_len:, :], attns 146 | else: 147 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 148 | -------------------------------------------------------------------------------- /ns_models/ns_Informer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ns_layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer 4 | from ns_layers.SelfAttention_Family import DSProbAttention, AttentionLayer 5 | from layers.Embed import DataEmbedding 6 | 7 | 8 | class Projector(nn.Module): 9 | ''' 10 | MLP to learn the De-stationary factors 11 | ''' 12 | def __init__(self, enc_in, seq_len, hidden_dims, hidden_layers, output_dim, kernel_size=3): 13 | super(Projector, self).__init__() 14 | 15 | padding = 1 if torch.__version__ >= '1.5.0' else 2 16 | self.series_conv = nn.Conv1d(in_channels=seq_len, out_channels=1, kernel_size=kernel_size, padding=padding, padding_mode='circular', bias=False) 17 | 18 | layers = [nn.Linear(2 * enc_in, hidden_dims[0]), nn.ReLU()] 19 | for i in range(hidden_layers-1): 20 | layers += [nn.Linear(hidden_dims[i], hidden_dims[i+1]), nn.ReLU()] 21 | 22 | layers += [nn.Linear(hidden_dims[-1], output_dim, bias=False)] 23 | self.backbone = nn.Sequential(*layers) 24 | 25 | def forward(self, x, stats): 26 | # x: B x S x E 27 | # stats: B x 1 x E 28 | # y: B x O 29 | batch_size = x.shape[0] 30 | x = self.series_conv(x) # B x 1 x E 31 | x = torch.cat([x, stats], dim=1) # B x 2 x E 32 | x = x.view(batch_size, -1) # B x 2E 33 | y = self.backbone(x) # B x O 34 | 35 | return y 36 | 37 | class Model(nn.Module): 38 | """ 39 | Non-stationary Informer 40 | """ 41 | def __init__(self, configs): 42 | super(Model, self).__init__() 43 | self.pred_len = configs.pred_len 44 | self.label_len = configs.label_len 45 | self.output_attention = configs.output_attention 46 | 47 | # Embedding 48 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 49 | configs.dropout) 50 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 51 | configs.dropout) 52 | 53 | # Encoder 54 | self.encoder = Encoder( 55 | [ 56 | EncoderLayer( 57 | AttentionLayer( 58 | DSProbAttention(False, configs.factor, attention_dropout=configs.dropout, 59 | output_attention=configs.output_attention), 60 | configs.d_model, configs.n_heads), 61 | configs.d_model, 62 | configs.d_ff, 63 | dropout=configs.dropout, 64 | activation=configs.activation 65 | ) for l in range(configs.e_layers) 66 | ], 67 | [ 68 | ConvLayer( 69 | configs.d_model 70 | ) for l in range(configs.e_layers - 1) 71 | ] if configs.distil else None, 72 | norm_layer=torch.nn.LayerNorm(configs.d_model) 73 | ) 74 | # Decoder 75 | self.decoder = Decoder( 76 | [ 77 | DecoderLayer( 78 | AttentionLayer( 79 | DSProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False), 80 | configs.d_model, configs.n_heads), 81 | AttentionLayer( 82 | DSProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), 83 | configs.d_model, configs.n_heads), 84 | configs.d_model, 85 | configs.d_ff, 86 | dropout=configs.dropout, 87 | activation=configs.activation, 88 | ) 89 | for l in range(configs.d_layers) 90 | ], 91 | norm_layer=torch.nn.LayerNorm(configs.d_model), 92 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 93 | ) 94 | 95 | self.tau_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len, hidden_dims=configs.p_hidden_dims, hidden_layers=configs.p_hidden_layers, output_dim=1) 96 | self.delta_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len, hidden_dims=configs.p_hidden_dims, hidden_layers=configs.p_hidden_layers, output_dim=configs.seq_len) 97 | 98 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 99 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 100 | 101 | x_raw = x_enc.clone().detach() 102 | 103 | # Normalization 104 | mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E 105 | x_enc = x_enc - mean_enc 106 | std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E 107 | x_enc = x_enc / std_enc 108 | x_dec_new = torch.cat([x_enc[:, -self.label_len: , :], torch.zeros_like(x_dec[:, -self.pred_len:, :])], dim=1).to(x_enc.device).clone() 109 | 110 | tau = self.tau_learner(x_raw, std_enc).exp() # B x S x E, B x 1 x E -> B x 1, positive scalar 111 | delta = self.delta_learner(x_raw, mean_enc) # B x S x E, B x 1 x E -> B x S 112 | 113 | # Model Inference 114 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 115 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask, tau=tau, delta=delta) 116 | 117 | dec_out = self.dec_embedding(x_dec_new, x_mark_dec) 118 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, tau=tau, delta=None) 119 | 120 | # De-normalization 121 | dec_out = dec_out * std_enc + mean_enc 122 | 123 | if self.output_attention: 124 | return dec_out[:, -self.pred_len:, :], attns 125 | else: 126 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 127 | -------------------------------------------------------------------------------- /ns_models/ns_Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ns_layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer 4 | from ns_layers.SelfAttention_Family import DSAttention, AttentionLayer 5 | from layers.Embed import DataEmbedding 6 | 7 | 8 | class Projector(nn.Module): 9 | ''' 10 | MLP to learn the De-stationary factors 11 | ''' 12 | def __init__(self, enc_in, seq_len, hidden_dims, hidden_layers, output_dim, kernel_size=3): 13 | super(Projector, self).__init__() 14 | 15 | padding = 1 if torch.__version__ >= '1.5.0' else 2 16 | self.series_conv = nn.Conv1d(in_channels=seq_len, out_channels=1, kernel_size=kernel_size, padding=padding, padding_mode='circular', bias=False) 17 | 18 | layers = [nn.Linear(2 * enc_in, hidden_dims[0]), nn.ReLU()] 19 | for i in range(hidden_layers-1): 20 | layers += [nn.Linear(hidden_dims[i], hidden_dims[i+1]), nn.ReLU()] 21 | 22 | layers += [nn.Linear(hidden_dims[-1], output_dim, bias=False)] 23 | self.backbone = nn.Sequential(*layers) 24 | 25 | def forward(self, x, stats): 26 | # x: B x S x E 27 | # stats: B x 1 x E 28 | # y: B x O 29 | batch_size = x.shape[0] 30 | x = self.series_conv(x) # B x 1 x E 31 | x = torch.cat([x, stats], dim=1) # B x 2 x E 32 | x = x.view(batch_size, -1) # B x 2E 33 | y = self.backbone(x) # B x O 34 | 35 | return y 36 | 37 | class Model(nn.Module): 38 | """ 39 | Non-stationary Transformer 40 | """ 41 | def __init__(self, configs): 42 | super(Model, self).__init__() 43 | self.pred_len = configs.pred_len 44 | self.seq_len = configs.seq_len 45 | self.label_len = configs.label_len 46 | self.output_attention = configs.output_attention 47 | 48 | # Embedding 49 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 50 | configs.dropout) 51 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 52 | configs.dropout) 53 | # Encoder 54 | self.encoder = Encoder( 55 | [ 56 | EncoderLayer( 57 | AttentionLayer( 58 | DSAttention(False, configs.factor, attention_dropout=configs.dropout, 59 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 60 | configs.d_model, 61 | configs.d_ff, 62 | dropout=configs.dropout, 63 | activation=configs.activation 64 | ) for l in range(configs.e_layers) 65 | ], 66 | norm_layer=torch.nn.LayerNorm(configs.d_model) 67 | ) 68 | # Decoder 69 | self.decoder = Decoder( 70 | [ 71 | DecoderLayer( 72 | AttentionLayer( 73 | DSAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False), 74 | configs.d_model, configs.n_heads), 75 | AttentionLayer( 76 | DSAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), 77 | configs.d_model, configs.n_heads), 78 | configs.d_model, 79 | configs.d_ff, 80 | dropout=configs.dropout, 81 | activation=configs.activation, 82 | ) 83 | for l in range(configs.d_layers) 84 | ], 85 | norm_layer=torch.nn.LayerNorm(configs.d_model), 86 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 87 | ) 88 | 89 | self.tau_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len, hidden_dims=configs.p_hidden_dims, hidden_layers=configs.p_hidden_layers, output_dim=1) 90 | self.delta_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len, hidden_dims=configs.p_hidden_dims, hidden_layers=configs.p_hidden_layers, output_dim=configs.seq_len) 91 | 92 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 93 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 94 | 95 | x_raw = x_enc.clone().detach() 96 | 97 | # Normalization 98 | mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E 99 | x_enc = x_enc - mean_enc 100 | std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E 101 | x_enc = x_enc / std_enc 102 | x_dec_new = torch.cat([x_enc[:, -self.label_len: , :], torch.zeros_like(x_dec[:, -self.pred_len:, :])], dim=1).to(x_enc.device).clone() 103 | 104 | tau = self.tau_learner(x_raw, std_enc).exp() # B x S x E, B x 1 x E -> B x 1, positive scalar 105 | delta = self.delta_learner(x_raw, mean_enc) # B x S x E, B x 1 x E -> B x S 106 | 107 | # Model Inference 108 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 109 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask, tau=tau, delta=delta) 110 | 111 | dec_out = self.dec_embedding(x_dec_new, x_mark_dec) 112 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, tau=tau, delta=delta) 113 | 114 | # De-normalization 115 | dec_out = dec_out * std_enc + mean_enc 116 | 117 | if self.output_attention: 118 | return dec_out[:, -self.pred_len:, :], attns 119 | else: 120 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | sklearn 3 | torchvision 4 | numpy 5 | matplotlib 6 | torch 7 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from exp.exp_main import Exp_Main 4 | import random 5 | import numpy as np 6 | 7 | parser = argparse.ArgumentParser(description='Non-stationary Transformers for Time Series Forecasting') 8 | 9 | # basic config 10 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status') 11 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id') 12 | parser.add_argument('--model', type=str, required=True, default='Transformer', 13 | help='model name, options: [ns_Transformer, Transformer]') 14 | 15 | # data loader 16 | parser.add_argument('--data', type=str, required=True, default='ETTh2', help='dataset type') 17 | parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file') 18 | parser.add_argument('--data_path', type=str, default='ETTh2.csv', help='data file') 19 | parser.add_argument('--features', type=str, default='M', 20 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate') 21 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task') 22 | parser.add_argument('--freq', type=str, default='h', 23 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') 24 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints') 25 | 26 | # forecasting task 27 | parser.add_argument('--seq_len', type=int, default=96, help='input sequence length') 28 | parser.add_argument('--label_len', type=int, default=48, help='start token length') 29 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length') 30 | 31 | # model define 32 | parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') 33 | parser.add_argument('--dec_in', type=int, default=7, help='decoder input size') 34 | parser.add_argument('--c_out', type=int, default=7, help='output size') 35 | parser.add_argument('--d_model', type=int, default=512, help='dimension of model') 36 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads') 37 | parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers') 38 | parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers') 39 | parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn') 40 | parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average') 41 | parser.add_argument('--factor', type=int, default=1, help='attn factor') 42 | parser.add_argument('--distil', action='store_false', 43 | help='whether to use distilling in encoder, using this argument means not using distilling', 44 | default=True) 45 | parser.add_argument('--dropout', type=float, default=0.05, help='dropout') 46 | parser.add_argument('--embed', type=str, default='timeF', 47 | help='time features encoding, options:[timeF, fixed, learned]') 48 | parser.add_argument('--activation', type=str, default='gelu', help='activation') 49 | parser.add_argument('--output_attention', action='store_true', help='whether to output attention in encoder') 50 | parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data') 51 | 52 | # optimization 53 | parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') 54 | parser.add_argument('--itr', type=int, default=2, help='experiments times') 55 | parser.add_argument('--train_epochs', type=int, default=10, help='train epochs') 56 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data') 57 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience') 58 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate') 59 | parser.add_argument('--des', type=str, default='test', help='exp description') 60 | parser.add_argument('--loss', type=str, default='mse', help='loss function') 61 | parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate') 62 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False) 63 | 64 | # GPU 65 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') 66 | parser.add_argument('--gpu', type=int, default=0, help='gpu') 67 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) 68 | parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus') 69 | parser.add_argument('--seed', type=int, default=2021, help='random seed') 70 | 71 | # de-stationary projector params 72 | parser.add_argument('--p_hidden_dims', type=int, nargs='+', default=[128, 128], help='hidden layer dimensions of projector (List)') 73 | parser.add_argument('--p_hidden_layers', type=int, default=2, help='number of hidden layers in projector') 74 | 75 | args = parser.parse_args() 76 | 77 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False 78 | 79 | fix_seed = args.seed 80 | random.seed(fix_seed) 81 | torch.manual_seed(fix_seed) 82 | np.random.seed(fix_seed) 83 | 84 | if args.use_gpu: 85 | if args.use_multi_gpu: 86 | args.devices = args.devices.replace(' ', '') 87 | device_ids = args.devices.split(',') 88 | args.device_ids = [int(id_) for id_ in device_ids] 89 | args.gpu = args.device_ids[0] 90 | else: 91 | torch.cuda.set_device(args.gpu) 92 | 93 | print('Args in experiment:') 94 | print(args) 95 | 96 | Exp = Exp_Main 97 | 98 | if args.is_training: 99 | for ii in range(args.itr): 100 | # setting record of experiments 101 | setting = '{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format( 102 | args.model_id, 103 | args.model, 104 | args.data, 105 | args.features, 106 | args.seq_len, 107 | args.label_len, 108 | args.pred_len, 109 | args.d_model, 110 | args.n_heads, 111 | args.e_layers, 112 | args.d_layers, 113 | args.d_ff, 114 | args.factor, 115 | args.embed, 116 | args.distil, 117 | args.des, ii) 118 | 119 | exp = Exp(args) # set experiments 120 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) 121 | exp.train(setting) 122 | 123 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 124 | exp.test(setting) 125 | 126 | if args.do_predict: 127 | print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 128 | exp.predict(setting, True) 129 | 130 | torch.cuda.empty_cache() 131 | else: 132 | ii = 0 133 | setting = '{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format(args.model_id, 134 | args.model, 135 | args.data, 136 | args.features, 137 | args.seq_len, 138 | args.label_len, 139 | args.pred_len, 140 | args.d_model, 141 | args.n_heads, 142 | args.e_layers, 143 | args.d_layers, 144 | args.d_ff, 145 | args.factor, 146 | args.embed, 147 | args.distil, 148 | args.des, ii) 149 | 150 | exp = Exp(args) # set experiments 151 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 152 | exp.test(setting, test=1) 153 | torch.cuda.empty_cache() 154 | -------------------------------------------------------------------------------- /scripts/ECL_script/Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/electricity/ \ 4 | --data_path electricity.csv \ 5 | --model_id ECL_96_96 \ 6 | --model Autoformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --gpu 4 \ 16 | --enc_in 321 \ 17 | --dec_in 321 \ 18 | --c_out 321 \ 19 | --des 'Exp' \ 20 | --itr 1 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/electricity/ \ 25 | --data_path electricity.csv \ 26 | --model_id ECL_96_192 \ 27 | --model Autoformer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --gpu 5 \ 37 | --enc_in 321 \ 38 | --dec_in 321 \ 39 | --c_out 321 \ 40 | --des 'Exp' \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/electricity/ \ 46 | --data_path electricity.csv \ 47 | --model_id ECL_96_336 \ 48 | --model Autoformer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --gpu 6 \ 58 | --enc_in 321 \ 59 | --dec_in 321 \ 60 | --c_out 321 \ 61 | --des 'Exp' \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/electricity/ \ 67 | --data_path electricity.csv \ 68 | --model_id ECL_96_720 \ 69 | --model Autoformer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --gpu 7 \ 79 | --enc_in 321 \ 80 | --dec_in 321 \ 81 | --c_out 321 \ 82 | --des 'Exp' \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/ECL_script/Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/electricity/ \ 4 | --data_path electricity.csv \ 5 | --model_id ECL_96_96 \ 6 | --model Informer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --gpu 4 \ 16 | --enc_in 321 \ 17 | --dec_in 321 \ 18 | --c_out 321 \ 19 | --des 'Exp' \ 20 | --itr 1 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/electricity/ \ 25 | --data_path electricity.csv \ 26 | --model_id ECL_96_192 \ 27 | --model Informer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --gpu 5 \ 37 | --enc_in 321 \ 38 | --dec_in 321 \ 39 | --c_out 321 \ 40 | --des 'Exp' \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/electricity/ \ 46 | --data_path electricity.csv \ 47 | --model_id ECL_96_336 \ 48 | --model Informer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --gpu 6 \ 58 | --enc_in 321 \ 59 | --dec_in 321 \ 60 | --c_out 321 \ 61 | --des 'Exp' \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/electricity/ \ 67 | --data_path electricity.csv \ 68 | --model_id ECL_96_720 \ 69 | --model Informer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --gpu 7 \ 79 | --enc_in 321 \ 80 | --dec_in 321 \ 81 | --c_out 321 \ 82 | --des 'Exp' \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/ECL_script/Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/electricity/ \ 4 | --data_path electricity.csv \ 5 | --model_id ECL_96_96 \ 6 | --model Transformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --gpu 4 \ 16 | --enc_in 321 \ 17 | --dec_in 321 \ 18 | --c_out 321 \ 19 | --des 'Exp' \ 20 | --itr 1 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/electricity/ \ 25 | --data_path electricity.csv \ 26 | --model_id ECL_96_192 \ 27 | --model Transformer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --gpu 5 \ 37 | --enc_in 321 \ 38 | --dec_in 321 \ 39 | --c_out 321 \ 40 | --des 'Exp' \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/electricity/ \ 46 | --data_path electricity.csv \ 47 | --model_id ECL_96_336 \ 48 | --model Transformer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --gpu 6 \ 58 | --enc_in 321 \ 59 | --dec_in 321 \ 60 | --c_out 321 \ 61 | --des 'Exp' \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/electricity/ \ 67 | --data_path electricity.csv \ 68 | --model_id ECL_96_720 \ 69 | --model Transformer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --gpu 7 \ 79 | --enc_in 321 \ 80 | --dec_in 321 \ 81 | --c_out 321 \ 82 | --des 'Exp' \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/ECL_script/ns_Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/electricity/ \ 4 | --data_path electricity.csv \ 5 | --model_id ECL_96_96 \ 6 | --model ns_Autoformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 321 \ 16 | --dec_in 321 \ 17 | --c_out 321 \ 18 | --gpu 0 \ 19 | --des 'Exp_h256_l2' \ 20 | --p_hidden_dims 256 256 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/electricity/ \ 28 | --data_path electricity.csv \ 29 | --model_id ECL_96_192 \ 30 | --model ns_Autoformer \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --label_len 48 \ 35 | --pred_len 192 \ 36 | --e_layers 2 \ 37 | --d_layers 1 \ 38 | --factor 3 \ 39 | --enc_in 321 \ 40 | --dec_in 321 \ 41 | --c_out 321 \ 42 | --gpu 1 \ 43 | --des 'Exp_h256_l2' \ 44 | --p_hidden_dims 256 256 \ 45 | --p_hidden_layers 2 \ 46 | --itr 1 & 47 | 48 | 49 | python -u run.py \ 50 | --is_training 1 \ 51 | --root_path ./dataset/electricity/ \ 52 | --data_path electricity.csv \ 53 | --model_id ECL_96_336 \ 54 | --model ns_Autoformer \ 55 | --data custom \ 56 | --features M \ 57 | --seq_len 96 \ 58 | --label_len 48 \ 59 | --pred_len 336 \ 60 | --e_layers 2 \ 61 | --d_layers 1 \ 62 | --factor 3 \ 63 | --enc_in 321 \ 64 | --dec_in 321 \ 65 | --c_out 321 \ 66 | --gpu 2 \ 67 | --des 'Exp_h128_l2' \ 68 | --p_hidden_dims 128 128 \ 69 | --p_hidden_layers 2 \ 70 | --itr 1 & 71 | 72 | python -u run.py \ 73 | --is_training 1 \ 74 | --root_path ./dataset/electricity/ \ 75 | --data_path electricity.csv \ 76 | --model_id ECL_96_720 \ 77 | --model ns_Autoformer \ 78 | --data custom \ 79 | --features M \ 80 | --seq_len 96 \ 81 | --label_len 48 \ 82 | --pred_len 720 \ 83 | --e_layers 2 \ 84 | --d_layers 1 \ 85 | --factor 3 \ 86 | --enc_in 321 \ 87 | --dec_in 321 \ 88 | --c_out 321 \ 89 | --gpu 3 \ 90 | --des 'Exp_h128_l2' \ 91 | --p_hidden_dims 128 128 \ 92 | --p_hidden_layers 2 \ 93 | --itr 1 & 94 | -------------------------------------------------------------------------------- /scripts/ECL_script/ns_Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/electricity/ \ 4 | --data_path electricity.csv \ 5 | --model_id ECL_96_96 \ 6 | --model ns_Informer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 321 \ 16 | --dec_in 321 \ 17 | --c_out 321 \ 18 | --gpu 0 \ 19 | --des 'Exp_h256_l2' \ 20 | --p_hidden_dims 256 256 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/electricity/ \ 28 | --data_path electricity.csv \ 29 | --model_id ECL_96_192 \ 30 | --model ns_Informer \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --label_len 48 \ 35 | --pred_len 192 \ 36 | --e_layers 2 \ 37 | --d_layers 1 \ 38 | --factor 3 \ 39 | --enc_in 321 \ 40 | --dec_in 321 \ 41 | --c_out 321 \ 42 | --gpu 1 \ 43 | --des 'Exp_h256_l2' \ 44 | --p_hidden_dims 256 256 \ 45 | --p_hidden_layers 2 \ 46 | --itr 1 & 47 | 48 | 49 | python -u run.py \ 50 | --is_training 1 \ 51 | --root_path ./dataset/electricity/ \ 52 | --data_path electricity.csv \ 53 | --model_id ECL_96_336 \ 54 | --model ns_Informer \ 55 | --data custom \ 56 | --features M \ 57 | --seq_len 96 \ 58 | --label_len 48 \ 59 | --pred_len 336 \ 60 | --e_layers 2 \ 61 | --d_layers 1 \ 62 | --factor 3 \ 63 | --enc_in 321 \ 64 | --dec_in 321 \ 65 | --c_out 321 \ 66 | --gpu 2 \ 67 | --des 'Exp_h128_l2' \ 68 | --p_hidden_dims 128 128 \ 69 | --p_hidden_layers 2 \ 70 | --itr 1 & 71 | 72 | python -u run.py \ 73 | --is_training 1 \ 74 | --root_path ./dataset/electricity/ \ 75 | --data_path electricity.csv \ 76 | --model_id ECL_96_720 \ 77 | --model ns_Informer \ 78 | --data custom \ 79 | --features M \ 80 | --seq_len 96 \ 81 | --label_len 48 \ 82 | --pred_len 720 \ 83 | --e_layers 2 \ 84 | --d_layers 1 \ 85 | --factor 3 \ 86 | --enc_in 321 \ 87 | --dec_in 321 \ 88 | --c_out 321 \ 89 | --gpu 3 \ 90 | --des 'Exp_h128_l2' \ 91 | --p_hidden_dims 128 128 \ 92 | --p_hidden_layers 2 \ 93 | --itr 1 & 94 | -------------------------------------------------------------------------------- /scripts/ECL_script/ns_Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/electricity/ \ 4 | --data_path electricity.csv \ 5 | --model_id ECL_96_96 \ 6 | --model ns_Transformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 321 \ 16 | --dec_in 321 \ 17 | --c_out 321 \ 18 | --gpu 0 \ 19 | --des 'Exp_h256_l2' \ 20 | --p_hidden_dims 256 256 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/electricity/ \ 28 | --data_path electricity.csv \ 29 | --model_id ECL_96_192 \ 30 | --model ns_Transformer \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --label_len 48 \ 35 | --pred_len 192 \ 36 | --e_layers 2 \ 37 | --d_layers 1 \ 38 | --factor 3 \ 39 | --enc_in 321 \ 40 | --dec_in 321 \ 41 | --c_out 321 \ 42 | --gpu 1 \ 43 | --des 'Exp_h256_l2' \ 44 | --p_hidden_dims 256 256 \ 45 | --p_hidden_layers 2 \ 46 | --itr 1 & 47 | 48 | 49 | python -u run.py \ 50 | --is_training 1 \ 51 | --root_path ./dataset/electricity/ \ 52 | --data_path electricity.csv \ 53 | --model_id ECL_96_336 \ 54 | --model ns_Transformer \ 55 | --data custom \ 56 | --features M \ 57 | --seq_len 96 \ 58 | --label_len 48 \ 59 | --pred_len 336 \ 60 | --e_layers 2 \ 61 | --d_layers 1 \ 62 | --factor 3 \ 63 | --enc_in 321 \ 64 | --dec_in 321 \ 65 | --c_out 321 \ 66 | --gpu 2 \ 67 | --des 'Exp_h128_l2' \ 68 | --p_hidden_dims 128 128 \ 69 | --p_hidden_layers 2 \ 70 | --itr 1 & 71 | 72 | python -u run.py \ 73 | --is_training 1 \ 74 | --root_path ./dataset/electricity/ \ 75 | --data_path electricity.csv \ 76 | --model_id ECL_96_720 \ 77 | --model ns_Transformer \ 78 | --data custom \ 79 | --features M \ 80 | --seq_len 96 \ 81 | --label_len 48 \ 82 | --pred_len 720 \ 83 | --e_layers 2 \ 84 | --d_layers 1 \ 85 | --factor 3 \ 86 | --enc_in 321 \ 87 | --dec_in 321 \ 88 | --c_out 321 \ 89 | --gpu 3 \ 90 | --des 'Exp_h128_l2' \ 91 | --p_hidden_dims 128 128 \ 92 | --p_hidden_layers 2 \ 93 | --itr 1 & 94 | -------------------------------------------------------------------------------- /scripts/ETT_script/Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/ETT-small/ \ 4 | --data_path ETTh2.csv \ 5 | --model_id ETTh2_96_96 \ 6 | --model Autoformer \ 7 | --data ETTh2 \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --enc_in 7 \ 15 | --dec_in 7 \ 16 | --c_out 7 \ 17 | --gpu 4 \ 18 | --des 'Exp' \ 19 | --itr 1 & 20 | 21 | python -u run.py \ 22 | --is_training 1 \ 23 | --root_path ./dataset/ETT-small/ \ 24 | --data_path ETTh2.csv \ 25 | --model_id ETTh2_96_192 \ 26 | --model Autoformer \ 27 | --data ETTh2 \ 28 | --features M \ 29 | --seq_len 96 \ 30 | --label_len 48 \ 31 | --pred_len 192 \ 32 | --e_layers 2 \ 33 | --d_layers 1 \ 34 | --enc_in 7 \ 35 | --dec_in 7 \ 36 | --c_out 7 \ 37 | --gpu 5 \ 38 | --des 'Exp' \ 39 | --itr 1 & 40 | 41 | python -u run.py \ 42 | --is_training 1 \ 43 | --root_path ./dataset/ETT-small/ \ 44 | --data_path ETTh2.csv \ 45 | --model_id ETTh2_96_336 \ 46 | --model Autoformer \ 47 | --data ETTh2 \ 48 | --features M \ 49 | --seq_len 96 \ 50 | --label_len 48 \ 51 | --pred_len 336 \ 52 | --e_layers 2 \ 53 | --d_layers 1 \ 54 | --enc_in 7 \ 55 | --dec_in 7 \ 56 | --c_out 7 \ 57 | --gpu 6 \ 58 | --des 'Exp' \ 59 | --itr 1 & 60 | 61 | python -u run.py \ 62 | --is_training 1 \ 63 | --root_path ./dataset/ETT-small/ \ 64 | --data_path ETTh2.csv \ 65 | --model_id ETTh2_96_720 \ 66 | --model Autoformer \ 67 | --data ETTh2 \ 68 | --features M \ 69 | --seq_len 96 \ 70 | --label_len 48 \ 71 | --pred_len 720 \ 72 | --e_layers 2 \ 73 | --d_layers 1 \ 74 | --enc_in 7 \ 75 | --dec_in 7 \ 76 | --c_out 7 \ 77 | --gpu 7 \ 78 | --des 'Exp' \ 79 | --itr 1 & 80 | -------------------------------------------------------------------------------- /scripts/ETT_script/Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/ETT-small/ \ 4 | --data_path ETTh2.csv \ 5 | --model_id ETTh2_96_96 \ 6 | --model Informer \ 7 | --data ETTh2 \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --enc_in 7 \ 15 | --dec_in 7 \ 16 | --c_out 7 \ 17 | --gpu 4 \ 18 | --des 'Exp' \ 19 | --itr 1 & 20 | 21 | python -u run.py \ 22 | --is_training 1 \ 23 | --root_path ./dataset/ETT-small/ \ 24 | --data_path ETTh2.csv \ 25 | --model_id ETTh2_96_192 \ 26 | --model Informer \ 27 | --data ETTh2 \ 28 | --features M \ 29 | --seq_len 96 \ 30 | --label_len 48 \ 31 | --pred_len 192 \ 32 | --e_layers 2 \ 33 | --d_layers 1 \ 34 | --enc_in 7 \ 35 | --dec_in 7 \ 36 | --c_out 7 \ 37 | --gpu 5 \ 38 | --des 'Exp' \ 39 | --itr 1 & 40 | 41 | python -u run.py \ 42 | --is_training 1 \ 43 | --root_path ./dataset/ETT-small/ \ 44 | --data_path ETTh2.csv \ 45 | --model_id ETTh2_96_336 \ 46 | --model Informer \ 47 | --data ETTh2 \ 48 | --features M \ 49 | --seq_len 96 \ 50 | --label_len 48 \ 51 | --pred_len 336 \ 52 | --e_layers 2 \ 53 | --d_layers 1 \ 54 | --enc_in 7 \ 55 | --dec_in 7 \ 56 | --c_out 7 \ 57 | --gpu 6 \ 58 | --des 'Exp' \ 59 | --itr 1 & 60 | 61 | python -u run.py \ 62 | --is_training 1 \ 63 | --root_path ./dataset/ETT-small/ \ 64 | --data_path ETTh2.csv \ 65 | --model_id ETTh2_96_720 \ 66 | --model Informer \ 67 | --data ETTh2 \ 68 | --features M \ 69 | --seq_len 96 \ 70 | --label_len 48 \ 71 | --pred_len 720 \ 72 | --e_layers 2 \ 73 | --d_layers 1 \ 74 | --enc_in 7 \ 75 | --dec_in 7 \ 76 | --c_out 7 \ 77 | --gpu 7 \ 78 | --des 'Exp' \ 79 | --itr 1 & 80 | -------------------------------------------------------------------------------- /scripts/ETT_script/Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/ETT-small/ \ 4 | --data_path ETTh2.csv \ 5 | --model_id ETTh2_96_96 \ 6 | --model Transformer \ 7 | --data ETTh2 \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --enc_in 7 \ 15 | --dec_in 7 \ 16 | --c_out 7 \ 17 | --gpu 4 \ 18 | --des 'Exp' \ 19 | --itr 1 & 20 | 21 | python -u run.py \ 22 | --is_training 1 \ 23 | --root_path ./dataset/ETT-small/ \ 24 | --data_path ETTh2.csv \ 25 | --model_id ETTh2_96_192 \ 26 | --model Transformer \ 27 | --data ETTh2 \ 28 | --features M \ 29 | --seq_len 96 \ 30 | --label_len 48 \ 31 | --pred_len 192 \ 32 | --e_layers 2 \ 33 | --d_layers 1 \ 34 | --enc_in 7 \ 35 | --dec_in 7 \ 36 | --c_out 7 \ 37 | --gpu 5 \ 38 | --des 'Exp' \ 39 | --itr 1 & 40 | 41 | python -u run.py \ 42 | --is_training 1 \ 43 | --root_path ./dataset/ETT-small/ \ 44 | --data_path ETTh2.csv \ 45 | --model_id ETTh2_96_336 \ 46 | --model Transformer \ 47 | --data ETTh2 \ 48 | --features M \ 49 | --seq_len 96 \ 50 | --label_len 48 \ 51 | --pred_len 336 \ 52 | --e_layers 2 \ 53 | --d_layers 1 \ 54 | --enc_in 7 \ 55 | --dec_in 7 \ 56 | --c_out 7 \ 57 | --gpu 6 \ 58 | --des 'Exp' \ 59 | --itr 1 & 60 | 61 | python -u run.py \ 62 | --is_training 1 \ 63 | --root_path ./dataset/ETT-small/ \ 64 | --data_path ETTh2.csv \ 65 | --model_id ETTh2_96_720 \ 66 | --model Transformer \ 67 | --data ETTh2 \ 68 | --features M \ 69 | --seq_len 96 \ 70 | --label_len 48 \ 71 | --pred_len 720 \ 72 | --e_layers 2 \ 73 | --d_layers 1 \ 74 | --enc_in 7 \ 75 | --dec_in 7 \ 76 | --c_out 7 \ 77 | --gpu 7 \ 78 | --des 'Exp' \ 79 | --itr 1 & 80 | -------------------------------------------------------------------------------- /scripts/ETT_script/ns_Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/ETT-small/ \ 4 | --data_path ETTh2.csv \ 5 | --model_id ETTh2_96_96 \ 6 | --model ns_Autoformer \ 7 | --data ETTh2 \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --enc_in 7 \ 15 | --dec_in 7 \ 16 | --c_out 7 \ 17 | --gpu 0 \ 18 | --des 'Exp_h256_l2' \ 19 | --p_hidden_dims 256 256 \ 20 | --p_hidden_layers 2 \ 21 | --itr 1 & 22 | 23 | python -u run.py \ 24 | --is_training 1 \ 25 | --root_path ./dataset/ETT-small/ \ 26 | --data_path ETTh2.csv \ 27 | --model_id ETTh2_96_192 \ 28 | --model ns_Autoformer \ 29 | --data ETTh2 \ 30 | --features M \ 31 | --seq_len 96 \ 32 | --label_len 48 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --d_layers 1 \ 36 | --enc_in 7 \ 37 | --dec_in 7 \ 38 | --c_out 7 \ 39 | --gpu 1 \ 40 | --des 'Exp_h64_l2' \ 41 | --p_hidden_dims 64 64 \ 42 | --p_hidden_layers 2 \ 43 | --itr 1 & 44 | 45 | python -u run.py \ 46 | --is_training 1 \ 47 | --root_path ./dataset/ETT-small/ \ 48 | --data_path ETTh2.csv \ 49 | --model_id ETTh2_96_336 \ 50 | --model ns_Autoformer \ 51 | --data ETTh2 \ 52 | --features M \ 53 | --seq_len 96 \ 54 | --label_len 48 \ 55 | --pred_len 336 \ 56 | --e_layers 2 \ 57 | --d_layers 1 \ 58 | --enc_in 7 \ 59 | --dec_in 7 \ 60 | --c_out 7 \ 61 | --gpu 2 \ 62 | --des 'Exp_h256_l2' \ 63 | --p_hidden_dims 256 256 \ 64 | --p_hidden_layers 2 \ 65 | --itr 1 & 66 | 67 | python -u run.py \ 68 | --is_training 1 \ 69 | --root_path ./dataset/ETT-small/ \ 70 | --data_path ETTh2.csv \ 71 | --model_id ETTh2_96_720 \ 72 | --model ns_Autoformer \ 73 | --data ETTh2 \ 74 | --features M \ 75 | --seq_len 96 \ 76 | --label_len 48 \ 77 | --pred_len 720 \ 78 | --e_layers 2 \ 79 | --d_layers 1 \ 80 | --enc_in 7 \ 81 | --dec_in 7 \ 82 | --c_out 7 \ 83 | --gpu 3 \ 84 | --des 'Exp_h256_l2' \ 85 | --p_hidden_dims 256 256 \ 86 | --p_hidden_layers 2 \ 87 | --itr 1 & 88 | -------------------------------------------------------------------------------- /scripts/ETT_script/ns_Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/ETT-small/ \ 4 | --data_path ETTh2.csv \ 5 | --model_id ETTh2_96_96 \ 6 | --model ns_Informer \ 7 | --data ETTh2 \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --enc_in 7 \ 15 | --dec_in 7 \ 16 | --c_out 7 \ 17 | --gpu 0 \ 18 | --des 'Exp_h256_l2' \ 19 | --p_hidden_dims 256 256 \ 20 | --p_hidden_layers 2 \ 21 | --itr 1 & 22 | 23 | python -u run.py \ 24 | --is_training 1 \ 25 | --root_path ./dataset/ETT-small/ \ 26 | --data_path ETTh2.csv \ 27 | --model_id ETTh2_96_192 \ 28 | --model ns_Informer \ 29 | --data ETTh2 \ 30 | --features M \ 31 | --seq_len 96 \ 32 | --label_len 48 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --d_layers 1 \ 36 | --enc_in 7 \ 37 | --dec_in 7 \ 38 | --c_out 7 \ 39 | --gpu 1 \ 40 | --des 'Exp_h64_l2' \ 41 | --p_hidden_dims 64 64 \ 42 | --p_hidden_layers 2 \ 43 | --itr 1 & 44 | 45 | python -u run.py \ 46 | --is_training 1 \ 47 | --root_path ./dataset/ETT-small/ \ 48 | --data_path ETTh2.csv \ 49 | --model_id ETTh2_96_336 \ 50 | --model ns_Informer \ 51 | --data ETTh2 \ 52 | --features M \ 53 | --seq_len 96 \ 54 | --label_len 48 \ 55 | --pred_len 336 \ 56 | --e_layers 2 \ 57 | --d_layers 1 \ 58 | --enc_in 7 \ 59 | --dec_in 7 \ 60 | --c_out 7 \ 61 | --gpu 2 \ 62 | --des 'Exp_h256_l2' \ 63 | --p_hidden_dims 256 256 \ 64 | --p_hidden_layers 2 \ 65 | --itr 1 & 66 | 67 | python -u run.py \ 68 | --is_training 1 \ 69 | --root_path ./dataset/ETT-small/ \ 70 | --data_path ETTh2.csv \ 71 | --model_id ETTh2_96_720 \ 72 | --model ns_Informer \ 73 | --data ETTh2 \ 74 | --features M \ 75 | --seq_len 96 \ 76 | --label_len 48 \ 77 | --pred_len 720 \ 78 | --e_layers 2 \ 79 | --d_layers 1 \ 80 | --enc_in 7 \ 81 | --dec_in 7 \ 82 | --c_out 7 \ 83 | --gpu 3 \ 84 | --des 'Exp_h256_l2' \ 85 | --p_hidden_dims 256 256 \ 86 | --p_hidden_layers 2 \ 87 | --itr 1 & 88 | -------------------------------------------------------------------------------- /scripts/ETT_script/ns_Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/ETT-small/ \ 4 | --data_path ETTh2.csv \ 5 | --model_id ETTh2_96_96 \ 6 | --model ns_Transformer \ 7 | --data ETTh2 \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --enc_in 7 \ 15 | --dec_in 7 \ 16 | --c_out 7 \ 17 | --gpu 0 \ 18 | --des 'Exp_h256_l2' \ 19 | --p_hidden_dims 256 256 \ 20 | --p_hidden_layers 2 \ 21 | --itr 1 & 22 | 23 | python -u run.py \ 24 | --is_training 1 \ 25 | --root_path ./dataset/ETT-small/ \ 26 | --data_path ETTh2.csv \ 27 | --model_id ETTh2_96_192 \ 28 | --model ns_Transformer \ 29 | --data ETTh2 \ 30 | --features M \ 31 | --seq_len 96 \ 32 | --label_len 48 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --d_layers 1 \ 36 | --enc_in 7 \ 37 | --dec_in 7 \ 38 | --c_out 7 \ 39 | --gpu 1 \ 40 | --des 'Exp_h64_l2' \ 41 | --p_hidden_dims 64 64 \ 42 | --p_hidden_layers 2 \ 43 | --itr 1 & 44 | 45 | python -u run.py \ 46 | --is_training 1 \ 47 | --root_path ./dataset/ETT-small/ \ 48 | --data_path ETTh2.csv \ 49 | --model_id ETTh2_96_336 \ 50 | --model ns_Transformer \ 51 | --data ETTh2 \ 52 | --features M \ 53 | --seq_len 96 \ 54 | --label_len 48 \ 55 | --pred_len 336 \ 56 | --e_layers 2 \ 57 | --d_layers 1 \ 58 | --enc_in 7 \ 59 | --dec_in 7 \ 60 | --c_out 7 \ 61 | --gpu 2 \ 62 | --des 'Exp_h256_l2' \ 63 | --p_hidden_dims 256 256 \ 64 | --p_hidden_layers 2 \ 65 | --itr 1 & 66 | 67 | python -u run.py \ 68 | --is_training 1 \ 69 | --root_path ./dataset/ETT-small/ \ 70 | --data_path ETTh2.csv \ 71 | --model_id ETTh2_96_720 \ 72 | --model ns_Transformer \ 73 | --data ETTh2 \ 74 | --features M \ 75 | --seq_len 96 \ 76 | --label_len 48 \ 77 | --pred_len 720 \ 78 | --e_layers 2 \ 79 | --d_layers 1 \ 80 | --enc_in 7 \ 81 | --dec_in 7 \ 82 | --c_out 7 \ 83 | --gpu 3 \ 84 | --des 'Exp_h256_l2' \ 85 | --p_hidden_dims 256 256 \ 86 | --p_hidden_layers 2 \ 87 | --itr 1 & 88 | -------------------------------------------------------------------------------- /scripts/Exchange_script/Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/exchange_rate/ \ 4 | --data_path exchange_rate.csv \ 5 | --model_id Exchange_96_96 \ 6 | --model Autoformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 8 \ 16 | --dec_in 8 \ 17 | --c_out 8 \ 18 | --gpu 4 \ 19 | --des 'Exp' \ 20 | --itr 1 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/exchange_rate/ \ 25 | --data_path exchange_rate.csv \ 26 | --model_id Exchange_96_192 \ 27 | --model Autoformer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 8 \ 37 | --dec_in 8 \ 38 | --c_out 8 \ 39 | --gpu 5 \ 40 | --des 'Exp' \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/exchange_rate/ \ 46 | --data_path exchange_rate.csv \ 47 | --model_id Exchange_96_336 \ 48 | --model Autoformer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 8 \ 58 | --dec_in 8 \ 59 | --c_out 8 \ 60 | --gpu 6 \ 61 | --des 'Exp' \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/exchange_rate/ \ 67 | --data_path exchange_rate.csv \ 68 | --model_id Exchange_96_720 \ 69 | --model Autoformer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 8 \ 79 | --dec_in 8 \ 80 | --c_out 8 \ 81 | --gpu 7 \ 82 | --des 'Exp' \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/Exchange_script/Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/exchange_rate/ \ 4 | --data_path exchange_rate.csv \ 5 | --model_id Exchange_96_96 \ 6 | --model Informer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 8 \ 16 | --dec_in 8 \ 17 | --c_out 8 \ 18 | --gpu 4 \ 19 | --des 'Exp' \ 20 | --itr 1 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/exchange_rate/ \ 25 | --data_path exchange_rate.csv \ 26 | --model_id Exchange_96_192 \ 27 | --model Informer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 8 \ 37 | --dec_in 8 \ 38 | --c_out 8 \ 39 | --gpu 5 \ 40 | --des 'Exp' \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/exchange_rate/ \ 46 | --data_path exchange_rate.csv \ 47 | --model_id Exchange_96_336 \ 48 | --model Informer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 8 \ 58 | --dec_in 8 \ 59 | --c_out 8 \ 60 | --gpu 6 \ 61 | --des 'Exp' \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/exchange_rate/ \ 67 | --data_path exchange_rate.csv \ 68 | --model_id Exchange_96_720 \ 69 | --model Informer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 8 \ 79 | --dec_in 8 \ 80 | --c_out 8 \ 81 | --gpu 7 \ 82 | --des 'Exp' \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/Exchange_script/Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/exchange_rate/ \ 4 | --data_path exchange_rate.csv \ 5 | --model_id Exchange_96_96 \ 6 | --model Transformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 8 \ 16 | --dec_in 8 \ 17 | --c_out 8 \ 18 | --gpu 4 \ 19 | --des 'Exp' \ 20 | --itr 1 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/exchange_rate/ \ 25 | --data_path exchange_rate.csv \ 26 | --model_id Exchange_96_192 \ 27 | --model Transformer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 8 \ 37 | --dec_in 8 \ 38 | --c_out 8 \ 39 | --gpu 5 \ 40 | --des 'Exp' \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/exchange_rate/ \ 46 | --data_path exchange_rate.csv \ 47 | --model_id Exchange_96_336 \ 48 | --model Transformer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 8 \ 58 | --dec_in 8 \ 59 | --c_out 8 \ 60 | --gpu 6 \ 61 | --des 'Exp' \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/exchange_rate/ \ 67 | --data_path exchange_rate.csv \ 68 | --model_id Exchange_96_720 \ 69 | --model Transformer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 8 \ 79 | --dec_in 8 \ 80 | --c_out 8 \ 81 | --gpu 7 \ 82 | --des 'Exp' \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/Exchange_script/ns_Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/exchange_rate/ \ 4 | --data_path exchange_rate.csv \ 5 | --model_id Exchange_96_96 \ 6 | --model ns_Autoformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 8 \ 16 | --dec_in 8 \ 17 | --c_out 8 \ 18 | --gpu 0 \ 19 | --des 'Exp_h16_l2' \ 20 | --p_hidden_dims 16 16 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/exchange_rate/ \ 27 | --data_path exchange_rate.csv \ 28 | --model_id Exchange_96_192 \ 29 | --model ns_Autoformer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --label_len 48 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 8 \ 39 | --dec_in 8 \ 40 | --c_out 8 \ 41 | --gpu 1 \ 42 | --des 'Exp_h16_l2' \ 43 | --p_hidden_dims 16 16 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/exchange_rate/ \ 50 | --data_path exchange_rate.csv \ 51 | --model_id Exchange_96_336 \ 52 | --model ns_Autoformer \ 53 | --data custom \ 54 | --features M \ 55 | --seq_len 96 \ 56 | --label_len 48 \ 57 | --pred_len 336 \ 58 | --e_layers 2 \ 59 | --d_layers 1 \ 60 | --factor 3 \ 61 | --enc_in 8 \ 62 | --dec_in 8 \ 63 | --c_out 8 \ 64 | --gpu 2 \ 65 | --des 'Exp_h64_l1' \ 66 | --p_hidden_dims 64 \ 67 | --p_hidden_layers 1 \ 68 | --itr 1 & 69 | 70 | 71 | python -u run.py \ 72 | --is_training 1 \ 73 | --root_path ./dataset/exchange_rate/ \ 74 | --data_path exchange_rate.csv \ 75 | --model_id Exchange_96_720 \ 76 | --model ns_Autoformer \ 77 | --data custom \ 78 | --features M \ 79 | --seq_len 96 \ 80 | --label_len 48 \ 81 | --pred_len 720 \ 82 | --e_layers 2 \ 83 | --d_layers 1 \ 84 | --factor 3 \ 85 | --enc_in 8 \ 86 | --dec_in 8 \ 87 | --c_out 8 \ 88 | --gpu 3 \ 89 | --des 'Exp_h64_l2' \ 90 | --p_hidden_dims 64 64 \ 91 | --p_hidden_layers 2 \ 92 | --itr 1 & 93 | -------------------------------------------------------------------------------- /scripts/Exchange_script/ns_Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/exchange_rate/ \ 4 | --data_path exchange_rate.csv \ 5 | --model_id Exchange_96_96 \ 6 | --model ns_Informer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 8 \ 16 | --dec_in 8 \ 17 | --c_out 8 \ 18 | --gpu 0 \ 19 | --des 'Exp_h16_l2' \ 20 | --p_hidden_dims 16 16 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/exchange_rate/ \ 27 | --data_path exchange_rate.csv \ 28 | --model_id Exchange_96_192 \ 29 | --model ns_Informer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --label_len 48 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 8 \ 39 | --dec_in 8 \ 40 | --c_out 8 \ 41 | --gpu 1 \ 42 | --des 'Exp_h16_l2' \ 43 | --p_hidden_dims 16 16 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/exchange_rate/ \ 50 | --data_path exchange_rate.csv \ 51 | --model_id Exchange_96_336 \ 52 | --model ns_Informer \ 53 | --data custom \ 54 | --features M \ 55 | --seq_len 96 \ 56 | --label_len 48 \ 57 | --pred_len 336 \ 58 | --e_layers 2 \ 59 | --d_layers 1 \ 60 | --factor 3 \ 61 | --enc_in 8 \ 62 | --dec_in 8 \ 63 | --c_out 8 \ 64 | --gpu 2 \ 65 | --des 'Exp_h64_l1' \ 66 | --p_hidden_dims 64 \ 67 | --p_hidden_layers 1 \ 68 | --itr 1 & 69 | 70 | 71 | python -u run.py \ 72 | --is_training 1 \ 73 | --root_path ./dataset/exchange_rate/ \ 74 | --data_path exchange_rate.csv \ 75 | --model_id Exchange_96_720 \ 76 | --model ns_Informer \ 77 | --data custom \ 78 | --features M \ 79 | --seq_len 96 \ 80 | --label_len 48 \ 81 | --pred_len 720 \ 82 | --e_layers 2 \ 83 | --d_layers 1 \ 84 | --factor 3 \ 85 | --enc_in 8 \ 86 | --dec_in 8 \ 87 | --c_out 8 \ 88 | --gpu 3 \ 89 | --des 'Exp_h64_l2' \ 90 | --p_hidden_dims 64 64 \ 91 | --p_hidden_layers 2 \ 92 | --itr 1 & 93 | -------------------------------------------------------------------------------- /scripts/Exchange_script/ns_Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/exchange_rate/ \ 4 | --data_path exchange_rate.csv \ 5 | --model_id Exchange_96_96 \ 6 | --model ns_Transformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 8 \ 16 | --dec_in 8 \ 17 | --c_out 8 \ 18 | --gpu 0 \ 19 | --des 'Exp_h16_l2' \ 20 | --p_hidden_dims 16 16 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/exchange_rate/ \ 27 | --data_path exchange_rate.csv \ 28 | --model_id Exchange_96_192 \ 29 | --model ns_Transformer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --label_len 48 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 8 \ 39 | --dec_in 8 \ 40 | --c_out 8 \ 41 | --gpu 1 \ 42 | --des 'Exp_h16_l2' \ 43 | --p_hidden_dims 16 16 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/exchange_rate/ \ 50 | --data_path exchange_rate.csv \ 51 | --model_id Exchange_96_336 \ 52 | --model ns_Transformer \ 53 | --data custom \ 54 | --features M \ 55 | --seq_len 96 \ 56 | --label_len 48 \ 57 | --pred_len 336 \ 58 | --e_layers 2 \ 59 | --d_layers 1 \ 60 | --factor 3 \ 61 | --enc_in 8 \ 62 | --dec_in 8 \ 63 | --c_out 8 \ 64 | --gpu 2 \ 65 | --des 'Exp_h64_l1' \ 66 | --p_hidden_dims 64 \ 67 | --p_hidden_layers 1 \ 68 | --itr 1 & 69 | 70 | 71 | python -u run.py \ 72 | --is_training 1 \ 73 | --root_path ./dataset/exchange_rate/ \ 74 | --data_path exchange_rate.csv \ 75 | --model_id Exchange_96_720 \ 76 | --model ns_Transformer \ 77 | --data custom \ 78 | --features M \ 79 | --seq_len 96 \ 80 | --label_len 48 \ 81 | --pred_len 720 \ 82 | --e_layers 2 \ 83 | --d_layers 1 \ 84 | --factor 3 \ 85 | --enc_in 8 \ 86 | --dec_in 8 \ 87 | --c_out 8 \ 88 | --gpu 3 \ 89 | --des 'Exp_h64_l2' \ 90 | --p_hidden_dims 64 64 \ 91 | --p_hidden_layers 2 \ 92 | --itr 1 & 93 | -------------------------------------------------------------------------------- /scripts/ILI_script/Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/illness/ \ 4 | --data_path national_illness.csv \ 5 | --model_id ili_36_24 \ 6 | --model Autoformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 36 \ 10 | --label_len 18 \ 11 | --pred_len 24 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 7 \ 16 | --dec_in 7 \ 17 | --c_out 7 \ 18 | --gpu 4 \ 19 | --des 'Exp' \ 20 | --itr 1 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/illness/ \ 25 | --data_path national_illness.csv \ 26 | --model_id ili_36_36 \ 27 | --model Autoformer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 36 \ 31 | --label_len 18 \ 32 | --pred_len 36 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 7 \ 37 | --dec_in 7 \ 38 | --c_out 7 \ 39 | --gpu 5 \ 40 | --des 'Exp' \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/illness/ \ 46 | --data_path national_illness.csv \ 47 | --model_id ili_36_48 \ 48 | --model Autoformer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 36 \ 52 | --label_len 18 \ 53 | --pred_len 48 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 7 \ 58 | --dec_in 7 \ 59 | --c_out 7 \ 60 | --gpu 6 \ 61 | --des 'Exp' \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/illness/ \ 67 | --data_path national_illness.csv \ 68 | --model_id ili_36_60 \ 69 | --model Autoformer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 36 \ 73 | --label_len 18 \ 74 | --pred_len 60 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 7 \ 79 | --dec_in 7 \ 80 | --c_out 7 \ 81 | --gpu 7 \ 82 | --des 'Exp' \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/ILI_script/Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/illness/ \ 4 | --data_path national_illness.csv \ 5 | --model_id ili_36_24 \ 6 | --model Informer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 36 \ 10 | --label_len 18 \ 11 | --pred_len 24 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 7 \ 16 | --dec_in 7 \ 17 | --c_out 7 \ 18 | --gpu 4 \ 19 | --des 'Exp' \ 20 | --itr 1 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/illness/ \ 25 | --data_path national_illness.csv \ 26 | --model_id ili_36_36 \ 27 | --model Informer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 36 \ 31 | --label_len 18 \ 32 | --pred_len 36 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 7 \ 37 | --dec_in 7 \ 38 | --c_out 7 \ 39 | --gpu 5 \ 40 | --des 'Exp' \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/illness/ \ 46 | --data_path national_illness.csv \ 47 | --model_id ili_36_48 \ 48 | --model Informer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 36 \ 52 | --label_len 18 \ 53 | --pred_len 48 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 7 \ 58 | --dec_in 7 \ 59 | --c_out 7 \ 60 | --gpu 6 \ 61 | --des 'Exp' \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/illness/ \ 67 | --data_path national_illness.csv \ 68 | --model_id ili_36_60 \ 69 | --model Informer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 36 \ 73 | --label_len 18 \ 74 | --pred_len 60 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 7 \ 79 | --dec_in 7 \ 80 | --c_out 7 \ 81 | --gpu 7 \ 82 | --des 'Exp' \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/ILI_script/Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/illness/ \ 4 | --data_path national_illness.csv \ 5 | --model_id ili_36_24 \ 6 | --model Transformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 36 \ 10 | --label_len 18 \ 11 | --pred_len 24 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 7 \ 16 | --dec_in 7 \ 17 | --c_out 7 \ 18 | --gpu 4 \ 19 | --des 'Exp' \ 20 | --itr 1 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/illness/ \ 25 | --data_path national_illness.csv \ 26 | --model_id ili_36_36 \ 27 | --model Transformer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 36 \ 31 | --label_len 18 \ 32 | --pred_len 36 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 7 \ 37 | --dec_in 7 \ 38 | --c_out 7 \ 39 | --gpu 5 \ 40 | --des 'Exp' \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/illness/ \ 46 | --data_path national_illness.csv \ 47 | --model_id ili_36_48 \ 48 | --model Transformer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 36 \ 52 | --label_len 18 \ 53 | --pred_len 48 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 7 \ 58 | --dec_in 7 \ 59 | --c_out 7 \ 60 | --gpu 6 \ 61 | --des 'Exp' \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/illness/ \ 67 | --data_path national_illness.csv \ 68 | --model_id ili_36_60 \ 69 | --model Transformer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 36 \ 73 | --label_len 18 \ 74 | --pred_len 60 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 7 \ 79 | --dec_in 7 \ 80 | --c_out 7 \ 81 | --gpu 7 \ 82 | --des 'Exp' \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/ILI_script/ns_Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/illness/ \ 4 | --data_path national_illness.csv \ 5 | --model_id ili_36_24 \ 6 | --model ns_Autoformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 36 \ 10 | --label_len 18 \ 11 | --pred_len 24 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 7 \ 16 | --dec_in 7 \ 17 | --c_out 7 \ 18 | --gpu 0 \ 19 | --des 'Exp_h32_l2' \ 20 | --p_hidden_dims 32 32 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/illness/ \ 27 | --data_path national_illness.csv \ 28 | --model_id ili_36_36 \ 29 | --model ns_Autoformer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 36 \ 33 | --label_len 18 \ 34 | --pred_len 36 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 7 \ 39 | --dec_in 7 \ 40 | --c_out 7 \ 41 | --gpu 1 \ 42 | --des 'Exp_h32_l2' \ 43 | --p_hidden_dims 32 32 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/illness/ \ 50 | --data_path national_illness.csv \ 51 | --model_id ili_36_48 \ 52 | --model ns_Autoformer \ 53 | --data custom \ 54 | --features M \ 55 | --seq_len 36 \ 56 | --label_len 18 \ 57 | --pred_len 48 \ 58 | --e_layers 2 \ 59 | --d_layers 1 \ 60 | --factor 3 \ 61 | --enc_in 7 \ 62 | --dec_in 7 \ 63 | --c_out 7 \ 64 | --gpu 2 \ 65 | --des 'Exp_h16_l2' \ 66 | --p_hidden_dims 16 16 \ 67 | --p_hidden_layers 2 \ 68 | --itr 1 & 69 | 70 | python -u run.py \ 71 | --is_training 1 \ 72 | --root_path ./dataset/illness/ \ 73 | --data_path national_illness.csv \ 74 | --model_id ili_36_60 \ 75 | --model ns_Autoformer \ 76 | --data custom \ 77 | --features M \ 78 | --seq_len 36 \ 79 | --label_len 18 \ 80 | --pred_len 60 \ 81 | --e_layers 2 \ 82 | --d_layers 1 \ 83 | --factor 3 \ 84 | --enc_in 7 \ 85 | --dec_in 7 \ 86 | --c_out 7 \ 87 | --gpu 3 \ 88 | --des 'Exp_h8_l2' \ 89 | --p_hidden_dims 8 8 \ 90 | --p_hidden_layers 2 \ 91 | --itr 1 & 92 | -------------------------------------------------------------------------------- /scripts/ILI_script/ns_Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/illness/ \ 4 | --data_path national_illness.csv \ 5 | --model_id ili_36_24 \ 6 | --model ns_Informer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 36 \ 10 | --label_len 18 \ 11 | --pred_len 24 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 7 \ 16 | --dec_in 7 \ 17 | --c_out 7 \ 18 | --gpu 0 \ 19 | --des 'Exp_h32_l2' \ 20 | --p_hidden_dims 32 32 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/illness/ \ 27 | --data_path national_illness.csv \ 28 | --model_id ili_36_36 \ 29 | --model ns_Informer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 36 \ 33 | --label_len 18 \ 34 | --pred_len 36 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 7 \ 39 | --dec_in 7 \ 40 | --c_out 7 \ 41 | --gpu 1 \ 42 | --des 'Exp_h32_l2' \ 43 | --p_hidden_dims 32 32 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/illness/ \ 50 | --data_path national_illness.csv \ 51 | --model_id ili_36_48 \ 52 | --model ns_Informer \ 53 | --data custom \ 54 | --features M \ 55 | --seq_len 36 \ 56 | --label_len 18 \ 57 | --pred_len 48 \ 58 | --e_layers 2 \ 59 | --d_layers 1 \ 60 | --factor 3 \ 61 | --enc_in 7 \ 62 | --dec_in 7 \ 63 | --c_out 7 \ 64 | --gpu 2 \ 65 | --des 'Exp_h16_l2' \ 66 | --p_hidden_dims 16 16 \ 67 | --p_hidden_layers 2 \ 68 | --itr 1 & 69 | 70 | python -u run.py \ 71 | --is_training 1 \ 72 | --root_path ./dataset/illness/ \ 73 | --data_path national_illness.csv \ 74 | --model_id ili_36_60 \ 75 | --model ns_Informer \ 76 | --data custom \ 77 | --features M \ 78 | --seq_len 36 \ 79 | --label_len 18 \ 80 | --pred_len 60 \ 81 | --e_layers 2 \ 82 | --d_layers 1 \ 83 | --factor 3 \ 84 | --enc_in 7 \ 85 | --dec_in 7 \ 86 | --c_out 7 \ 87 | --gpu 3 \ 88 | --des 'Exp_h8_l2' \ 89 | --p_hidden_dims 8 8 \ 90 | --p_hidden_layers 2 \ 91 | --itr 1 & 92 | -------------------------------------------------------------------------------- /scripts/ILI_script/ns_Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/illness/ \ 4 | --data_path national_illness.csv \ 5 | --model_id ili_36_24 \ 6 | --model ns_Transformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 36 \ 10 | --label_len 18 \ 11 | --pred_len 24 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 7 \ 16 | --dec_in 7 \ 17 | --c_out 7 \ 18 | --gpu 0 \ 19 | --des 'Exp_h32_l2' \ 20 | --p_hidden_dims 32 32 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/illness/ \ 27 | --data_path national_illness.csv \ 28 | --model_id ili_36_36 \ 29 | --model ns_Transformer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 36 \ 33 | --label_len 18 \ 34 | --pred_len 36 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 7 \ 39 | --dec_in 7 \ 40 | --c_out 7 \ 41 | --gpu 1 \ 42 | --des 'Exp_h32_l2' \ 43 | --p_hidden_dims 32 32 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/illness/ \ 50 | --data_path national_illness.csv \ 51 | --model_id ili_36_48 \ 52 | --model ns_Transformer \ 53 | --data custom \ 54 | --features M \ 55 | --seq_len 36 \ 56 | --label_len 18 \ 57 | --pred_len 48 \ 58 | --e_layers 2 \ 59 | --d_layers 1 \ 60 | --factor 3 \ 61 | --enc_in 7 \ 62 | --dec_in 7 \ 63 | --c_out 7 \ 64 | --gpu 2 \ 65 | --des 'Exp_h16_l2' \ 66 | --p_hidden_dims 16 16 \ 67 | --p_hidden_layers 2 \ 68 | --itr 1 & 69 | 70 | python -u run.py \ 71 | --is_training 1 \ 72 | --root_path ./dataset/illness/ \ 73 | --data_path national_illness.csv \ 74 | --model_id ili_36_60 \ 75 | --model ns_Transformer \ 76 | --data custom \ 77 | --features M \ 78 | --seq_len 36 \ 79 | --label_len 18 \ 80 | --pred_len 60 \ 81 | --e_layers 2 \ 82 | --d_layers 1 \ 83 | --factor 3 \ 84 | --enc_in 7 \ 85 | --dec_in 7 \ 86 | --c_out 7 \ 87 | --gpu 3 \ 88 | --des 'Exp_h8_l2' \ 89 | --p_hidden_dims 8 8 \ 90 | --p_hidden_layers 2 \ 91 | --itr 1 & 92 | -------------------------------------------------------------------------------- /scripts/Traffic_script/Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/traffic/ \ 4 | --data_path traffic.csv \ 5 | --model_id traffic_96_96 \ 6 | --model Autoformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 862 \ 16 | --dec_in 862 \ 17 | --c_out 862 \ 18 | --des 'Exp' \ 19 | --itr 1 \ 20 | --gpu 4 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/traffic/ \ 25 | --data_path traffic.csv \ 26 | --model_id traffic_96_192 \ 27 | --model Autoformer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 862 \ 37 | --dec_in 862 \ 38 | --c_out 862 \ 39 | --des 'Exp' \ 40 | --itr 1 \ 41 | --gpu 5 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/traffic/ \ 46 | --data_path traffic.csv \ 47 | --model_id traffic_96_336 \ 48 | --model Autoformer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 862 \ 58 | --dec_in 862 \ 59 | --c_out 862 \ 60 | --des 'Exp' \ 61 | --itr 1 \ 62 | --gpu 6 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/traffic/ \ 67 | --data_path traffic.csv \ 68 | --model_id traffic_96_720 \ 69 | --model Autoformer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 862 \ 79 | --dec_in 862 \ 80 | --c_out 862 \ 81 | --des 'Exp' \ 82 | --itr 1 \ 83 | --gpu 7 & 84 | -------------------------------------------------------------------------------- /scripts/Traffic_script/Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/traffic/ \ 4 | --data_path traffic.csv \ 5 | --model_id traffic_96_96 \ 6 | --model Informer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 862 \ 16 | --dec_in 862 \ 17 | --c_out 862 \ 18 | --des 'Exp' \ 19 | --itr 1 \ 20 | --gpu 4 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/traffic/ \ 25 | --data_path traffic.csv \ 26 | --model_id traffic_96_192 \ 27 | --model Informer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 862 \ 37 | --dec_in 862 \ 38 | --c_out 862 \ 39 | --des 'Exp' \ 40 | --itr 1 \ 41 | --gpu 5 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/traffic/ \ 46 | --data_path traffic.csv \ 47 | --model_id traffic_96_336 \ 48 | --model Informer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 862 \ 58 | --dec_in 862 \ 59 | --c_out 862 \ 60 | --des 'Exp' \ 61 | --itr 1 \ 62 | --gpu 6 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/traffic/ \ 67 | --data_path traffic.csv \ 68 | --model_id traffic_96_720 \ 69 | --model Informer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 862 \ 79 | --dec_in 862 \ 80 | --c_out 862 \ 81 | --des 'Exp' \ 82 | --itr 1 \ 83 | --gpu 7 & 84 | -------------------------------------------------------------------------------- /scripts/Traffic_script/Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/traffic/ \ 4 | --data_path traffic.csv \ 5 | --model_id traffic_96_96 \ 6 | --model Transformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 862 \ 16 | --dec_in 862 \ 17 | --c_out 862 \ 18 | --des 'Exp' \ 19 | --itr 1 \ 20 | --gpu 4 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/traffic/ \ 25 | --data_path traffic.csv \ 26 | --model_id traffic_96_192 \ 27 | --model Transformer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 862 \ 37 | --dec_in 862 \ 38 | --c_out 862 \ 39 | --des 'Exp' \ 40 | --itr 1 \ 41 | --gpu 5 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/traffic/ \ 46 | --data_path traffic.csv \ 47 | --model_id traffic_96_336 \ 48 | --model Transformer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 862 \ 58 | --dec_in 862 \ 59 | --c_out 862 \ 60 | --des 'Exp' \ 61 | --itr 1 \ 62 | --gpu 6 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/traffic/ \ 67 | --data_path traffic.csv \ 68 | --model_id traffic_96_720 \ 69 | --model Transformer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 862 \ 79 | --dec_in 862 \ 80 | --c_out 862 \ 81 | --des 'Exp' \ 82 | --itr 1 \ 83 | --gpu 7 & 84 | -------------------------------------------------------------------------------- /scripts/Traffic_script/ns_Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/traffic/ \ 4 | --data_path traffic.csv \ 5 | --model_id traffic_96_96 \ 6 | --model ns_Autoformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 862 \ 16 | --dec_in 862 \ 17 | --c_out 862 \ 18 | --gpu 0 \ 19 | --des 'Exp_h128_l2' \ 20 | --p_hidden_dims 128 128 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/traffic/ \ 27 | --data_path traffic.csv \ 28 | --model_id traffic_96_192 \ 29 | --model ns_Autoformer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --label_len 48 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 862 \ 39 | --dec_in 862 \ 40 | --c_out 862 \ 41 | --gpu 1 \ 42 | --des 'Exp_h128_l2' \ 43 | --p_hidden_dims 128 128 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/traffic/ \ 51 | --data_path traffic.csv \ 52 | --model_id traffic_96_336 \ 53 | --model ns_Autoformer \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 96 \ 57 | --label_len 48 \ 58 | --pred_len 336 \ 59 | --e_layers 2 \ 60 | --d_layers 1 \ 61 | --factor 3 \ 62 | --enc_in 862 \ 63 | --dec_in 862 \ 64 | --c_out 862 \ 65 | --gpu 2 \ 66 | --des 'Exp_h256_l2' \ 67 | --p_hidden_dims 256 256 \ 68 | --p_hidden_layers 2 \ 69 | --itr 1 & 70 | 71 | python -u run.py \ 72 | --is_training 1 \ 73 | --root_path ./dataset/traffic/ \ 74 | --data_path traffic.csv \ 75 | --model_id traffic_96_720 \ 76 | --model ns_Autoformer \ 77 | --data custom \ 78 | --features M \ 79 | --seq_len 96 \ 80 | --label_len 48 \ 81 | --pred_len 720 \ 82 | --e_layers 2 \ 83 | --d_layers 1 \ 84 | --factor 3 \ 85 | --enc_in 862 \ 86 | --dec_in 862 \ 87 | --c_out 862 \ 88 | --gpu 3 \ 89 | --des 'Exp_h128_l2' \ 90 | --p_hidden_dims 128 128 \ 91 | --p_hidden_layers 2 \ 92 | --itr 1 & 93 | -------------------------------------------------------------------------------- /scripts/Traffic_script/ns_Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/traffic/ \ 4 | --data_path traffic.csv \ 5 | --model_id traffic_96_96 \ 6 | --model ns_Informer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 862 \ 16 | --dec_in 862 \ 17 | --c_out 862 \ 18 | --gpu 0 \ 19 | --des 'Exp_h128_l2' \ 20 | --p_hidden_dims 128 128 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/traffic/ \ 27 | --data_path traffic.csv \ 28 | --model_id traffic_96_192 \ 29 | --model ns_Informer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --label_len 48 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 862 \ 39 | --dec_in 862 \ 40 | --c_out 862 \ 41 | --gpu 1 \ 42 | --des 'Exp_h128_l2' \ 43 | --p_hidden_dims 128 128 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/traffic/ \ 51 | --data_path traffic.csv \ 52 | --model_id traffic_96_336 \ 53 | --model ns_Informer \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 96 \ 57 | --label_len 48 \ 58 | --pred_len 336 \ 59 | --e_layers 2 \ 60 | --d_layers 1 \ 61 | --factor 3 \ 62 | --enc_in 862 \ 63 | --dec_in 862 \ 64 | --c_out 862 \ 65 | --gpu 2 \ 66 | --des 'Exp_h256_l2' \ 67 | --p_hidden_dims 256 256 \ 68 | --p_hidden_layers 2 \ 69 | --itr 1 & 70 | 71 | python -u run.py \ 72 | --is_training 1 \ 73 | --root_path ./dataset/traffic/ \ 74 | --data_path traffic.csv \ 75 | --model_id traffic_96_720 \ 76 | --model ns_Informer \ 77 | --data custom \ 78 | --features M \ 79 | --seq_len 96 \ 80 | --label_len 48 \ 81 | --pred_len 720 \ 82 | --e_layers 2 \ 83 | --d_layers 1 \ 84 | --factor 3 \ 85 | --enc_in 862 \ 86 | --dec_in 862 \ 87 | --c_out 862 \ 88 | --gpu 3 \ 89 | --des 'Exp_h128_l2' \ 90 | --p_hidden_dims 128 128 \ 91 | --p_hidden_layers 2 \ 92 | --itr 1 & 93 | -------------------------------------------------------------------------------- /scripts/Traffic_script/ns_Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/traffic/ \ 4 | --data_path traffic.csv \ 5 | --model_id traffic_96_96 \ 6 | --model ns_Transformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 862 \ 16 | --dec_in 862 \ 17 | --c_out 862 \ 18 | --gpu 0 \ 19 | --des 'Exp_h128_l2' \ 20 | --p_hidden_dims 128 128 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/traffic/ \ 27 | --data_path traffic.csv \ 28 | --model_id traffic_96_192 \ 29 | --model ns_Transformer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --label_len 48 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 862 \ 39 | --dec_in 862 \ 40 | --c_out 862 \ 41 | --gpu 1 \ 42 | --des 'Exp_h128_l2' \ 43 | --p_hidden_dims 128 128 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/traffic/ \ 51 | --data_path traffic.csv \ 52 | --model_id traffic_96_336 \ 53 | --model ns_Transformer \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 96 \ 57 | --label_len 48 \ 58 | --pred_len 336 \ 59 | --e_layers 2 \ 60 | --d_layers 1 \ 61 | --factor 3 \ 62 | --enc_in 862 \ 63 | --dec_in 862 \ 64 | --c_out 862 \ 65 | --gpu 2 \ 66 | --des 'Exp_h256_l2' \ 67 | --p_hidden_dims 256 256 \ 68 | --p_hidden_layers 2 \ 69 | --itr 1 & 70 | 71 | python -u run.py \ 72 | --is_training 1 \ 73 | --root_path ./dataset/traffic/ \ 74 | --data_path traffic.csv \ 75 | --model_id traffic_96_720 \ 76 | --model ns_Transformer \ 77 | --data custom \ 78 | --features M \ 79 | --seq_len 96 \ 80 | --label_len 48 \ 81 | --pred_len 720 \ 82 | --e_layers 2 \ 83 | --d_layers 1 \ 84 | --factor 3 \ 85 | --enc_in 862 \ 86 | --dec_in 862 \ 87 | --c_out 862 \ 88 | --gpu 3 \ 89 | --des 'Exp_h128_l2' \ 90 | --p_hidden_dims 128 128 \ 91 | --p_hidden_layers 2 \ 92 | --itr 1 & 93 | -------------------------------------------------------------------------------- /scripts/Weather_script/Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/weather/ \ 4 | --data_path weather.csv \ 5 | --model_id weather_96_96 \ 6 | --model Autoformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 21 \ 16 | --dec_in 21 \ 17 | --c_out 21 \ 18 | --des 'Exp' \ 19 | --itr 1 \ 20 | --gpu 4 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/weather/ \ 25 | --data_path weather.csv \ 26 | --model_id weather_96_192 \ 27 | --model Autoformer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 21 \ 37 | --dec_in 21 \ 38 | --c_out 21 \ 39 | --des 'Exp' \ 40 | --gpu 5 \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/weather/ \ 46 | --data_path weather.csv \ 47 | --model_id weather_96_336 \ 48 | --model Autoformer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 21 \ 58 | --dec_in 21 \ 59 | --c_out 21 \ 60 | --des 'Exp' \ 61 | --gpu 6 \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/weather/ \ 67 | --data_path weather.csv \ 68 | --model_id weather_96_720 \ 69 | --model Autoformer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 21 \ 79 | --dec_in 21 \ 80 | --c_out 21 \ 81 | --des 'Exp' \ 82 | --gpu 7 \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/Weather_script/Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/weather/ \ 4 | --data_path weather.csv \ 5 | --model_id weather_96_96 \ 6 | --model Informer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 21 \ 16 | --dec_in 21 \ 17 | --c_out 21 \ 18 | --des 'Exp' \ 19 | --itr 1 \ 20 | --gpu 4 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/weather/ \ 25 | --data_path weather.csv \ 26 | --model_id weather_96_192 \ 27 | --model Informer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 21 \ 37 | --dec_in 21 \ 38 | --c_out 21 \ 39 | --des 'Exp' \ 40 | --gpu 5 \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/weather/ \ 46 | --data_path weather.csv \ 47 | --model_id weather_96_336 \ 48 | --model Informer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 21 \ 58 | --dec_in 21 \ 59 | --c_out 21 \ 60 | --des 'Exp' \ 61 | --gpu 6 \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/weather/ \ 67 | --data_path weather.csv \ 68 | --model_id weather_96_720 \ 69 | --model Informer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 21 \ 79 | --dec_in 21 \ 80 | --c_out 21 \ 81 | --des 'Exp' \ 82 | --gpu 7 \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/Weather_script/Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/weather/ \ 4 | --data_path weather.csv \ 5 | --model_id weather_96_96 \ 6 | --model Transformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 21 \ 16 | --dec_in 21 \ 17 | --c_out 21 \ 18 | --des 'Exp' \ 19 | --itr 1 \ 20 | --gpu 4 & 21 | 22 | python -u run.py \ 23 | --is_training 1 \ 24 | --root_path ./dataset/weather/ \ 25 | --data_path weather.csv \ 26 | --model_id weather_96_192 \ 27 | --model Transformer \ 28 | --data custom \ 29 | --features M \ 30 | --seq_len 96 \ 31 | --label_len 48 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --d_layers 1 \ 35 | --factor 3 \ 36 | --enc_in 21 \ 37 | --dec_in 21 \ 38 | --c_out 21 \ 39 | --des 'Exp' \ 40 | --gpu 5 \ 41 | --itr 1 & 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/weather/ \ 46 | --data_path weather.csv \ 47 | --model_id weather_96_336 \ 48 | --model Transformer \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --label_len 48 \ 53 | --pred_len 336 \ 54 | --e_layers 2 \ 55 | --d_layers 1 \ 56 | --factor 3 \ 57 | --enc_in 21 \ 58 | --dec_in 21 \ 59 | --c_out 21 \ 60 | --des 'Exp' \ 61 | --gpu 6 \ 62 | --itr 1 & 63 | 64 | python -u run.py \ 65 | --is_training 1 \ 66 | --root_path ./dataset/weather/ \ 67 | --data_path weather.csv \ 68 | --model_id weather_96_720 \ 69 | --model Transformer \ 70 | --data custom \ 71 | --features M \ 72 | --seq_len 96 \ 73 | --label_len 48 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --d_layers 1 \ 77 | --factor 3 \ 78 | --enc_in 21 \ 79 | --dec_in 21 \ 80 | --c_out 21 \ 81 | --des 'Exp' \ 82 | --gpu 7 \ 83 | --itr 1 & 84 | -------------------------------------------------------------------------------- /scripts/Weather_script/ns_Autoformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/weather/ \ 4 | --data_path weather.csv \ 5 | --model_id weather_96_96 \ 6 | --model ns_Autoformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 21 \ 16 | --dec_in 21 \ 17 | --c_out 21 \ 18 | --gpu 0 \ 19 | --des 'Exp_h256_l2' \ 20 | --p_hidden_dims 256 256 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/weather/ \ 27 | --data_path weather.csv \ 28 | --model_id weather_96_192 \ 29 | --model ns_Autoformer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --label_len 48 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 21 \ 39 | --dec_in 21 \ 40 | --c_out 21 \ 41 | --gpu 1 \ 42 | --des 'Exp_h128_l2' \ 43 | --p_hidden_dims 128 128 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/weather/ \ 50 | --data_path weather.csv \ 51 | --model_id weather_96_336 \ 52 | --model ns_Autoformer \ 53 | --data custom \ 54 | --features M \ 55 | --seq_len 96 \ 56 | --label_len 48 \ 57 | --pred_len 336 \ 58 | --e_layers 2 \ 59 | --d_layers 1 \ 60 | --factor 3 \ 61 | --enc_in 21 \ 62 | --dec_in 21 \ 63 | --c_out 21 \ 64 | --gpu 2 \ 65 | --des 'Exp_h128_l2' \ 66 | --p_hidden_dims 128 128 \ 67 | --p_hidden_layers 2 \ 68 | --itr 1 & 69 | 70 | python -u run.py \ 71 | --is_training 1 \ 72 | --root_path ./dataset/weather/ \ 73 | --data_path weather.csv \ 74 | --model_id weather_96_720 \ 75 | --model ns_Autoformer \ 76 | --data custom \ 77 | --features M \ 78 | --seq_len 96 \ 79 | --label_len 48 \ 80 | --pred_len 720 \ 81 | --e_layers 2 \ 82 | --d_layers 1 \ 83 | --factor 3 \ 84 | --enc_in 21 \ 85 | --dec_in 21 \ 86 | --c_out 21 \ 87 | --gpu 3 \ 88 | --des 'Exp_h128_l2' \ 89 | --p_hidden_dims 128 128 \ 90 | --p_hidden_layers 2 \ 91 | --itr 1 & 92 | -------------------------------------------------------------------------------- /scripts/Weather_script/ns_Informer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/weather/ \ 4 | --data_path weather.csv \ 5 | --model_id weather_96_96 \ 6 | --model ns_Informer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 21 \ 16 | --dec_in 21 \ 17 | --c_out 21 \ 18 | --gpu 0 \ 19 | --des 'Exp_h256_l2' \ 20 | --p_hidden_dims 256 256 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/weather/ \ 27 | --data_path weather.csv \ 28 | --model_id weather_96_192 \ 29 | --model ns_Informer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --label_len 48 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 21 \ 39 | --dec_in 21 \ 40 | --c_out 21 \ 41 | --gpu 1 \ 42 | --des 'Exp_h128_l2' \ 43 | --p_hidden_dims 128 128 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/weather/ \ 50 | --data_path weather.csv \ 51 | --model_id weather_96_336 \ 52 | --model ns_Informer \ 53 | --data custom \ 54 | --features M \ 55 | --seq_len 96 \ 56 | --label_len 48 \ 57 | --pred_len 336 \ 58 | --e_layers 2 \ 59 | --d_layers 1 \ 60 | --factor 3 \ 61 | --enc_in 21 \ 62 | --dec_in 21 \ 63 | --c_out 21 \ 64 | --gpu 2 \ 65 | --des 'Exp_h128_l2' \ 66 | --p_hidden_dims 128 128 \ 67 | --p_hidden_layers 2 \ 68 | --itr 1 & 69 | 70 | python -u run.py \ 71 | --is_training 1 \ 72 | --root_path ./dataset/weather/ \ 73 | --data_path weather.csv \ 74 | --model_id weather_96_720 \ 75 | --model ns_Informer \ 76 | --data custom \ 77 | --features M \ 78 | --seq_len 96 \ 79 | --label_len 48 \ 80 | --pred_len 720 \ 81 | --e_layers 2 \ 82 | --d_layers 1 \ 83 | --factor 3 \ 84 | --enc_in 21 \ 85 | --dec_in 21 \ 86 | --c_out 21 \ 87 | --gpu 3 \ 88 | --des 'Exp_h128_l2' \ 89 | --p_hidden_dims 128 128 \ 90 | --p_hidden_layers 2 \ 91 | --itr 1 & 92 | -------------------------------------------------------------------------------- /scripts/Weather_script/ns_Transformer.sh: -------------------------------------------------------------------------------- 1 | python -u run.py \ 2 | --is_training 1 \ 3 | --root_path ./dataset/weather/ \ 4 | --data_path weather.csv \ 5 | --model_id weather_96_96 \ 6 | --model ns_Transformer \ 7 | --data custom \ 8 | --features M \ 9 | --seq_len 96 \ 10 | --label_len 48 \ 11 | --pred_len 96 \ 12 | --e_layers 2 \ 13 | --d_layers 1 \ 14 | --factor 3 \ 15 | --enc_in 21 \ 16 | --dec_in 21 \ 17 | --c_out 21 \ 18 | --gpu 0 \ 19 | --des 'Exp_h256_l2' \ 20 | --p_hidden_dims 256 256 \ 21 | --p_hidden_layers 2 \ 22 | --itr 1 & 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/weather/ \ 27 | --data_path weather.csv \ 28 | --model_id weather_96_192 \ 29 | --model ns_Transformer \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --label_len 48 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --d_layers 1 \ 37 | --factor 3 \ 38 | --enc_in 21 \ 39 | --dec_in 21 \ 40 | --c_out 21 \ 41 | --gpu 1 \ 42 | --des 'Exp_h128_l2' \ 43 | --p_hidden_dims 128 128 \ 44 | --p_hidden_layers 2 \ 45 | --itr 1 & 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/weather/ \ 50 | --data_path weather.csv \ 51 | --model_id weather_96_336 \ 52 | --model ns_Transformer \ 53 | --data custom \ 54 | --features M \ 55 | --seq_len 96 \ 56 | --label_len 48 \ 57 | --pred_len 336 \ 58 | --e_layers 2 \ 59 | --d_layers 1 \ 60 | --factor 3 \ 61 | --enc_in 21 \ 62 | --dec_in 21 \ 63 | --c_out 21 \ 64 | --gpu 2 \ 65 | --des 'Exp_h128_l2' \ 66 | --p_hidden_dims 128 128 \ 67 | --p_hidden_layers 2 \ 68 | --itr 1 & 69 | 70 | python -u run.py \ 71 | --is_training 1 \ 72 | --root_path ./dataset/weather/ \ 73 | --data_path weather.csv \ 74 | --model_id weather_96_720 \ 75 | --model ns_Transformer \ 76 | --data custom \ 77 | --features M \ 78 | --seq_len 96 \ 79 | --label_len 48 \ 80 | --pred_len 720 \ 81 | --e_layers 2 \ 82 | --d_layers 1 \ 83 | --factor 3 \ 84 | --enc_in 21 \ 85 | --dec_in 21 \ 86 | --c_out 21 \ 87 | --gpu 3 \ 88 | --des 'Exp_h128_l2' \ 89 | --p_hidden_dims 128 128 \ 90 | --p_hidden_layers 2 \ 91 | --itr 1 & 92 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/Nonstationary_Transformers/c4ec40675d11d50b3d9923657f408d0db6f90f56/utils/__init__.py -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 9 | 10 | @property 11 | def mask(self): 12 | return self._mask 13 | 14 | 15 | class ProbMask(): 16 | def __init__(self, B, H, L, index, scores, device="cpu"): 17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 19 | indicator = _mask_ex[torch.arange(B)[:, None, None], 20 | torch.arange(H)[None, :, None], 21 | index, :].to(device) 22 | self._mask = indicator.view(scores.shape).to(device) 23 | 24 | @property 25 | def mask(self): 26 | return self._mask 27 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 11 | return (u / d).mean(-1) 12 | 13 | 14 | def MAE(pred, true): 15 | return np.mean(np.abs(pred - true)) 16 | 17 | 18 | def MSE(pred, true): 19 | return np.mean((pred - true) ** 2) 20 | 21 | 22 | def RMSE(pred, true): 23 | return np.sqrt(MSE(pred, true)) 24 | 25 | 26 | def MAPE(pred, true): 27 | return np.mean(np.abs((pred - true) / true)) 28 | 29 | 30 | def MSPE(pred, true): 31 | return np.mean(np.square((pred - true) / true)) 32 | 33 | 34 | def metric(pred, true): 35 | mae = MAE(pred, true) 36 | mse = MSE(pred, true) 37 | rmse = RMSE(pred, true) 38 | mape = MAPE(pred, true) 39 | mspe = MSPE(pred, true) 40 | 41 | return mae, mse, rmse, mape, mspe 42 | -------------------------------------------------------------------------------- /utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from pandas.tseries import offsets 6 | from pandas.tseries.frequencies import to_offset 7 | 8 | 9 | class TimeFeature: 10 | def __init__(self): 11 | pass 12 | 13 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 14 | pass 15 | 16 | def __repr__(self): 17 | return self.__class__.__name__ + "()" 18 | 19 | 20 | class SecondOfMinute(TimeFeature): 21 | """Minute of hour encoded as value between [-0.5, 0.5]""" 22 | 23 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 24 | return index.second / 59.0 - 0.5 25 | 26 | 27 | class MinuteOfHour(TimeFeature): 28 | """Minute of hour encoded as value between [-0.5, 0.5]""" 29 | 30 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 31 | return index.minute / 59.0 - 0.5 32 | 33 | 34 | class HourOfDay(TimeFeature): 35 | """Hour of day encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.hour / 23.0 - 0.5 39 | 40 | 41 | class DayOfWeek(TimeFeature): 42 | """Hour of day encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.dayofweek / 6.0 - 0.5 46 | 47 | 48 | class DayOfMonth(TimeFeature): 49 | """Day of month encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return (index.day - 1) / 30.0 - 0.5 53 | 54 | 55 | class DayOfYear(TimeFeature): 56 | """Day of year encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return (index.dayofyear - 1) / 365.0 - 0.5 60 | 61 | 62 | class MonthOfYear(TimeFeature): 63 | """Month of year encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.month - 1) / 11.0 - 0.5 67 | 68 | 69 | class WeekOfYear(TimeFeature): 70 | """Week of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.isocalendar().week - 1) / 52.0 - 0.5 74 | 75 | 76 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 77 | """ 78 | Returns a list of time features that will be appropriate for the given frequency string. 79 | Parameters 80 | ---------- 81 | freq_str 82 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 83 | """ 84 | 85 | features_by_offsets = { 86 | offsets.YearEnd: [], 87 | offsets.QuarterEnd: [MonthOfYear], 88 | offsets.MonthEnd: [MonthOfYear], 89 | offsets.Week: [DayOfMonth, WeekOfYear], 90 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 91 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 92 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 93 | offsets.Minute: [ 94 | MinuteOfHour, 95 | HourOfDay, 96 | DayOfWeek, 97 | DayOfMonth, 98 | DayOfYear, 99 | ], 100 | offsets.Second: [ 101 | SecondOfMinute, 102 | MinuteOfHour, 103 | HourOfDay, 104 | DayOfWeek, 105 | DayOfMonth, 106 | DayOfYear, 107 | ], 108 | } 109 | 110 | offset = to_offset(freq_str) 111 | 112 | for offset_type, feature_classes in features_by_offsets.items(): 113 | if isinstance(offset, offset_type): 114 | return [cls() for cls in feature_classes] 115 | 116 | supported_freq_msg = f""" 117 | Unsupported frequency {freq_str} 118 | The following frequencies are supported: 119 | Y - yearly 120 | alias: A 121 | M - monthly 122 | W - weekly 123 | D - daily 124 | B - business days 125 | H - hourly 126 | T - minutely 127 | alias: min 128 | S - secondly 129 | """ 130 | raise RuntimeError(supported_freq_msg) 131 | 132 | 133 | def time_features(dates, freq='h'): 134 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 135 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | plt.switch_backend('agg') 6 | 7 | 8 | def adjust_learning_rate(optimizer, epoch, args): 9 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 10 | if args.lradj == 'type1': 11 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 12 | elif args.lradj == 'type2': 13 | lr_adjust = { 14 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 15 | 10: 5e-7, 15: 1e-7, 20: 5e-8 16 | } 17 | if epoch in lr_adjust.keys(): 18 | lr = lr_adjust[epoch] 19 | for param_group in optimizer.param_groups: 20 | param_group['lr'] = lr 21 | print('Updating learning rate to {}'.format(lr)) 22 | 23 | 24 | class EarlyStopping: 25 | def __init__(self, patience=7, verbose=False, delta=0): 26 | self.patience = patience 27 | self.verbose = verbose 28 | self.counter = 0 29 | self.best_score = None 30 | self.early_stop = False 31 | self.val_loss_min = np.Inf 32 | self.delta = delta 33 | 34 | def __call__(self, val_loss, model, path): 35 | score = -val_loss 36 | if self.best_score is None: 37 | self.best_score = score 38 | self.save_checkpoint(val_loss, model, path) 39 | elif score < self.best_score + self.delta: 40 | self.counter += 1 41 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 42 | if self.counter >= self.patience: 43 | self.early_stop = True 44 | else: 45 | self.best_score = score 46 | self.save_checkpoint(val_loss, model, path) 47 | self.counter = 0 48 | 49 | def save_checkpoint(self, val_loss, model, path): 50 | if self.verbose: 51 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 52 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 53 | self.val_loss_min = val_loss 54 | 55 | 56 | class dotdict(dict): 57 | """dot.notation access to dictionary attributes""" 58 | __getattr__ = dict.get 59 | __setattr__ = dict.__setitem__ 60 | __delattr__ = dict.__delitem__ 61 | 62 | 63 | class StandardScaler(): 64 | def __init__(self, mean, std): 65 | self.mean = mean 66 | self.std = std 67 | 68 | def transform(self, data): 69 | return (data - self.mean) / self.std 70 | 71 | def inverse_transform(self, data): 72 | return (data * self.std) + self.mean 73 | 74 | 75 | def visual(true, preds=None, name='./pic/test.pdf'): 76 | """ 77 | Results visualization 78 | """ 79 | plt.figure() 80 | plt.plot(true, label='GroundTruth', linewidth=2) 81 | if preds is not None: 82 | plt.plot(preds, label='Prediction', linewidth=2) 83 | plt.legend() 84 | plt.savefig(name, bbox_inches='tight') 85 | --------------------------------------------------------------------------------