├── .gitignore ├── LICENSE ├── README.md ├── data_provider ├── __init__.py ├── data_factory.py └── data_loader.py ├── experiments ├── exp_basic.py ├── exp_long_term_forecasting.py └── exp_long_term_forecasting_partial.py ├── figures ├── ablations.png ├── algorithm.png ├── analysis.png ├── architecture.png ├── boosting.png ├── boosting_trm.png ├── datasets.png ├── datasets_mtsf.png ├── efficiency.png ├── efficient.png ├── formulations.png ├── generability.png ├── groups.png ├── increase_lookback.png ├── layernorm.png ├── main_results.png ├── main_results_alipay.png ├── motivation.png ├── pi.png ├── pt.png └── radar.png ├── layers ├── Embed.py ├── SelfAttention_Family.py ├── Transformer_EncDec.py └── __init__.py ├── model ├── Flashformer.py ├── Flowformer.py ├── Informer.py ├── Reformer.py ├── Transformer.py ├── __init__.py ├── iFlashformer.py ├── iFlowformer.py ├── iInformer.py ├── iReformer.py └── iTransformer.py ├── requirements.txt ├── run.py ├── scripts ├── boost_performance │ ├── ECL │ │ ├── iFlowformer.sh │ │ ├── iInformer.sh │ │ ├── iReformer.sh │ │ └── iTransformer.sh │ ├── README.md │ ├── Traffic │ │ ├── iFlowformer.sh │ │ ├── iInformer.sh │ │ ├── iReformer.sh │ │ └── iTransformer.sh │ └── Weather │ │ ├── iFlowformer.sh │ │ ├── iInformer.sh │ │ ├── iReformer.sh │ │ └── iTransformer.sh ├── increasing_lookback │ ├── ECL │ │ ├── iFlowformer.sh │ │ ├── iInformer.sh │ │ ├── iReformer.sh │ │ └── iTransformer.sh │ ├── README.md │ └── Traffic │ │ ├── iFlowformer.sh │ │ ├── iInformer.sh │ │ ├── iReformer.sh │ │ └── iTransformer.sh ├── model_efficiency │ ├── ECL │ │ └── iFlashTransformer.sh │ ├── README.md │ ├── Traffic │ │ └── iFlashTransformer.sh │ └── Weather │ │ └── iFlashTransformer.sh ├── multivariate_forecasting │ ├── ECL │ │ └── iTransformer.sh │ ├── ETT │ │ ├── iTransformer_ETTh1.sh │ │ ├── iTransformer_ETTh2.sh │ │ ├── iTransformer_ETTm1.sh │ │ └── iTransformer_ETTm2.sh │ ├── Exchange │ │ └── iTransformer.sh │ ├── PEMS │ │ ├── iTransformer_03.sh │ │ ├── iTransformer_04.sh │ │ ├── iTransformer_07.sh │ │ └── iTransformer_08.sh │ ├── README.md │ ├── SolarEnergy │ │ └── iTransformer.sh │ ├── Traffic │ │ └── iTransformer.sh │ └── Weather │ │ └── iTransformer.sh └── variate_generalization │ ├── ECL │ ├── iFlowformer.sh │ ├── iInformer.sh │ ├── iReformer.sh │ └── iTransformer.sh │ ├── README.md │ ├── SolarEnergy │ ├── iFlowformer.sh │ ├── iInformer.sh │ ├── iReformer.sh │ └── iTransformer.sh │ └── Traffic │ ├── iFlowformer.sh │ ├── iInformer.sh │ ├── iReformer.sh │ └── iTransformer.sh └── utils ├── __init__.py ├── masking.py ├── metrics.py ├── timefeatures.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | */.DS_Store 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 THUML @ Tsinghua University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # iTransformer 2 | 3 | The repo is the official implementation for the paper: [iTransformer: Inverted Transformers Are Effective for Time Series Forecasting](https://arxiv.org/abs/2310.06625). [[Slides]](https://cloud.tsinghua.edu.cn/f/175ff98f7e2d44fbbe8e/), [[Poster]](https://cloud.tsinghua.edu.cn/f/36a2ae6c132d44c0bd8c/). 4 | 5 | 6 | # Updates 7 | 8 | :triangular_flag_on_post: **News** (2024.10) [TimeXer](https://arxiv.org/abs/2402.19072), a Transformer for predicting with exogenous variables, is released. Code is available [here](https://github.com/thuml/TimeXer). 9 | 10 | :triangular_flag_on_post: **News** (2024.05) Many thanks for the great efforts from [lucidrains](https://github.com/lucidrains/iTransformer). A pip package for the usage of iTransformer variants can be simply installed via ```pip install iTransformer``` 11 | 12 | :triangular_flag_on_post: **News** (2024.04) iTransformer has benn included in [NeuralForecast](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/itransformer.py). Special thanks to the contributor @[Marco](https://github.com/marcopeix)! 13 | 14 | :triangular_flag_on_post: **News** (2024.03) Introduction of our work in [Chinese](https://mp.weixin.qq.com/s/-pvBnA1_NSloNxa6TYXTSg) is available. 15 | 16 | :triangular_flag_on_post: **News** (2024.02) iTransformer has been accepted as **ICLR 2024 Spotlight**. 17 | 18 | :triangular_flag_on_post: **News** (2023.12) iTransformer available in [GluonTS](https://github.com/awslabs/gluonts/pull/3017) with probablistic head and support for static covariates. Notebook is available [here](https://github.com/awslabs/gluonts/blob/dev/examples/iTransformer.ipynb). 19 | 20 | :triangular_flag_on_post: **News** (2023.12) We received lots of valuable suggestions. A [revised version](https://arxiv.org/pdf/2310.06625v2.pdf) (**24 Pages**) is now available. 21 | 22 | :triangular_flag_on_post: **News** (2023.10) iTransformer has been included in [[Time-Series-Library]](https://github.com/thuml/Time-Series-Library) and achieves state-of-the-art in Lookback-$96$ forecasting. 23 | 24 | :triangular_flag_on_post: **News** (2023.10) All the scripts for the experiments in our [paper](https://arxiv.org/pdf/2310.06625.pdf) are available. 25 | 26 | 27 | ## Introduction 28 | 29 | 🌟 Considering the characteristics of multivariate time series, iTransformer breaks the conventional structure without modifying any Transformer modules. **Inverted Transformer is all you need in MTSF**. 30 | 31 |

32 | 33 |

34 | 35 | 🏆 iTransformer achieves the comprehensive state-of-the-art in challenging multivariate forecasting tasks and solves several pain points of Transformer on extensive time series data. 36 | 37 |

38 | 39 |

40 | 41 | 42 | ## Overall Architecture 43 | 44 | iTransformer regards **independent time series as variate tokens** to **capture multivariate correlations by attention** and **utilize layernorm and feed-forward networks to learn series representations**. 45 | 46 |

47 | 48 |

49 | 50 | The pseudo-code of iTransformer is as simple as the following: 51 | 52 |

53 | 54 |

55 | 56 | ## Usage 57 | 58 | 1. Install Pytorch and the necessary dependencies. 59 | 60 | ``` 61 | pip install -r requirements.txt 62 | ``` 63 | 64 | 1. The datasets can be obtained from [Google Drive](https://drive.google.com/file/d/1l51QsKvQPcqILT3DwfjCgx8Dsg2rpjot/view?usp=drive_link) or [Baidu Cloud](https://pan.baidu.com/s/11AWXg1Z6UwjHzmto4hesAA?pwd=9qjr). 65 | 66 | 2. Train and evaluate the model. We provide all the above tasks under the folder ./scripts/. You can reproduce the results as the following examples: 67 | 68 | ``` 69 | # Multivariate forecasting with iTransformer 70 | bash ./scripts/multivariate_forecasting/Traffic/iTransformer.sh 71 | 72 | # Compare the performance of Transformer and iTransformer 73 | bash ./scripts/boost_performance/Weather/iTransformer.sh 74 | 75 | # Train the model with partial variates, and generalize to the unseen variates 76 | bash ./scripts/variate_generalization/ECL/iTransformer.sh 77 | 78 | # Test the performance on the enlarged lookback window 79 | bash ./scripts/increasing_lookback/Traffic/iTransformer.sh 80 | 81 | # Utilize FlashAttention for acceleration 82 | bash ./scripts/efficient_attentions/iFlashTransformer.sh 83 | ``` 84 | 85 | ## Main Result of Multivariate Forecasting 86 | 87 | We evaluate the iTransformer on challenging multivariate forecasting benchmarks (**generally hundreds of variates**). **Comprehensive good performance** (MSE/MAE $\downarrow$) is achieved. 88 | 89 | 90 | 91 | ### Online Transaction Load Prediction of Alipay Trading Platform (Avg Results) 92 | 93 |

94 | 95 |

96 | 97 | ## General Performance Boosting on Transformers 98 | 99 | By introducing the proposed framework, Transformer and its variants achieve **significant performance improvement**, demonstrating the **generality of the iTransformer approach** and **benefiting from efficient attention mechanisms**. 100 | 101 |

102 | 103 |

104 | 105 | ## Zero-Shot Generalization on Variates 106 | 107 | **Technically, iTransformer is able to forecast with arbitrary numbers of variables**. We train iTransformers on partial variates and forecast unseen variates with good generalizability. 108 | 109 |

110 | 111 |

112 | 113 | ## Model Analysis 114 | 115 | Benefiting from inverted Transformer modules: 116 | 117 | - (Left) Inverted Transformers learn **better time series representations** (more similar [CKA](https://github.com/jayroxis/CKA-similarity)) favored by forecasting. 118 | - (Right) The inverted self-attention module learns **interpretable multivariate correlations**. 119 | 120 |

121 | 122 |

123 | 124 | ## Citation 125 | 126 | If you find this repo helpful, please cite our paper. 127 | 128 | ``` 129 | @article{liu2023itransformer, 130 | title={iTransformer: Inverted Transformers Are Effective for Time Series Forecasting}, 131 | author={Liu, Yong and Hu, Tengge and Zhang, Haoran and Wu, Haixu and Wang, Shiyu and Ma, Lintao and Long, Mingsheng}, 132 | journal={arXiv preprint arXiv:2310.06625}, 133 | year={2023} 134 | } 135 | ``` 136 | 137 | ## Acknowledgement 138 | 139 | We appreciate the following GitHub repos a lot for their valuable code and efforts. 140 | - Reformer (https://github.com/lucidrains/reformer-pytorch) 141 | - Informer (https://github.com/zhouhaoyi/Informer2020) 142 | - FlashAttention (https://github.com/shreyansh26/FlashAttention-PyTorch) 143 | - Autoformer (https://github.com/thuml/Autoformer) 144 | - Stationary (https://github.com/thuml/Nonstationary_Transformers) 145 | - Time-Series-Library (https://github.com/thuml/Time-Series-Library) 146 | - lucidrains (https://github.com/lucidrains/iTransformer) 147 | 148 | This work was supported by Ant Group through the CCF-Ant Research Fund and awarded as [Outstanding Projects of CCF Fund](https://mp.weixin.qq.com/s/PDLNbibZD3kqhcUoNejLfA). 149 | 150 | ## Contact 151 | 152 | If you have any questions or want to use the code, feel free to contact: 153 | * Yong Liu (liuyong21@mails.tsinghua.edu.cn) 154 | * Haoran Zhang (z-hr20@mails.tsinghua.edu.cn) 155 | * Tengge Hu (htg21@mails.tsinghua.edu.cn) 156 | -------------------------------------------------------------------------------- /data_provider/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/data_provider/__init__.py -------------------------------------------------------------------------------- /data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Solar, Dataset_PEMS, \ 2 | Dataset_Pred 3 | from torch.utils.data import DataLoader 4 | 5 | data_dict = { 6 | 'ETTh1': Dataset_ETT_hour, 7 | 'ETTh2': Dataset_ETT_hour, 8 | 'ETTm1': Dataset_ETT_minute, 9 | 'ETTm2': Dataset_ETT_minute, 10 | 'Solar': Dataset_Solar, 11 | 'PEMS': Dataset_PEMS, 12 | 'custom': Dataset_Custom, 13 | } 14 | 15 | 16 | def data_provider(args, flag): 17 | Data = data_dict[args.data] 18 | timeenc = 0 if args.embed != 'timeF' else 1 19 | 20 | if flag == 'test': 21 | shuffle_flag = False 22 | drop_last = True 23 | batch_size = 1 # bsz=1 for evaluation 24 | freq = args.freq 25 | elif flag == 'pred': 26 | shuffle_flag = False 27 | drop_last = False 28 | batch_size = 1 29 | freq = args.freq 30 | Data = Dataset_Pred 31 | else: 32 | shuffle_flag = True 33 | drop_last = True 34 | batch_size = args.batch_size # bsz for train and valid 35 | freq = args.freq 36 | 37 | data_set = Data( 38 | root_path=args.root_path, 39 | data_path=args.data_path, 40 | flag=flag, 41 | size=[args.seq_len, args.label_len, args.pred_len], 42 | features=args.features, 43 | target=args.target, 44 | timeenc=timeenc, 45 | freq=freq, 46 | ) 47 | print(flag, len(data_set)) 48 | data_loader = DataLoader( 49 | data_set, 50 | batch_size=batch_size, 51 | shuffle=shuffle_flag, 52 | num_workers=args.num_workers, 53 | drop_last=drop_last) 54 | return data_set, data_loader 55 | -------------------------------------------------------------------------------- /experiments/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from model import Transformer, Informer, Reformer, Flowformer, Flashformer, \ 4 | iTransformer, iInformer, iReformer, iFlowformer, iFlashformer 5 | 6 | 7 | class Exp_Basic(object): 8 | def __init__(self, args): 9 | self.args = args 10 | self.model_dict = { 11 | 'Transformer': Transformer, 12 | 'Informer': Informer, 13 | 'Reformer': Reformer, 14 | 'Flowformer': Flowformer, 15 | 'Flashformer': Flashformer, 16 | 'iTransformer': iTransformer, 17 | 'iInformer': iInformer, 18 | 'iReformer': iReformer, 19 | 'iFlowformer': iFlowformer, 20 | 'iFlashformer': iFlashformer, 21 | } 22 | self.device = self._acquire_device() 23 | self.model = self._build_model().to(self.device) 24 | 25 | def _build_model(self): 26 | raise NotImplementedError 27 | return None 28 | 29 | def _acquire_device(self): 30 | if self.args.use_gpu: 31 | os.environ["CUDA_VISIBLE_DEVICES"] = str( 32 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices 33 | device = torch.device('cuda:{}'.format(self.args.gpu)) 34 | print('Use GPU: cuda:{}'.format(self.args.gpu)) 35 | else: 36 | device = torch.device('cpu') 37 | print('Use CPU') 38 | return device 39 | 40 | def _get_data(self): 41 | pass 42 | 43 | def vali(self): 44 | pass 45 | 46 | def train(self): 47 | pass 48 | 49 | def test(self): 50 | pass 51 | -------------------------------------------------------------------------------- /figures/ablations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/ablations.png -------------------------------------------------------------------------------- /figures/algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/algorithm.png -------------------------------------------------------------------------------- /figures/analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/analysis.png -------------------------------------------------------------------------------- /figures/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/architecture.png -------------------------------------------------------------------------------- /figures/boosting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/boosting.png -------------------------------------------------------------------------------- /figures/boosting_trm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/boosting_trm.png -------------------------------------------------------------------------------- /figures/datasets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/datasets.png -------------------------------------------------------------------------------- /figures/datasets_mtsf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/datasets_mtsf.png -------------------------------------------------------------------------------- /figures/efficiency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/efficiency.png -------------------------------------------------------------------------------- /figures/efficient.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/efficient.png -------------------------------------------------------------------------------- /figures/formulations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/formulations.png -------------------------------------------------------------------------------- /figures/generability.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/generability.png -------------------------------------------------------------------------------- /figures/groups.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/groups.png -------------------------------------------------------------------------------- /figures/increase_lookback.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/increase_lookback.png -------------------------------------------------------------------------------- /figures/layernorm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/layernorm.png -------------------------------------------------------------------------------- /figures/main_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/main_results.png -------------------------------------------------------------------------------- /figures/main_results_alipay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/main_results_alipay.png -------------------------------------------------------------------------------- /figures/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/motivation.png -------------------------------------------------------------------------------- /figures/pi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/pi.png -------------------------------------------------------------------------------- /figures/pt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/pt.png -------------------------------------------------------------------------------- /figures/radar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/radar.png -------------------------------------------------------------------------------- /layers/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class PositionalEmbedding(nn.Module): 7 | def __init__(self, d_model, max_len=5000): 8 | super(PositionalEmbedding, self).__init__() 9 | # Compute the positional encodings once in log space. 10 | pe = torch.zeros(max_len, d_model).float() 11 | pe.require_grad = False 12 | 13 | position = torch.arange(0, max_len).float().unsqueeze(1) 14 | div_term = (torch.arange(0, d_model, 2).float() 15 | * -(math.log(10000.0) / d_model)).exp() 16 | 17 | pe[:, 0::2] = torch.sin(position * div_term) 18 | pe[:, 1::2] = torch.cos(position * div_term) 19 | 20 | pe = pe.unsqueeze(0) 21 | self.register_buffer('pe', pe) 22 | 23 | def forward(self, x): 24 | return self.pe[:, :x.size(1)] 25 | 26 | 27 | class TokenEmbedding(nn.Module): 28 | def __init__(self, c_in, d_model): 29 | super(TokenEmbedding, self).__init__() 30 | padding = 1 if torch.__version__ >= '1.5.0' else 2 31 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 32 | kernel_size=3, padding=padding, padding_mode='circular', bias=False) 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv1d): 35 | nn.init.kaiming_normal_( 36 | m.weight, mode='fan_in', nonlinearity='leaky_relu') 37 | 38 | def forward(self, x): 39 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 40 | return x 41 | 42 | 43 | class FixedEmbedding(nn.Module): 44 | def __init__(self, c_in, d_model): 45 | super(FixedEmbedding, self).__init__() 46 | 47 | w = torch.zeros(c_in, d_model).float() 48 | w.require_grad = False 49 | 50 | position = torch.arange(0, c_in).float().unsqueeze(1) 51 | div_term = (torch.arange(0, d_model, 2).float() 52 | * -(math.log(10000.0) / d_model)).exp() 53 | 54 | w[:, 0::2] = torch.sin(position * div_term) 55 | w[:, 1::2] = torch.cos(position * div_term) 56 | 57 | self.emb = nn.Embedding(c_in, d_model) 58 | self.emb.weight = nn.Parameter(w, requires_grad=False) 59 | 60 | def forward(self, x): 61 | return self.emb(x).detach() 62 | 63 | 64 | class TemporalEmbedding(nn.Module): 65 | def __init__(self, d_model, embed_type='fixed', freq='h'): 66 | super(TemporalEmbedding, self).__init__() 67 | 68 | minute_size = 4 69 | hour_size = 24 70 | weekday_size = 7 71 | day_size = 32 72 | month_size = 13 73 | 74 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding 75 | if freq == 't': 76 | self.minute_embed = Embed(minute_size, d_model) 77 | self.hour_embed = Embed(hour_size, d_model) 78 | self.weekday_embed = Embed(weekday_size, d_model) 79 | self.day_embed = Embed(day_size, d_model) 80 | self.month_embed = Embed(month_size, d_model) 81 | 82 | def forward(self, x): 83 | x = x.long() 84 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr( 85 | self, 'minute_embed') else 0. 86 | hour_x = self.hour_embed(x[:, :, 3]) 87 | weekday_x = self.weekday_embed(x[:, :, 2]) 88 | day_x = self.day_embed(x[:, :, 1]) 89 | month_x = self.month_embed(x[:, :, 0]) 90 | 91 | return hour_x + weekday_x + day_x + month_x + minute_x 92 | 93 | 94 | class TimeFeatureEmbedding(nn.Module): 95 | def __init__(self, d_model, embed_type='timeF', freq='h'): 96 | super(TimeFeatureEmbedding, self).__init__() 97 | 98 | freq_map = {'h': 4, 't': 5, 's': 6, 99 | 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 100 | d_inp = freq_map[freq] 101 | self.embed = nn.Linear(d_inp, d_model, bias=False) 102 | 103 | def forward(self, x): 104 | return self.embed(x) 105 | 106 | 107 | class DataEmbedding(nn.Module): 108 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 109 | super(DataEmbedding, self).__init__() 110 | 111 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 112 | self.position_embedding = PositionalEmbedding(d_model=d_model) 113 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 114 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 115 | d_model=d_model, embed_type=embed_type, freq=freq) 116 | self.dropout = nn.Dropout(p=dropout) 117 | 118 | def forward(self, x, x_mark): 119 | if x_mark is None: 120 | x = self.value_embedding(x) + self.position_embedding(x) 121 | else: 122 | x = self.value_embedding( 123 | x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 124 | return self.dropout(x) 125 | 126 | 127 | class DataEmbedding_inverted(nn.Module): 128 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 129 | super(DataEmbedding_inverted, self).__init__() 130 | self.value_embedding = nn.Linear(c_in, d_model) 131 | self.dropout = nn.Dropout(p=dropout) 132 | 133 | def forward(self, x, x_mark): 134 | x = x.permute(0, 2, 1) 135 | # x: [Batch Variate Time] 136 | if x_mark is None: 137 | x = self.value_embedding(x) 138 | else: 139 | # the potential to take covariates (e.g. timestamps) as tokens 140 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) 141 | # x: [Batch Variate d_model] 142 | return self.dropout(x) 143 | 144 | -------------------------------------------------------------------------------- /layers/SelfAttention_Family.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from math import sqrt 5 | from utils.masking import TriangularCausalMask, ProbMask 6 | from reformer_pytorch import LSHSelfAttention 7 | from einops import rearrange 8 | 9 | 10 | # Code implementation from https://github.com/thuml/Flowformer 11 | class FlowAttention(nn.Module): 12 | def __init__(self, attention_dropout=0.1): 13 | super(FlowAttention, self).__init__() 14 | self.dropout = nn.Dropout(attention_dropout) 15 | 16 | def kernel_method(self, x): 17 | return torch.sigmoid(x) 18 | 19 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 20 | queries = queries.transpose(1, 2) 21 | keys = keys.transpose(1, 2) 22 | values = values.transpose(1, 2) 23 | # kernel 24 | queries = self.kernel_method(queries) 25 | keys = self.kernel_method(keys) 26 | # incoming and outgoing 27 | normalizer_row = 1.0 / (torch.einsum("nhld,nhd->nhl", queries + 1e-6, keys.sum(dim=2) + 1e-6)) 28 | normalizer_col = 1.0 / (torch.einsum("nhsd,nhd->nhs", keys + 1e-6, queries.sum(dim=2) + 1e-6)) 29 | # reweighting 30 | normalizer_row_refine = ( 31 | torch.einsum("nhld,nhd->nhl", queries + 1e-6, (keys * normalizer_col[:, :, :, None]).sum(dim=2) + 1e-6)) 32 | normalizer_col_refine = ( 33 | torch.einsum("nhsd,nhd->nhs", keys + 1e-6, (queries * normalizer_row[:, :, :, None]).sum(dim=2) + 1e-6)) 34 | # competition and allocation 35 | normalizer_row_refine = torch.sigmoid( 36 | normalizer_row_refine * (float(queries.shape[2]) / float(keys.shape[2]))) 37 | normalizer_col_refine = torch.softmax(normalizer_col_refine, dim=-1) * keys.shape[2] # B h L vis 38 | # multiply 39 | kv = keys.transpose(-2, -1) @ (values * normalizer_col_refine[:, :, :, None]) 40 | x = (((queries @ kv) * normalizer_row[:, :, :, None]) * normalizer_row_refine[:, :, :, None]).transpose(1, 41 | 2).contiguous() 42 | return x, None 43 | 44 | 45 | # Code implementation from https://github.com/shreyansh26/FlashAttention-PyTorch 46 | class FlashAttention(nn.Module): 47 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 48 | super(FlashAttention, self).__init__() 49 | self.scale = scale 50 | self.mask_flag = mask_flag 51 | self.output_attention = output_attention 52 | self.dropout = nn.Dropout(attention_dropout) 53 | 54 | def flash_attention_forward(self, Q, K, V, mask=None): 55 | BLOCK_SIZE = 32 56 | NEG_INF = -1e10 # -infinity 57 | EPSILON = 1e-10 58 | # mask = torch.randint(0, 2, (128, 8)).to(device='cuda') 59 | O = torch.zeros_like(Q, requires_grad=True) 60 | l = torch.zeros(Q.shape[:-1])[..., None] 61 | m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF 62 | 63 | O = O.to(device='cuda') 64 | l = l.to(device='cuda') 65 | m = m.to(device='cuda') 66 | 67 | Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1]) 68 | KV_BLOCK_SIZE = BLOCK_SIZE 69 | 70 | Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2) 71 | K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2) 72 | V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2) 73 | if mask is not None: 74 | mask_BLOCKS = list(torch.split(mask, KV_BLOCK_SIZE, dim=1)) 75 | 76 | Tr = len(Q_BLOCKS) 77 | Tc = len(K_BLOCKS) 78 | 79 | O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2)) 80 | l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2)) 81 | m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2)) 82 | 83 | for j in range(Tc): 84 | Kj = K_BLOCKS[j] 85 | Vj = V_BLOCKS[j] 86 | if mask is not None: 87 | maskj = mask_BLOCKS[j] 88 | 89 | for i in range(Tr): 90 | Qi = Q_BLOCKS[i] 91 | Oi = O_BLOCKS[i] 92 | li = l_BLOCKS[i] 93 | mi = m_BLOCKS[i] 94 | 95 | scale = 1 / np.sqrt(Q.shape[-1]) 96 | Qi_scaled = Qi * scale 97 | 98 | S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj) 99 | if mask is not None: 100 | # Masking 101 | maskj_temp = rearrange(maskj, 'b j -> b 1 1 j') 102 | S_ij = torch.where(maskj_temp > 0, S_ij, NEG_INF) 103 | 104 | m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True) 105 | P_ij = torch.exp(S_ij - m_block_ij) 106 | if mask is not None: 107 | # Masking 108 | P_ij = torch.where(maskj_temp > 0, P_ij, 0.) 109 | 110 | l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON 111 | 112 | P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj) 113 | 114 | mi_new = torch.maximum(m_block_ij, mi) 115 | li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij 116 | 117 | O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi + ( 118 | torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj 119 | l_BLOCKS[i] = li_new 120 | m_BLOCKS[i] = mi_new 121 | 122 | O = torch.cat(O_BLOCKS, dim=2) 123 | l = torch.cat(l_BLOCKS, dim=2) 124 | m = torch.cat(m_BLOCKS, dim=2) 125 | return O, l, m 126 | 127 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 128 | res = \ 129 | self.flash_attention_forward(queries.permute(0, 2, 1, 3), keys.permute(0, 2, 1, 3), values.permute(0, 2, 1, 3), 130 | attn_mask)[0] 131 | return res.permute(0, 2, 1, 3).contiguous(), None 132 | 133 | 134 | class FullAttention(nn.Module): 135 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 136 | super(FullAttention, self).__init__() 137 | self.scale = scale 138 | self.mask_flag = mask_flag 139 | self.output_attention = output_attention 140 | self.dropout = nn.Dropout(attention_dropout) 141 | 142 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 143 | B, L, H, E = queries.shape 144 | _, S, _, D = values.shape 145 | scale = self.scale or 1. / sqrt(E) 146 | 147 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) 148 | 149 | if self.mask_flag: 150 | if attn_mask is None: 151 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 152 | 153 | scores.masked_fill_(attn_mask.mask, -np.inf) 154 | 155 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 156 | V = torch.einsum("bhls,bshd->blhd", A, values) 157 | 158 | if self.output_attention: 159 | return (V.contiguous(), A) 160 | else: 161 | return (V.contiguous(), None) 162 | 163 | 164 | # Code implementation from https://github.com/zhouhaoyi/Informer2020 165 | class ProbAttention(nn.Module): 166 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 167 | super(ProbAttention, self).__init__() 168 | self.factor = factor 169 | self.scale = scale 170 | self.mask_flag = mask_flag 171 | self.output_attention = output_attention 172 | self.dropout = nn.Dropout(attention_dropout) 173 | 174 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) 175 | # Q [B, H, L, D] 176 | B, H, L_K, E = K.shape 177 | _, _, L_Q, _ = Q.shape 178 | 179 | # calculate the sampled Q_K 180 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) 181 | # real U = U_part(factor*ln(L_k))*L_q 182 | index_sample = torch.randint(L_K, (L_Q, sample_k)) 183 | K_sample = K_expand[:, :, torch.arange( 184 | L_Q).unsqueeze(1), index_sample, :] 185 | Q_K_sample = torch.matmul( 186 | Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() 187 | 188 | # find the Top_k query with sparisty measurement 189 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) 190 | M_top = M.topk(n_top, sorted=False)[1] 191 | 192 | # use the reduced Q to calculate Q_K 193 | Q_reduce = Q[torch.arange(B)[:, None, None], 194 | torch.arange(H)[None, :, None], 195 | M_top, :] # factor*ln(L_q) 196 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k 197 | 198 | return Q_K, M_top 199 | 200 | def _get_initial_context(self, V, L_Q): 201 | B, H, L_V, D = V.shape 202 | if not self.mask_flag: 203 | # V_sum = V.sum(dim=-2) 204 | V_sum = V.mean(dim=-2) 205 | contex = V_sum.unsqueeze(-2).expand(B, H, 206 | L_Q, V_sum.shape[-1]).clone() 207 | else: # use mask 208 | # requires that L_Q == L_V, i.e. for self-attention only 209 | assert (L_Q == L_V) 210 | contex = V.cumsum(dim=-2) 211 | return contex 212 | 213 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): 214 | B, H, L_V, D = V.shape 215 | 216 | if self.mask_flag: 217 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) 218 | scores.masked_fill_(attn_mask.mask, -np.inf) 219 | 220 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) 221 | 222 | context_in[torch.arange(B)[:, None, None], 223 | torch.arange(H)[None, :, None], 224 | index, :] = torch.matmul(attn, V).type_as(context_in) 225 | if self.output_attention: 226 | attns = (torch.ones([B, H, L_V, L_V]) / 227 | L_V).type_as(attn).to(attn.device) 228 | attns[torch.arange(B)[:, None, None], torch.arange(H)[ 229 | None, :, None], index, :] = attn 230 | return (context_in, attns) 231 | else: 232 | return (context_in, None) 233 | 234 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 235 | B, L_Q, H, D = queries.shape 236 | _, L_K, _, _ = keys.shape 237 | 238 | queries = queries.transpose(2, 1) 239 | keys = keys.transpose(2, 1) 240 | values = values.transpose(2, 1) 241 | 242 | U_part = self.factor * \ 243 | np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) 244 | u = self.factor * \ 245 | np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 246 | 247 | U_part = U_part if U_part < L_K else L_K 248 | u = u if u < L_Q else L_Q 249 | 250 | scores_top, index = self._prob_QK( 251 | queries, keys, sample_k=U_part, n_top=u) 252 | 253 | # add scale factor 254 | scale = self.scale or 1. / sqrt(D) 255 | if scale is not None: 256 | scores_top = scores_top * scale 257 | # get the context 258 | context = self._get_initial_context(values, L_Q) 259 | # update the context with selected top_k queries 260 | context, attn = self._update_context( 261 | context, values, scores_top, index, L_Q, attn_mask) 262 | 263 | return context.contiguous(), attn 264 | 265 | 266 | class AttentionLayer(nn.Module): 267 | def __init__(self, attention, d_model, n_heads, d_keys=None, 268 | d_values=None): 269 | super(AttentionLayer, self).__init__() 270 | 271 | d_keys = d_keys or (d_model // n_heads) 272 | d_values = d_values or (d_model // n_heads) 273 | 274 | self.inner_attention = attention 275 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 276 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 277 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 278 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 279 | self.n_heads = n_heads 280 | 281 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): 282 | B, L, _ = queries.shape 283 | _, S, _ = keys.shape 284 | H = self.n_heads 285 | 286 | queries = self.query_projection(queries).view(B, L, H, -1) 287 | keys = self.key_projection(keys).view(B, S, H, -1) 288 | values = self.value_projection(values).view(B, S, H, -1) 289 | 290 | out, attn = self.inner_attention( 291 | queries, 292 | keys, 293 | values, 294 | attn_mask, 295 | tau=tau, 296 | delta=delta 297 | ) 298 | out = out.view(B, L, -1) 299 | 300 | return self.out_projection(out), attn 301 | 302 | 303 | class ReformerLayer(nn.Module): 304 | def __init__(self, attention, d_model, n_heads, d_keys=None, 305 | d_values=None, causal=False, bucket_size=4, n_hashes=4): 306 | super().__init__() 307 | self.bucket_size = bucket_size 308 | self.attn = LSHSelfAttention( 309 | dim=d_model, 310 | heads=n_heads, 311 | bucket_size=bucket_size, 312 | n_hashes=n_hashes, 313 | causal=causal 314 | ) 315 | 316 | def fit_length(self, queries): 317 | # inside reformer: assert N % (bucket_size * 2) == 0 318 | B, N, C = queries.shape 319 | if N % (self.bucket_size * 2) == 0: 320 | return queries 321 | else: 322 | # fill the time series 323 | fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2)) 324 | return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1) 325 | 326 | def forward(self, queries, keys, values, attn_mask, tau, delta): 327 | # in Reformer: defalut queries=keys 328 | B, N, C = queries.shape 329 | queries = self.attn(self.fit_length(queries))[:, :N, :] 330 | return queries, None 331 | 332 | -------------------------------------------------------------------------------- /layers/Transformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class ConvLayer(nn.Module): 6 | def __init__(self, c_in): 7 | super(ConvLayer, self).__init__() 8 | self.downConv = nn.Conv1d(in_channels=c_in, 9 | out_channels=c_in, 10 | kernel_size=3, 11 | padding=2, 12 | padding_mode='circular') 13 | self.norm = nn.BatchNorm1d(c_in) 14 | self.activation = nn.ELU() 15 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 16 | 17 | def forward(self, x): 18 | x = self.downConv(x.permute(0, 2, 1)) 19 | x = self.norm(x) 20 | x = self.activation(x) 21 | x = self.maxPool(x) 22 | x = x.transpose(1, 2) 23 | return x 24 | 25 | 26 | class EncoderLayer(nn.Module): 27 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 28 | super(EncoderLayer, self).__init__() 29 | d_ff = d_ff or 4 * d_model 30 | self.attention = attention 31 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 32 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 33 | self.norm1 = nn.LayerNorm(d_model) 34 | self.norm2 = nn.LayerNorm(d_model) 35 | self.dropout = nn.Dropout(dropout) 36 | self.activation = F.relu if activation == "relu" else F.gelu 37 | 38 | def forward(self, x, attn_mask=None, tau=None, delta=None): 39 | new_x, attn = self.attention( 40 | x, x, x, 41 | attn_mask=attn_mask, 42 | tau=tau, delta=delta 43 | ) 44 | x = x + self.dropout(new_x) 45 | 46 | y = x = self.norm1(x) 47 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 48 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 49 | 50 | return self.norm2(x + y), attn 51 | 52 | 53 | class Encoder(nn.Module): 54 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 55 | super(Encoder, self).__init__() 56 | self.attn_layers = nn.ModuleList(attn_layers) 57 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 58 | self.norm = norm_layer 59 | 60 | def forward(self, x, attn_mask=None, tau=None, delta=None): 61 | # x [B, L, D] 62 | attns = [] 63 | if self.conv_layers is not None: 64 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): 65 | delta = delta if i == 0 else None 66 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 67 | x = conv_layer(x) 68 | attns.append(attn) 69 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None) 70 | attns.append(attn) 71 | else: 72 | for attn_layer in self.attn_layers: 73 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 74 | attns.append(attn) 75 | 76 | if self.norm is not None: 77 | x = self.norm(x) 78 | 79 | return x, attns 80 | 81 | 82 | class DecoderLayer(nn.Module): 83 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 84 | dropout=0.1, activation="relu"): 85 | super(DecoderLayer, self).__init__() 86 | d_ff = d_ff or 4 * d_model 87 | self.self_attention = self_attention 88 | self.cross_attention = cross_attention 89 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 90 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 91 | self.norm1 = nn.LayerNorm(d_model) 92 | self.norm2 = nn.LayerNorm(d_model) 93 | self.norm3 = nn.LayerNorm(d_model) 94 | self.dropout = nn.Dropout(dropout) 95 | self.activation = F.relu if activation == "relu" else F.gelu 96 | 97 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 98 | x = x + self.dropout(self.self_attention( 99 | x, x, x, 100 | attn_mask=x_mask, 101 | tau=tau, delta=None 102 | )[0]) 103 | x = self.norm1(x) 104 | 105 | x = x + self.dropout(self.cross_attention( 106 | x, cross, cross, 107 | attn_mask=cross_mask, 108 | tau=tau, delta=delta 109 | )[0]) 110 | 111 | y = x = self.norm2(x) 112 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 113 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 114 | 115 | return self.norm3(x + y) 116 | 117 | 118 | class Decoder(nn.Module): 119 | def __init__(self, layers, norm_layer=None, projection=None): 120 | super(Decoder, self).__init__() 121 | self.layers = nn.ModuleList(layers) 122 | self.norm = norm_layer 123 | self.projection = projection 124 | 125 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 126 | for layer in self.layers: 127 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta) 128 | 129 | if self.norm is not None: 130 | x = self.norm(x) 131 | 132 | if self.projection is not None: 133 | x = self.projection(x) 134 | return x 135 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/layers/__init__.py -------------------------------------------------------------------------------- /model/Flashformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer 5 | from layers.SelfAttention_Family import FlashAttention, AttentionLayer, FullAttention 6 | from layers.Embed import DataEmbedding 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Vanilla Transformer 13 | with O(L^2) complexity 14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf 15 | """ 16 | 17 | def __init__(self, configs): 18 | super(Model, self).__init__() 19 | self.pred_len = configs.pred_len 20 | self.output_attention = configs.output_attention 21 | # Embedding 22 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 23 | configs.dropout) 24 | # Encoder 25 | self.encoder = Encoder( 26 | [ 27 | EncoderLayer( 28 | AttentionLayer( 29 | FlashAttention(False, configs.factor, attention_dropout=configs.dropout, 30 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 31 | configs.d_model, 32 | configs.d_ff, 33 | dropout=configs.dropout, 34 | activation=configs.activation 35 | ) for l in range(configs.e_layers) 36 | ], 37 | norm_layer=torch.nn.LayerNorm(configs.d_model) 38 | ) 39 | # Decoder 40 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 41 | configs.dropout) 42 | self.decoder = Decoder( 43 | [ 44 | DecoderLayer( 45 | AttentionLayer( 46 | FullAttention(True, configs.factor, attention_dropout=configs.dropout, 47 | output_attention=False), 48 | configs.d_model, configs.n_heads), 49 | AttentionLayer( 50 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 51 | output_attention=False), 52 | configs.d_model, configs.n_heads), 53 | configs.d_model, 54 | configs.d_ff, 55 | dropout=configs.dropout, 56 | activation=configs.activation, 57 | ) 58 | for l in range(configs.d_layers) 59 | ], 60 | norm_layer=torch.nn.LayerNorm(configs.d_model), 61 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 62 | ) 63 | 64 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 65 | # Embedding 66 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 67 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 68 | 69 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 70 | dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None) 71 | return dec_out 72 | 73 | 74 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 75 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 76 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 77 | -------------------------------------------------------------------------------- /model/Flowformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer 5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer, FlowAttention 6 | from layers.Embed import DataEmbedding 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Vanilla Transformer 13 | with O(L^2) complexity 14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf 15 | """ 16 | 17 | def __init__(self, configs): 18 | super(Model, self).__init__() 19 | self.pred_len = configs.pred_len 20 | self.output_attention = configs.output_attention 21 | 22 | if configs.channel_independence: 23 | self.enc_in = 1 24 | self.dec_in = 1 25 | self.c_out = 1 26 | else: 27 | self.enc_in = configs.enc_in 28 | self.dec_in = configs.dec_in 29 | self.c_out = configs.c_out 30 | 31 | # Embedding 32 | self.enc_embedding = DataEmbedding(self.enc_in, configs.d_model, configs.embed, configs.freq, 33 | configs.dropout) 34 | # Encoder 35 | self.encoder = Encoder( 36 | [ 37 | EncoderLayer( 38 | AttentionLayer( 39 | FlowAttention(attention_dropout=configs.dropout), configs.d_model, configs.n_heads), 40 | configs.d_model, 41 | configs.d_ff, 42 | dropout=configs.dropout, 43 | activation=configs.activation 44 | ) for l in range(configs.e_layers) 45 | ], 46 | norm_layer=torch.nn.LayerNorm(configs.d_model) 47 | ) 48 | # Decoder 49 | self.dec_embedding = DataEmbedding(self.dec_in, configs.d_model, configs.embed, configs.freq, 50 | configs.dropout) 51 | self.decoder = Decoder( 52 | [ 53 | DecoderLayer( 54 | AttentionLayer( 55 | FullAttention(True, configs.factor, attention_dropout=configs.dropout, 56 | output_attention=False), 57 | configs.d_model, configs.n_heads), 58 | AttentionLayer( 59 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 60 | output_attention=False), 61 | configs.d_model, configs.n_heads), 62 | configs.d_model, 63 | configs.d_ff, 64 | dropout=configs.dropout, 65 | activation=configs.activation, 66 | ) 67 | for l in range(configs.d_layers) 68 | ], 69 | norm_layer=torch.nn.LayerNorm(configs.d_model), 70 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 71 | ) 72 | 73 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 74 | # Embedding 75 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 76 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 77 | 78 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 79 | dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None) 80 | return dec_out 81 | 82 | 83 | 84 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 85 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 86 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 87 | -------------------------------------------------------------------------------- /model/Informer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer 5 | from layers.SelfAttention_Family import ProbAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding 7 | 8 | 9 | class Model(nn.Module): 10 | """ 11 | Informer with Propspare attention in O(LlogL) complexity 12 | Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/17325/17132 13 | """ 14 | 15 | def __init__(self, configs): 16 | super(Model, self).__init__() 17 | 18 | self.pred_len = configs.pred_len 19 | self.label_len = configs.label_len 20 | 21 | if configs.channel_independence: 22 | self.enc_in = 1 23 | self.dec_in = 1 24 | self.c_out = 1 25 | else: 26 | self.enc_in = configs.enc_in 27 | self.dec_in = configs.dec_in 28 | self.c_out = configs.c_out 29 | 30 | # Embedding 31 | self.enc_embedding = DataEmbedding(self.enc_in, configs.d_model, configs.embed, configs.freq, 32 | configs.dropout) 33 | self.dec_embedding = DataEmbedding(self.dec_in, configs.d_model, configs.embed, configs.freq, 34 | configs.dropout) 35 | 36 | # Encoder 37 | self.encoder = Encoder( 38 | [ 39 | EncoderLayer( 40 | AttentionLayer( 41 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, 42 | output_attention=configs.output_attention), 43 | configs.d_model, configs.n_heads), 44 | configs.d_model, 45 | configs.d_ff, 46 | dropout=configs.dropout, 47 | activation=configs.activation 48 | ) for l in range(configs.e_layers) 49 | ], 50 | [ 51 | ConvLayer( 52 | configs.d_model 53 | ) for l in range(configs.e_layers - 1) 54 | ] if configs.distil else None, 55 | norm_layer=torch.nn.LayerNorm(configs.d_model) 56 | ) 57 | # Decoder 58 | self.decoder = Decoder( 59 | [ 60 | DecoderLayer( 61 | AttentionLayer( 62 | ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False), 63 | configs.d_model, configs.n_heads), 64 | AttentionLayer( 65 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), 66 | configs.d_model, configs.n_heads), 67 | configs.d_model, 68 | configs.d_ff, 69 | dropout=configs.dropout, 70 | activation=configs.activation, 71 | ) 72 | for l in range(configs.d_layers) 73 | ], 74 | norm_layer=torch.nn.LayerNorm(configs.d_model), 75 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 76 | ) 77 | 78 | 79 | def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 80 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 81 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 82 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 83 | 84 | dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None) 85 | 86 | return dec_out # [B, L, D] 87 | 88 | 89 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 90 | dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 91 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 92 | -------------------------------------------------------------------------------- /model/Reformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import ReformerLayer 6 | from layers.Embed import DataEmbedding 7 | 8 | 9 | class Model(nn.Module): 10 | """ 11 | Reformer with O(LlogL) complexity 12 | Paper link: https://openreview.net/forum?id=rkgNKkHtvB 13 | """ 14 | 15 | def __init__(self, configs, bucket_size=4, n_hashes=4): 16 | """ 17 | bucket_size: int, 18 | n_hashes: int, 19 | """ 20 | super(Model, self).__init__() 21 | self.pred_len = configs.pred_len 22 | self.seq_len = configs.seq_len 23 | 24 | if configs.channel_independence: 25 | self.enc_in = 1 26 | self.dec_in = 1 27 | self.c_out = 1 28 | else: 29 | self.enc_in = configs.enc_in 30 | self.dec_in = configs.dec_in 31 | self.c_out = configs.c_out 32 | 33 | self.enc_embedding = DataEmbedding(self.enc_in, configs.d_model, configs.embed, configs.freq, 34 | configs.dropout) 35 | # Encoder 36 | self.encoder = Encoder( 37 | [ 38 | EncoderLayer( 39 | ReformerLayer(None, configs.d_model, configs.n_heads, 40 | bucket_size=bucket_size, n_hashes=n_hashes), 41 | configs.d_model, 42 | configs.d_ff, 43 | dropout=configs.dropout, 44 | activation=configs.activation 45 | ) for l in range(configs.e_layers) 46 | ], 47 | norm_layer=torch.nn.LayerNorm(configs.d_model) 48 | ) 49 | 50 | self.projection = nn.Linear( 51 | configs.d_model, configs.c_out, bias=True) 52 | 53 | def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 54 | # add placeholder 55 | x_enc = torch.cat([x_enc, x_dec[:, -self.pred_len:, :]], dim=1) 56 | if x_mark_enc is not None: 57 | x_mark_enc = torch.cat( 58 | [x_mark_enc, x_mark_dec[:, -self.pred_len:, :]], dim=1) 59 | 60 | enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C] 61 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 62 | dec_out = self.projection(enc_out) 63 | 64 | return dec_out # [B, L, D] 65 | 66 | 67 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 68 | dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 69 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 70 | -------------------------------------------------------------------------------- /model/Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer 5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Vanilla Transformer 13 | with O(L^2) complexity 14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf 15 | """ 16 | 17 | def __init__(self, configs): 18 | super(Model, self).__init__() 19 | self.pred_len = configs.pred_len 20 | self.output_attention = configs.output_attention 21 | 22 | if configs.channel_independence: 23 | self.enc_in = 1 24 | self.dec_in = 1 25 | self.c_out = 1 26 | else: 27 | self.enc_in = configs.enc_in 28 | self.dec_in = configs.dec_in 29 | self.c_out = configs.c_out 30 | 31 | # Embedding 32 | self.enc_embedding = DataEmbedding(self.enc_in, configs.d_model, configs.embed, configs.freq, 33 | configs.dropout) 34 | # Encoder 35 | self.encoder = Encoder( 36 | [ 37 | EncoderLayer( 38 | AttentionLayer( 39 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 40 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 41 | configs.d_model, 42 | configs.d_ff, 43 | dropout=configs.dropout, 44 | activation=configs.activation 45 | ) for l in range(configs.e_layers) 46 | ], 47 | norm_layer=torch.nn.LayerNorm(configs.d_model) 48 | ) 49 | # Decoder 50 | self.dec_embedding = DataEmbedding(self.dec_in, configs.d_model, configs.embed, configs.freq, 51 | configs.dropout) 52 | self.decoder = Decoder( 53 | [ 54 | DecoderLayer( 55 | AttentionLayer( 56 | FullAttention(True, configs.factor, attention_dropout=configs.dropout, 57 | output_attention=False), 58 | configs.d_model, configs.n_heads), 59 | AttentionLayer( 60 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 61 | output_attention=False), 62 | configs.d_model, configs.n_heads), 63 | configs.d_model, 64 | configs.d_ff, 65 | dropout=configs.dropout, 66 | activation=configs.activation, 67 | ) 68 | for l in range(configs.d_layers) 69 | ], 70 | norm_layer=torch.nn.LayerNorm(configs.d_model), 71 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 72 | ) 73 | 74 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 75 | # Embedding 76 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 77 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 78 | 79 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 80 | dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None) 81 | return dec_out 82 | 83 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 84 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 85 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 86 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/model/__init__.py -------------------------------------------------------------------------------- /model/iFlashformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import FlashAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding_inverted 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Vanilla Transformer 13 | with O(L^2) complexity 14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf 15 | """ 16 | 17 | def __init__(self, configs): 18 | super(Model, self).__init__() 19 | self.seq_len = configs.seq_len 20 | self.pred_len = configs.pred_len 21 | self.output_attention = configs.output_attention 22 | # Embedding 23 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq, 24 | configs.dropout) 25 | # Encoder-only architecture 26 | self.encoder = Encoder( 27 | [ 28 | EncoderLayer( 29 | AttentionLayer( 30 | FlashAttention(False, configs.factor, attention_dropout=configs.dropout, 31 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 32 | configs.d_model, 33 | configs.d_ff, 34 | dropout=configs.dropout, 35 | activation=configs.activation 36 | ) for l in range(configs.e_layers) 37 | ], 38 | norm_layer=torch.nn.LayerNorm(configs.d_model) 39 | ) 40 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True) 41 | 42 | 43 | 44 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 45 | # Normalization from Non-stationary Transformer 46 | means = x_enc.mean(1, keepdim=True).detach() 47 | x_enc = x_enc - means 48 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 49 | x_enc /= stdev 50 | 51 | _, _, N = x_enc.shape 52 | 53 | # Embedding 54 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 55 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 56 | 57 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] 58 | # De-Normalization from Non-stationary Transformer 59 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 60 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 61 | return dec_out, attns 62 | 63 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 64 | dec_out, attns = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 65 | 66 | if self.output_attention: 67 | return dec_out[:, -self.pred_len:, :], attns 68 | else: 69 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 70 | -------------------------------------------------------------------------------- /model/iFlowformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import FlowAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding_inverted 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Vanilla Transformer 13 | with O(L^2) complexity 14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf 15 | """ 16 | 17 | def __init__(self, configs): 18 | super(Model, self).__init__() 19 | self.seq_len = configs.seq_len 20 | self.pred_len = configs.pred_len 21 | self.output_attention = configs.output_attention 22 | # Embedding 23 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq, 24 | configs.dropout) 25 | # Encoder-only architecture 26 | self.encoder = Encoder( 27 | [ 28 | EncoderLayer( 29 | AttentionLayer( 30 | FlowAttention(attention_dropout=configs.dropout), configs.d_model, configs.n_heads), 31 | configs.d_model, 32 | configs.d_ff, 33 | dropout=configs.dropout, 34 | activation=configs.activation 35 | ) for l in range(configs.e_layers) 36 | ], 37 | norm_layer=torch.nn.LayerNorm(configs.d_model) 38 | ) 39 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True) 40 | 41 | 42 | 43 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 44 | # Normalization from Non-stationary Transformer 45 | means = x_enc.mean(1, keepdim=True).detach() 46 | x_enc = x_enc - means 47 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 48 | x_enc /= stdev 49 | 50 | _, _, N = x_enc.shape 51 | 52 | # Embedding 53 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 54 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 55 | 56 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] 57 | # De-Normalization from Non-stationary Transformer 58 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 59 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 60 | return dec_out, attns 61 | 62 | 63 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 64 | dec_out, attns = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 65 | 66 | if self.output_attention: 67 | return dec_out[:, -self.pred_len:, :], attns 68 | else: 69 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 70 | -------------------------------------------------------------------------------- /model/iInformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import ProbAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding_inverted 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Vanilla Transformer 13 | with O(L^2) complexity 14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf 15 | """ 16 | 17 | def __init__(self, configs): 18 | super(Model, self).__init__() 19 | self.seq_len = configs.seq_len 20 | self.pred_len = configs.pred_len 21 | self.output_attention = configs.output_attention 22 | # Embedding 23 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq, 24 | configs.dropout) 25 | # Encoder-only architecture 26 | self.encoder = Encoder( 27 | [ 28 | EncoderLayer( 29 | AttentionLayer( 30 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, 31 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 32 | configs.d_model, 33 | configs.d_ff, 34 | dropout=configs.dropout, 35 | activation=configs.activation 36 | ) for l in range(configs.e_layers) 37 | ], 38 | norm_layer=torch.nn.LayerNorm(configs.d_model) 39 | ) 40 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True) 41 | 42 | 43 | 44 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 45 | # Normalization from Non-stationary Transformer 46 | means = x_enc.mean(1, keepdim=True).detach() 47 | x_enc = x_enc - means 48 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 49 | x_enc /= stdev 50 | 51 | _, _, N = x_enc.shape 52 | 53 | # Embedding 54 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 55 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 56 | 57 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] 58 | # De-Normalization from Non-stationary Transformer 59 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 60 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 61 | return dec_out, attns 62 | 63 | 64 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 65 | dec_out, attns = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 66 | 67 | if self.output_attention: 68 | return dec_out[:, -self.pred_len:, :], attns 69 | else: 70 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 71 | -------------------------------------------------------------------------------- /model/iReformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import ReformerLayer 6 | from layers.Embed import DataEmbedding_inverted 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Vanilla Transformer 13 | with O(L^2) complexity 14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf 15 | """ 16 | 17 | def __init__(self, configs): 18 | super(Model, self).__init__() 19 | self.seq_len = configs.seq_len 20 | self.pred_len = configs.pred_len 21 | self.output_attention = configs.output_attention 22 | # Embedding 23 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq, 24 | configs.dropout) 25 | # Encoder-only architecture 26 | self.encoder = Encoder( 27 | [ 28 | EncoderLayer( 29 | ReformerLayer(None, configs.d_model, configs.n_heads, 30 | bucket_size=4, n_hashes=4), 31 | configs.d_model, 32 | configs.d_ff, 33 | dropout=configs.dropout, 34 | activation=configs.activation 35 | ) for l in range(configs.e_layers) 36 | ], 37 | norm_layer=torch.nn.LayerNorm(configs.d_model) 38 | ) 39 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True) 40 | 41 | 42 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 43 | # Normalization from Non-stationary Transformer 44 | means = x_enc.mean(1, keepdim=True).detach() 45 | x_enc = x_enc - means 46 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 47 | x_enc /= stdev 48 | 49 | _, _, N = x_enc.shape 50 | 51 | # Embedding 52 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 53 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 54 | 55 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] 56 | # De-Normalization from Non-stationary Transformer 57 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 58 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 59 | return dec_out, attns 60 | 61 | 62 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 63 | dec_out, attns = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 64 | 65 | if self.output_attention: 66 | return dec_out[:, -self.pred_len:, :], attns 67 | else: 68 | return dec_out[:, -self.pred_len:, :] # [B, L, D] -------------------------------------------------------------------------------- /model/iTransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding_inverted 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Paper link: https://arxiv.org/abs/2310.06625 13 | """ 14 | 15 | def __init__(self, configs): 16 | super(Model, self).__init__() 17 | self.seq_len = configs.seq_len 18 | self.pred_len = configs.pred_len 19 | self.output_attention = configs.output_attention 20 | self.use_norm = configs.use_norm 21 | # Embedding 22 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq, 23 | configs.dropout) 24 | self.class_strategy = configs.class_strategy 25 | # Encoder-only architecture 26 | self.encoder = Encoder( 27 | [ 28 | EncoderLayer( 29 | AttentionLayer( 30 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 31 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 32 | configs.d_model, 33 | configs.d_ff, 34 | dropout=configs.dropout, 35 | activation=configs.activation 36 | ) for l in range(configs.e_layers) 37 | ], 38 | norm_layer=torch.nn.LayerNorm(configs.d_model) 39 | ) 40 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True) 41 | 42 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 43 | if self.use_norm: 44 | # Normalization from Non-stationary Transformer 45 | means = x_enc.mean(1, keepdim=True).detach() 46 | x_enc = x_enc - means 47 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 48 | x_enc /= stdev 49 | 50 | _, _, N = x_enc.shape # B L N 51 | # B: batch_size; E: d_model; 52 | # L: seq_len; S: pred_len; 53 | # N: number of variate (tokens), can also includes covariates 54 | 55 | # Embedding 56 | # B L N -> B N E (B L N -> B L E in the vanilla Transformer) 57 | enc_out = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens 58 | 59 | # B N E -> B N E (B L E -> B L E in the vanilla Transformer) 60 | # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules 61 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 62 | 63 | # B N E -> B N S -> B S N 64 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates 65 | 66 | if self.use_norm: 67 | # De-Normalization from Non-stationary Transformer 68 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 69 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 70 | 71 | return dec_out, attns 72 | 73 | 74 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 75 | dec_out, attns = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 76 | 77 | if self.output_attention: 78 | return dec_out[:, -self.pred_len:, :], attns 79 | else: 80 | return dec_out[:, -self.pred_len:, :] # [B, L, D] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.5.3 2 | scikit-learn==1.2.2 3 | numpy==1.23.5 4 | matplotlib==3.7.0 5 | torch==2.0.0 6 | reformer-pytorch==1.4.4 7 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from experiments.exp_long_term_forecasting import Exp_Long_Term_Forecast 4 | from experiments.exp_long_term_forecasting_partial import Exp_Long_Term_Forecast_Partial 5 | import random 6 | import numpy as np 7 | 8 | if __name__ == '__main__': 9 | fix_seed = 2023 10 | random.seed(fix_seed) 11 | torch.manual_seed(fix_seed) 12 | np.random.seed(fix_seed) 13 | 14 | parser = argparse.ArgumentParser(description='iTransformer') 15 | 16 | # basic config 17 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status') 18 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id') 19 | parser.add_argument('--model', type=str, required=True, default='iTransformer', 20 | help='model name, options: [iTransformer, iInformer, iReformer, iFlowformer, iFlashformer]') 21 | 22 | # data loader 23 | parser.add_argument('--data', type=str, required=True, default='custom', help='dataset type') 24 | parser.add_argument('--root_path', type=str, default='./data/electricity/', help='root path of the data file') 25 | parser.add_argument('--data_path', type=str, default='electricity.csv', help='data csv file') 26 | parser.add_argument('--features', type=str, default='M', 27 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate') 28 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task') 29 | parser.add_argument('--freq', type=str, default='h', 30 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') 31 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints') 32 | 33 | # forecasting task 34 | parser.add_argument('--seq_len', type=int, default=96, help='input sequence length') 35 | parser.add_argument('--label_len', type=int, default=48, help='start token length') # no longer needed in inverted Transformers 36 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length') 37 | 38 | # model define 39 | parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') 40 | parser.add_argument('--dec_in', type=int, default=7, help='decoder input size') 41 | parser.add_argument('--c_out', type=int, default=7, help='output size') # applicable on arbitrary number of variates in inverted Transformers 42 | parser.add_argument('--d_model', type=int, default=512, help='dimension of model') 43 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads') 44 | parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers') 45 | parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers') 46 | parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn') 47 | parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average') 48 | parser.add_argument('--factor', type=int, default=1, help='attn factor') 49 | parser.add_argument('--distil', action='store_false', 50 | help='whether to use distilling in encoder, using this argument means not using distilling', 51 | default=True) 52 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout') 53 | parser.add_argument('--embed', type=str, default='timeF', 54 | help='time features encoding, options:[timeF, fixed, learned]') 55 | parser.add_argument('--activation', type=str, default='gelu', help='activation') 56 | parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder') 57 | parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data') 58 | 59 | # optimization 60 | parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') 61 | parser.add_argument('--itr', type=int, default=1, help='experiments times') 62 | parser.add_argument('--train_epochs', type=int, default=10, help='train epochs') 63 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data') 64 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience') 65 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate') 66 | parser.add_argument('--des', type=str, default='test', help='exp description') 67 | parser.add_argument('--loss', type=str, default='MSE', help='loss function') 68 | parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate') 69 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False) 70 | 71 | # GPU 72 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') 73 | parser.add_argument('--gpu', type=int, default=0, help='gpu') 74 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) 75 | parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus') 76 | 77 | # iTransformer 78 | parser.add_argument('--exp_name', type=str, required=False, default='MTSF', 79 | help='experiemnt name, options:[MTSF, partial_train]') 80 | parser.add_argument('--channel_independence', type=bool, default=False, help='whether to use channel_independence mechanism') 81 | parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False) 82 | parser.add_argument('--class_strategy', type=str, default='projection', help='projection/average/cls_token') 83 | parser.add_argument('--target_root_path', type=str, default='./data/electricity/', help='root path of the data file') 84 | parser.add_argument('--target_data_path', type=str, default='electricity.csv', help='data file') 85 | parser.add_argument('--efficient_training', type=bool, default=False, help='whether to use efficient_training (exp_name should be partial train)') # See Figure 8 of our paper for the detail 86 | parser.add_argument('--use_norm', type=int, default=True, help='use norm and denorm') 87 | parser.add_argument('--partial_start_index', type=int, default=0, help='the start index of variates for partial training, ' 88 | 'you can select [partial_start_index, min(enc_in + partial_start_index, N)]') 89 | 90 | args = parser.parse_args() 91 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False 92 | 93 | if args.use_gpu and args.use_multi_gpu: 94 | args.devices = args.devices.replace(' ', '') 95 | device_ids = args.devices.split(',') 96 | args.device_ids = [int(id_) for id_ in device_ids] 97 | args.gpu = args.device_ids[0] 98 | 99 | print('Args in experiment:') 100 | print(args) 101 | 102 | if args.exp_name == 'partial_train': # See Figure 8 of our paper, for the detail 103 | Exp = Exp_Long_Term_Forecast_Partial 104 | else: # MTSF: multivariate time series forecasting 105 | Exp = Exp_Long_Term_Forecast 106 | 107 | 108 | if args.is_training: 109 | for ii in range(args.itr): 110 | # setting record of experiments 111 | setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format( 112 | args.model_id, 113 | args.model, 114 | args.data, 115 | args.features, 116 | args.seq_len, 117 | args.label_len, 118 | args.pred_len, 119 | args.d_model, 120 | args.n_heads, 121 | args.e_layers, 122 | args.d_layers, 123 | args.d_ff, 124 | args.factor, 125 | args.embed, 126 | args.distil, 127 | args.des, 128 | args.class_strategy, ii) 129 | 130 | exp = Exp(args) # set experiments 131 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) 132 | exp.train(setting) 133 | 134 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 135 | exp.test(setting) 136 | 137 | if args.do_predict: 138 | print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 139 | exp.predict(setting, True) 140 | 141 | torch.cuda.empty_cache() 142 | else: 143 | ii = 0 144 | setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format( 145 | args.model_id, 146 | args.model, 147 | args.data, 148 | args.features, 149 | args.seq_len, 150 | args.label_len, 151 | args.pred_len, 152 | args.d_model, 153 | args.n_heads, 154 | args.e_layers, 155 | args.d_layers, 156 | args.d_ff, 157 | args.factor, 158 | args.embed, 159 | args.distil, 160 | args.des, 161 | args.class_strategy, ii) 162 | 163 | exp = Exp(args) # set experiments 164 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 165 | exp.test(setting, test=1) 166 | torch.cuda.empty_cache() 167 | -------------------------------------------------------------------------------- /scripts/boost_performance/ECL/iFlowformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=iFlowformer 4 | # model_name=Flowformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/electricity/ \ 9 | --data_path electricity.csv \ 10 | --model_id ECL_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 321 \ 18 | --dec_in 321 \ 19 | --c_out 321 \ 20 | --des 'Exp' \ 21 | --itr 1 22 | 23 | python -u run.py \ 24 | --is_training 1 \ 25 | --root_path ./dataset/electricity/ \ 26 | --data_path electricity.csv \ 27 | --model_id ECL_96_192 \ 28 | --model $model_name \ 29 | --data custom \ 30 | --features M \ 31 | --seq_len 96 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --enc_in 321 \ 35 | --dec_in 321 \ 36 | --c_out 321 \ 37 | --des 'Exp' \ 38 | --itr 1 39 | 40 | python -u run.py \ 41 | --is_training 1 \ 42 | --root_path ./dataset/electricity/ \ 43 | --data_path electricity.csv \ 44 | --model_id ECL_96_336 \ 45 | --model $model_name \ 46 | --data custom \ 47 | --features M \ 48 | --seq_len 96 \ 49 | --pred_len 336 \ 50 | --e_layers 2 \ 51 | --enc_in 321 \ 52 | --dec_in 321 \ 53 | --c_out 321 \ 54 | --des 'Exp' \ 55 | --itr 1 56 | 57 | python -u run.py \ 58 | --is_training 1 \ 59 | --root_path ./dataset/electricity/ \ 60 | --data_path electricity.csv \ 61 | --model_id ECL_96_720 \ 62 | --model $model_name \ 63 | --data custom \ 64 | --features M \ 65 | --seq_len 96 \ 66 | --pred_len 720 \ 67 | --e_layers 2 \ 68 | --enc_in 321 \ 69 | --dec_in 321 \ 70 | --c_out 321 \ 71 | --des 'Exp' \ 72 | --itr 1 73 | -------------------------------------------------------------------------------- /scripts/boost_performance/ECL/iInformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=iInformer 4 | # model_name=Informer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/electricity/ \ 9 | --data_path electricity.csv \ 10 | --model_id ECL_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 321 \ 18 | --dec_in 321 \ 19 | --c_out 321 \ 20 | --des 'Exp' \ 21 | --itr 1 22 | 23 | python -u run.py \ 24 | --is_training 1 \ 25 | --root_path ./dataset/electricity/ \ 26 | --data_path electricity.csv \ 27 | --model_id ECL_96_192 \ 28 | --model $model_name \ 29 | --data custom \ 30 | --features M \ 31 | --seq_len 96 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --enc_in 321 \ 35 | --dec_in 321 \ 36 | --c_out 321 \ 37 | --des 'Exp' \ 38 | --itr 1 39 | 40 | python -u run.py \ 41 | --is_training 1 \ 42 | --root_path ./dataset/electricity/ \ 43 | --data_path electricity.csv \ 44 | --model_id ECL_96_336 \ 45 | --model $model_name \ 46 | --data custom \ 47 | --features M \ 48 | --seq_len 96 \ 49 | --pred_len 336 \ 50 | --e_layers 2 \ 51 | --enc_in 321 \ 52 | --dec_in 321 \ 53 | --c_out 321 \ 54 | --des 'Exp' \ 55 | --itr 1 56 | 57 | python -u run.py \ 58 | --is_training 1 \ 59 | --root_path ./dataset/electricity/ \ 60 | --data_path electricity.csv \ 61 | --model_id ECL_96_720 \ 62 | --model $model_name \ 63 | --data custom \ 64 | --features M \ 65 | --seq_len 96 \ 66 | --pred_len 720 \ 67 | --e_layers 2 \ 68 | --enc_in 321 \ 69 | --dec_in 321 \ 70 | --c_out 321 \ 71 | --des 'Exp' \ 72 | --itr 1 73 | -------------------------------------------------------------------------------- /scripts/boost_performance/ECL/iReformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=iReformer 4 | # model_name=Reformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/electricity/ \ 9 | --data_path electricity.csv \ 10 | --model_id ECL_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 321 \ 18 | --dec_in 321 \ 19 | --c_out 321 \ 20 | --des 'Exp' \ 21 | --itr 1 22 | 23 | python -u run.py \ 24 | --is_training 1 \ 25 | --root_path ./dataset/electricity/ \ 26 | --data_path electricity.csv \ 27 | --model_id ECL_96_192 \ 28 | --model $model_name \ 29 | --data custom \ 30 | --features M \ 31 | --seq_len 96 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --enc_in 321 \ 35 | --dec_in 321 \ 36 | --c_out 321 \ 37 | --des 'Exp' \ 38 | --itr 1 39 | 40 | python -u run.py \ 41 | --is_training 1 \ 42 | --root_path ./dataset/electricity/ \ 43 | --data_path electricity.csv \ 44 | --model_id ECL_96_336 \ 45 | --model $model_name \ 46 | --data custom \ 47 | --features M \ 48 | --seq_len 96 \ 49 | --pred_len 336 \ 50 | --e_layers 2 \ 51 | --enc_in 321 \ 52 | --dec_in 321 \ 53 | --c_out 321 \ 54 | --des 'Exp' \ 55 | --itr 1 56 | 57 | python -u run.py \ 58 | --is_training 1 \ 59 | --root_path ./dataset/electricity/ \ 60 | --data_path electricity.csv \ 61 | --model_id ECL_96_720 \ 62 | --model $model_name \ 63 | --data custom \ 64 | --features M \ 65 | --seq_len 96 \ 66 | --pred_len 720 \ 67 | --e_layers 2 \ 68 | --enc_in 321 \ 69 | --dec_in 321 \ 70 | --c_out 321 \ 71 | --des 'Exp' \ 72 | --itr 1 73 | -------------------------------------------------------------------------------- /scripts/boost_performance/ECL/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=iTransformer 4 | # model_name=Transformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/electricity/ \ 9 | --data_path electricity.csv \ 10 | --model_id ECL_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 321 \ 18 | --dec_in 321 \ 19 | --c_out 321 \ 20 | --des 'Exp' \ 21 | --itr 1 22 | 23 | python -u run.py \ 24 | --is_training 1 \ 25 | --root_path ./dataset/electricity/ \ 26 | --data_path electricity.csv \ 27 | --model_id ECL_96_192 \ 28 | --model $model_name \ 29 | --data custom \ 30 | --features M \ 31 | --seq_len 96 \ 32 | --pred_len 192 \ 33 | --e_layers 2 \ 34 | --enc_in 321 \ 35 | --dec_in 321 \ 36 | --c_out 321 \ 37 | --des 'Exp' \ 38 | --itr 1 39 | 40 | python -u run.py \ 41 | --is_training 1 \ 42 | --root_path ./dataset/electricity/ \ 43 | --data_path electricity.csv \ 44 | --model_id ECL_96_336 \ 45 | --model $model_name \ 46 | --data custom \ 47 | --features M \ 48 | --seq_len 96 \ 49 | --pred_len 336 \ 50 | --e_layers 2 \ 51 | --enc_in 321 \ 52 | --dec_in 321 \ 53 | --c_out 321 \ 54 | --des 'Exp' \ 55 | --itr 1 56 | 57 | python -u run.py \ 58 | --is_training 1 \ 59 | --root_path ./dataset/electricity/ \ 60 | --data_path electricity.csv \ 61 | --model_id ECL_96_720 \ 62 | --model $model_name \ 63 | --data custom \ 64 | --features M \ 65 | --seq_len 96 \ 66 | --pred_len 720 \ 67 | --e_layers 2 \ 68 | --enc_in 321 \ 69 | --dec_in 321 \ 70 | --c_out 321 \ 71 | --des 'Exp' \ 72 | --itr 1 73 | -------------------------------------------------------------------------------- /scripts/boost_performance/README.md: -------------------------------------------------------------------------------- 1 | # Inverted Transformers Work Better for Time Series Forecasting 2 | 3 | This folder contains the comparison of the vanilla Transformer-based forecasters and the inverted versions. If you are new to this repo, we recommend you to have a look at this [README](../multivariate_forecasting/README.md) first. 4 | 5 | ## Scripts 6 | 7 | In each folder named after the dataset, we compare the performance of iTransformers and the vanilla Transformers. 8 | 9 | ``` 10 | # iTransformer on the Traffic Dataset with gradually enlarged lookback windows. 11 | 12 | bash ./scripts/boost_performance/Traffic/iTransformer.sh 13 | ``` 14 | 15 | You can change the ```model_name``` in the script to select one Transformer variant and its inverted version. 16 | 17 | ## Results 18 | We compare the performance of Transformer and iTransformer on all six datasets, indicating that the attention and feed-forward network on the 19 | inverted dimensions greatly empower Transformers in multivariate time series forecasting. 20 | 21 |

22 | 23 |

24 | 25 | We apply the proposed inverted framework to Transformer and its variants. It demonstrates that our iTransformers framework can consistently promote these Transformer variants, 26 | and take advantage of the booming efficient attention mechanisms. 27 |

28 | 29 |

30 | -------------------------------------------------------------------------------- /scripts/boost_performance/Traffic/iFlowformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=iFlowformer 4 | # model_name=Flowformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/traffic/ \ 9 | --data_path traffic.csv \ 10 | --model_id traffic_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 862 \ 18 | --dec_in 862 \ 19 | --c_out 862 \ 20 | --des 'Exp' \ 21 | --itr 1 \ 22 | --train_epochs 3 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/traffic/ \ 27 | --data_path traffic.csv \ 28 | --model_id traffic_96_192 \ 29 | --model $model_name \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --enc_in 862 \ 36 | --dec_in 862 \ 37 | --c_out 862 \ 38 | --des 'Exp' \ 39 | --itr 1 \ 40 | --train_epochs 3 41 | 42 | python -u run.py \ 43 | --is_training 1 \ 44 | --root_path ./dataset/traffic/ \ 45 | --data_path traffic.csv \ 46 | --model_id traffic_96_336 \ 47 | --model $model_name \ 48 | --data custom \ 49 | --features M \ 50 | --seq_len 96 \ 51 | --pred_len 336 \ 52 | --e_layers 2 \ 53 | --enc_in 862 \ 54 | --dec_in 862 \ 55 | --c_out 862 \ 56 | --des 'Exp' \ 57 | --itr 1 \ 58 | --train_epochs 3 59 | 60 | python -u run.py \ 61 | --is_training 1 \ 62 | --root_path ./dataset/traffic/ \ 63 | --data_path traffic.csv \ 64 | --model_id traffic_96_720 \ 65 | --model $model_name \ 66 | --data custom \ 67 | --features M \ 68 | --seq_len 96 \ 69 | --pred_len 720 \ 70 | --e_layers 2 \ 71 | --enc_in 862 \ 72 | --dec_in 862 \ 73 | --c_out 862 \ 74 | --des 'Exp' \ 75 | --itr 1 \ 76 | --train_epochs 3 77 | -------------------------------------------------------------------------------- /scripts/boost_performance/Traffic/iInformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | # model_name=iInformer 4 | model_name=Informer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/traffic/ \ 9 | --data_path traffic.csv \ 10 | --model_id traffic_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 862 \ 18 | --dec_in 862 \ 19 | --c_out 862 \ 20 | --des 'Exp' \ 21 | --itr 1 \ 22 | --train_epochs 3 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/traffic/ \ 27 | --data_path traffic.csv \ 28 | --model_id traffic_96_192 \ 29 | --model $model_name \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --enc_in 862 \ 36 | --dec_in 862 \ 37 | --c_out 862 \ 38 | --des 'Exp' \ 39 | --itr 1 \ 40 | --train_epochs 3 41 | 42 | python -u run.py \ 43 | --is_training 1 \ 44 | --root_path ./dataset/traffic/ \ 45 | --data_path traffic.csv \ 46 | --model_id traffic_96_336 \ 47 | --model $model_name \ 48 | --data custom \ 49 | --features M \ 50 | --seq_len 96 \ 51 | --pred_len 336 \ 52 | --e_layers 2 \ 53 | --enc_in 862 \ 54 | --dec_in 862 \ 55 | --c_out 862 \ 56 | --des 'Exp' \ 57 | --itr 1 \ 58 | --train_epochs 3 59 | 60 | python -u run.py \ 61 | --is_training 1 \ 62 | --root_path ./dataset/traffic/ \ 63 | --data_path traffic.csv \ 64 | --model_id traffic_96_720 \ 65 | --model $model_name \ 66 | --data custom \ 67 | --features M \ 68 | --seq_len 96 \ 69 | --pred_len 720 \ 70 | --e_layers 2 \ 71 | --enc_in 862 \ 72 | --dec_in 862 \ 73 | --c_out 862 \ 74 | --des 'Exp' \ 75 | --itr 1 \ 76 | --train_epochs 3 77 | -------------------------------------------------------------------------------- /scripts/boost_performance/Traffic/iReformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=iReformer 4 | # model_name=Reformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/traffic/ \ 9 | --data_path traffic.csv \ 10 | --model_id traffic_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 862 \ 18 | --dec_in 862 \ 19 | --c_out 862 \ 20 | --des 'Exp' \ 21 | --itr 1 \ 22 | --train_epochs 3 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/traffic/ \ 27 | --data_path traffic.csv \ 28 | --model_id traffic_96_192 \ 29 | --model $model_name \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --enc_in 862 \ 36 | --dec_in 862 \ 37 | --c_out 862 \ 38 | --des 'Exp' \ 39 | --itr 1 \ 40 | --train_epochs 3 41 | 42 | python -u run.py \ 43 | --is_training 1 \ 44 | --root_path ./dataset/traffic/ \ 45 | --data_path traffic.csv \ 46 | --model_id traffic_96_336 \ 47 | --model $model_name \ 48 | --data custom \ 49 | --features M \ 50 | --seq_len 96 \ 51 | --pred_len 336 \ 52 | --e_layers 2 \ 53 | --enc_in 862 \ 54 | --dec_in 862 \ 55 | --c_out 862 \ 56 | --des 'Exp' \ 57 | --itr 1 \ 58 | --train_epochs 3 59 | 60 | python -u run.py \ 61 | --is_training 1 \ 62 | --root_path ./dataset/traffic/ \ 63 | --data_path traffic.csv \ 64 | --model_id traffic_96_720 \ 65 | --model $model_name \ 66 | --data custom \ 67 | --features M \ 68 | --seq_len 96 \ 69 | --pred_len 720 \ 70 | --e_layers 2 \ 71 | --enc_in 862 \ 72 | --dec_in 862 \ 73 | --c_out 862 \ 74 | --des 'Exp' \ 75 | --itr 1 \ 76 | --train_epochs 3 77 | -------------------------------------------------------------------------------- /scripts/boost_performance/Traffic/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=iTransformer 4 | #model_name=Transformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/traffic/ \ 9 | --data_path traffic.csv \ 10 | --model_id traffic_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 862 \ 18 | --dec_in 862 \ 19 | --c_out 862 \ 20 | --des 'Exp' \ 21 | --itr 1 \ 22 | --train_epochs 3 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/traffic/ \ 27 | --data_path traffic.csv \ 28 | --model_id traffic_96_192 \ 29 | --model $model_name \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --enc_in 862 \ 36 | --dec_in 862 \ 37 | --c_out 862 \ 38 | --des 'Exp' \ 39 | --itr 1 \ 40 | --train_epochs 3 41 | 42 | python -u run.py \ 43 | --is_training 1 \ 44 | --root_path ./dataset/traffic/ \ 45 | --data_path traffic.csv \ 46 | --model_id traffic_96_336 \ 47 | --model $model_name \ 48 | --data custom \ 49 | --features M \ 50 | --seq_len 96 \ 51 | --pred_len 336 \ 52 | --e_layers 2 \ 53 | --enc_in 862 \ 54 | --dec_in 862 \ 55 | --c_out 862 \ 56 | --des 'Exp' \ 57 | --itr 1 \ 58 | --train_epochs 3 59 | 60 | python -u run.py \ 61 | --is_training 1 \ 62 | --root_path ./dataset/traffic/ \ 63 | --data_path traffic.csv \ 64 | --model_id traffic_96_720 \ 65 | --model $model_name \ 66 | --data custom \ 67 | --features M \ 68 | --seq_len 96 \ 69 | --pred_len 720 \ 70 | --e_layers 2 \ 71 | --enc_in 862 \ 72 | --dec_in 862 \ 73 | --c_out 862 \ 74 | --des 'Exp' \ 75 | --itr 1 \ 76 | --train_epochs 3 77 | -------------------------------------------------------------------------------- /scripts/boost_performance/Weather/iFlowformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=iFlowformer 4 | # model_name=Flowformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/weather/ \ 9 | --data_path weather.csv \ 10 | --model_id weather_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 21 \ 18 | --dec_in 21 \ 19 | --c_out 21 \ 20 | --des 'Exp' \ 21 | --itr 1 \ 22 | --batch_size 128 \ 23 | --train_epochs 3 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/weather/ \ 28 | --data_path weather.csv \ 29 | --model_id weather_96_192 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --enc_in 21 \ 37 | --dec_in 21 \ 38 | --c_out 21 \ 39 | --des 'Exp' \ 40 | --batch_size 128 \ 41 | --itr 1 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/weather/ \ 46 | --data_path weather.csv \ 47 | --model_id weather_96_336 \ 48 | --model $model_name \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --pred_len 336 \ 53 | --e_layers 2 \ 54 | --enc_in 21 \ 55 | --dec_in 21 \ 56 | --c_out 21 \ 57 | --des 'Exp' \ 58 | --batch_size 128 \ 59 | --itr 1 60 | 61 | python -u run.py \ 62 | --is_training 1 \ 63 | --root_path ./dataset/weather/ \ 64 | --data_path weather.csv \ 65 | --model_id weather_96_720 \ 66 | --model $model_name \ 67 | --data custom \ 68 | --features M \ 69 | --seq_len 96 \ 70 | --pred_len 720 \ 71 | --e_layers 2 \ 72 | --enc_in 21 \ 73 | --dec_in 21 \ 74 | --c_out 21 \ 75 | --des 'Exp' \ 76 | --batch_size 128 \ 77 | --itr 1 78 | -------------------------------------------------------------------------------- /scripts/boost_performance/Weather/iInformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=iInformer 4 | #model_name=Informer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/weather/ \ 9 | --data_path weather.csv \ 10 | --model_id weather_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 21 \ 18 | --dec_in 21 \ 19 | --c_out 21 \ 20 | --des 'Exp' \ 21 | --itr 1 \ 22 | --batch_size 128 \ 23 | --train_epochs 3 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/weather/ \ 28 | --data_path weather.csv \ 29 | --model_id weather_96_192 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --enc_in 21 \ 37 | --dec_in 21 \ 38 | --c_out 21 \ 39 | --des 'Exp' \ 40 | --batch_size 128 \ 41 | --itr 1 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/weather/ \ 46 | --data_path weather.csv \ 47 | --model_id weather_96_336 \ 48 | --model $model_name \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --pred_len 336 \ 53 | --e_layers 2 \ 54 | --enc_in 21 \ 55 | --dec_in 21 \ 56 | --c_out 21 \ 57 | --des 'Exp' \ 58 | --batch_size 128 \ 59 | --itr 1 60 | 61 | python -u run.py \ 62 | --is_training 1 \ 63 | --root_path ./dataset/weather/ \ 64 | --data_path weather.csv \ 65 | --model_id weather_96_720 \ 66 | --model $model_name \ 67 | --data custom \ 68 | --features M \ 69 | --seq_len 96 \ 70 | --pred_len 720 \ 71 | --e_layers 2 \ 72 | --enc_in 21 \ 73 | --dec_in 21 \ 74 | --c_out 21 \ 75 | --des 'Exp' \ 76 | --batch_size 128 \ 77 | --itr 1 78 | -------------------------------------------------------------------------------- /scripts/boost_performance/Weather/iReformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=iReformer 4 | #model_name=Reformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/weather/ \ 9 | --data_path weather.csv \ 10 | --model_id weather_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 21 \ 18 | --dec_in 21 \ 19 | --c_out 21 \ 20 | --des 'Exp' \ 21 | --itr 1 \ 22 | --batch_size 128 \ 23 | --train_epochs 3 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/weather/ \ 28 | --data_path weather.csv \ 29 | --model_id weather_96_192 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --enc_in 21 \ 37 | --dec_in 21 \ 38 | --c_out 21 \ 39 | --des 'Exp' \ 40 | --batch_size 128 \ 41 | --itr 1 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/weather/ \ 46 | --data_path weather.csv \ 47 | --model_id weather_96_336 \ 48 | --model $model_name \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --pred_len 336 \ 53 | --e_layers 2 \ 54 | --enc_in 21 \ 55 | --dec_in 21 \ 56 | --c_out 21 \ 57 | --des 'Exp' \ 58 | --batch_size 128 \ 59 | --itr 1 60 | 61 | python -u run.py \ 62 | --is_training 1 \ 63 | --root_path ./dataset/weather/ \ 64 | --data_path weather.csv \ 65 | --model_id weather_96_720 \ 66 | --model $model_name \ 67 | --data custom \ 68 | --features M \ 69 | --seq_len 96 \ 70 | --pred_len 720 \ 71 | --e_layers 2 \ 72 | --enc_in 21 \ 73 | --dec_in 21 \ 74 | --c_out 21 \ 75 | --des 'Exp' \ 76 | --batch_size 128 \ 77 | --itr 1 78 | -------------------------------------------------------------------------------- /scripts/boost_performance/Weather/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=iTransformer 4 | #model_name=Transformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/weather/ \ 9 | --data_path weather.csv \ 10 | --model_id weather_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --enc_in 21 \ 18 | --dec_in 21 \ 19 | --c_out 21 \ 20 | --des 'Exp' \ 21 | --itr 1 \ 22 | --batch_size 128 \ 23 | --train_epochs 3 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/weather/ \ 28 | --data_path weather.csv \ 29 | --model_id weather_96_192 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --enc_in 21 \ 37 | --dec_in 21 \ 38 | --c_out 21 \ 39 | --des 'Exp' \ 40 | --batch_size 128 \ 41 | --itr 1 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/weather/ \ 46 | --data_path weather.csv \ 47 | --model_id weather_96_336 \ 48 | --model $model_name \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --pred_len 336 \ 53 | --e_layers 2 \ 54 | --enc_in 21 \ 55 | --dec_in 21 \ 56 | --c_out 21 \ 57 | --des 'Exp' \ 58 | --batch_size 128 \ 59 | --itr 1 60 | 61 | python -u run.py \ 62 | --is_training 1 \ 63 | --root_path ./dataset/weather/ \ 64 | --data_path weather.csv \ 65 | --model_id weather_96_720 \ 66 | --model $model_name \ 67 | --data custom \ 68 | --features M \ 69 | --seq_len 96 \ 70 | --pred_len 720 \ 71 | --e_layers 2 \ 72 | --enc_in 21 \ 73 | --dec_in 21 \ 74 | --c_out 21 \ 75 | --des 'Exp' \ 76 | --batch_size 128 \ 77 | --itr 1 78 | -------------------------------------------------------------------------------- /scripts/increasing_lookback/ECL/iFlowformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | # model_name=Flowformer 4 | model_name=iFlowformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/electricity/ \ 9 | --data_path electricity.csv \ 10 | --model_id ECL_48_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 48 \ 15 | --pred_len 96 \ 16 | --e_layers 3 \ 17 | --enc_in 321 \ 18 | --dec_in 321 \ 19 | --c_out 321 \ 20 | --des 'Exp' \ 21 | --d_model 512 \ 22 | --d_ff 512 \ 23 | --batch_size 16 \ 24 | --learning_rate 0.0005\ 25 | --itr 1 26 | 27 | python -u run.py \ 28 | --is_training 1 \ 29 | --root_path ./dataset/electricity/ \ 30 | --data_path electricity.csv \ 31 | --model_id ECL_96_96 \ 32 | --model $model_name \ 33 | --data custom \ 34 | --features M \ 35 | --seq_len 96 \ 36 | --pred_len 96 \ 37 | --e_layers 3 \ 38 | --enc_in 321 \ 39 | --dec_in 321 \ 40 | --c_out 321 \ 41 | --des 'Exp' \ 42 | --d_model 512 \ 43 | --d_ff 512 \ 44 | --batch_size 16 \ 45 | --learning_rate 0.0005\ 46 | --itr 1 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/electricity/ \ 51 | --data_path electricity.csv \ 52 | --model_id ECL_192_96 \ 53 | --model $model_name \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 192 \ 57 | --pred_len 96 \ 58 | --e_layers 3 \ 59 | --enc_in 321 \ 60 | --dec_in 321 \ 61 | --c_out 321 \ 62 | --des 'Exp' \ 63 | --d_model 512 \ 64 | --d_ff 512 \ 65 | --batch_size 16 \ 66 | --learning_rate 0.0005 \ 67 | --itr 1 68 | 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path ./dataset/electricity/ \ 72 | --data_path electricity.csv \ 73 | --model_id ECL_336_96 \ 74 | --model $model_name \ 75 | --data custom \ 76 | --features M \ 77 | --seq_len 336 \ 78 | --pred_len 96 \ 79 | --e_layers 3 \ 80 | --enc_in 321 \ 81 | --dec_in 321 \ 82 | --c_out 321 \ 83 | --des 'Exp' \ 84 | --d_model 512 \ 85 | --d_ff 512 \ 86 | --batch_size 16 \ 87 | --learning_rate 0.0005 \ 88 | --itr 1 89 | 90 | python -u run.py \ 91 | --is_training 1 \ 92 | --root_path ./dataset/electricity/ \ 93 | --data_path electricity.csv \ 94 | --model_id ECL_720_96 \ 95 | --model $model_name \ 96 | --data custom \ 97 | --features M \ 98 | --seq_len 720 \ 99 | --pred_len 96 \ 100 | --e_layers 3 \ 101 | --enc_in 321 \ 102 | --dec_in 321 \ 103 | --c_out 321 \ 104 | --des 'Exp' \ 105 | --d_model 512 \ 106 | --d_ff 512 \ 107 | --batch_size 16 \ 108 | --learning_rate 0.0005 \ 109 | --itr 1 110 | -------------------------------------------------------------------------------- /scripts/increasing_lookback/ECL/iInformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | # model_name=Informer 4 | model_name=iInformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/electricity/ \ 9 | --data_path electricity.csv \ 10 | --model_id ECL_48_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 48 \ 15 | --pred_len 96 \ 16 | --e_layers 3 \ 17 | --enc_in 321 \ 18 | --dec_in 321 \ 19 | --c_out 321 \ 20 | --des 'Exp' \ 21 | --d_model 512 \ 22 | --d_ff 512 \ 23 | --batch_size 16 \ 24 | --learning_rate 0.0005 \ 25 | --itr 1 26 | 27 | python -u run.py \ 28 | --is_training 1 \ 29 | --root_path ./dataset/electricity/ \ 30 | --data_path electricity.csv \ 31 | --model_id ECL_96_96 \ 32 | --model $model_name \ 33 | --data custom \ 34 | --features M \ 35 | --seq_len 96 \ 36 | --pred_len 96 \ 37 | --e_layers 3 \ 38 | --enc_in 321 \ 39 | --dec_in 321 \ 40 | --c_out 321 \ 41 | --des 'Exp' \ 42 | --d_model 512 \ 43 | --d_ff 512 \ 44 | --batch_size 16 \ 45 | --learning_rate 0.0005 \ 46 | --itr 1 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/electricity/ \ 51 | --data_path electricity.csv \ 52 | --model_id ECL_192_96 \ 53 | --model $model_name \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 192 \ 57 | --pred_len 96 \ 58 | --e_layers 3 \ 59 | --enc_in 321 \ 60 | --dec_in 321 \ 61 | --c_out 321 \ 62 | --des 'Exp' \ 63 | --d_model 512 \ 64 | --d_ff 512 \ 65 | --batch_size 16 \ 66 | --learning_rate 0.0005 \ 67 | --itr 1 68 | 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path ./dataset/electricity/ \ 72 | --data_path electricity.csv \ 73 | --model_id ECL_336_96 \ 74 | --model $model_name \ 75 | --data custom \ 76 | --features M \ 77 | --seq_len 336 \ 78 | --pred_len 96 \ 79 | --e_layers 3 \ 80 | --enc_in 321 \ 81 | --dec_in 321 \ 82 | --c_out 321 \ 83 | --des 'Exp' \ 84 | --d_model 512 \ 85 | --d_ff 512 \ 86 | --batch_size 16 \ 87 | --learning_rate 0.0005 \ 88 | --itr 1 89 | 90 | python -u run.py \ 91 | --is_training 1 \ 92 | --root_path ./dataset/electricity/ \ 93 | --data_path electricity.csv \ 94 | --model_id ECL_720_96 \ 95 | --model $model_name \ 96 | --data custom \ 97 | --features M \ 98 | --seq_len 720 \ 99 | --pred_len 96 \ 100 | --e_layers 3 \ 101 | --enc_in 321 \ 102 | --dec_in 321 \ 103 | --c_out 321 \ 104 | --des 'Exp' \ 105 | --d_model 512 \ 106 | --d_ff 512 \ 107 | --batch_size 16 \ 108 | --learning_rate 0.0005 \ 109 | --itr 1 110 | -------------------------------------------------------------------------------- /scripts/increasing_lookback/ECL/iReformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | # model_name=Reformer 4 | model_name=iReformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/electricity/ \ 9 | --data_path electricity.csv \ 10 | --model_id ECL_48_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 48 \ 15 | --pred_len 96 \ 16 | --e_layers 3 \ 17 | --enc_in 321 \ 18 | --dec_in 321 \ 19 | --c_out 321 \ 20 | --des 'Exp' \ 21 | --d_model 512 \ 22 | --d_ff 512 \ 23 | --batch_size 16 \ 24 | --learning_rate 0.0005 \ 25 | --itr 1 26 | 27 | python -u run.py \ 28 | --is_training 1 \ 29 | --root_path ./dataset/electricity/ \ 30 | --data_path electricity.csv \ 31 | --model_id ECL_96_96 \ 32 | --model $model_name \ 33 | --data custom \ 34 | --features M \ 35 | --seq_len 96 \ 36 | --pred_len 96 \ 37 | --e_layers 3 \ 38 | --enc_in 321 \ 39 | --dec_in 321 \ 40 | --c_out 321 \ 41 | --des 'Exp' \ 42 | --d_model 512 \ 43 | --d_ff 512 \ 44 | --batch_size 16 \ 45 | --learning_rate 0.0005 \ 46 | --itr 1 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/electricity/ \ 51 | --data_path electricity.csv \ 52 | --model_id ECL_192_96 \ 53 | --model $model_name \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 192 \ 57 | --pred_len 96 \ 58 | --e_layers 3 \ 59 | --enc_in 321 \ 60 | --dec_in 321 \ 61 | --c_out 321 \ 62 | --des 'Exp' \ 63 | --d_model 512 \ 64 | --d_ff 512 \ 65 | --batch_size 16 \ 66 | --learning_rate 0.0005 \ 67 | --itr 1 68 | 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path ./dataset/electricity/ \ 72 | --data_path electricity.csv \ 73 | --model_id ECL_336_96 \ 74 | --model $model_name \ 75 | --data custom \ 76 | --features M \ 77 | --seq_len 336 \ 78 | --pred_len 96 \ 79 | --e_layers 3 \ 80 | --enc_in 321 \ 81 | --dec_in 321 \ 82 | --c_out 321 \ 83 | --des 'Exp' \ 84 | --d_model 512 \ 85 | --d_ff 512 \ 86 | --batch_size 16 \ 87 | --learning_rate 0.0005 \ 88 | --itr 1 89 | 90 | python -u run.py \ 91 | --is_training 1 \ 92 | --root_path ./dataset/electricity/ \ 93 | --data_path electricity.csv \ 94 | --model_id ECL_720_96 \ 95 | --model $model_name \ 96 | --data custom \ 97 | --features M \ 98 | --seq_len 720 \ 99 | --pred_len 96 \ 100 | --e_layers 3 \ 101 | --enc_in 321 \ 102 | --dec_in 321 \ 103 | --c_out 321 \ 104 | --des 'Exp' \ 105 | --d_model 512 \ 106 | --d_ff 512 \ 107 | --batch_size 16 \ 108 | --learning_rate 0.0005 \ 109 | --itr 1 110 | -------------------------------------------------------------------------------- /scripts/increasing_lookback/ECL/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | # model_name=Transformer 4 | model_name=iTransformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/electricity/ \ 9 | --data_path electricity.csv \ 10 | --model_id ECL_48_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 48 \ 15 | --pred_len 96 \ 16 | --e_layers 3 \ 17 | --enc_in 321 \ 18 | --dec_in 321 \ 19 | --c_out 321 \ 20 | --des 'Exp' \ 21 | --d_model 512 \ 22 | --d_ff 512 \ 23 | --batch_size 16 \ 24 | --learning_rate 0.0005 \ 25 | --itr 1 26 | 27 | python -u run.py \ 28 | --is_training 1 \ 29 | --root_path ./dataset/electricity/ \ 30 | --data_path electricity.csv \ 31 | --model_id ECL_96_96 \ 32 | --model $model_name \ 33 | --data custom \ 34 | --features M \ 35 | --seq_len 96 \ 36 | --pred_len 96 \ 37 | --e_layers 3 \ 38 | --enc_in 321 \ 39 | --dec_in 321 \ 40 | --c_out 321 \ 41 | --des 'Exp' \ 42 | --d_model 512 \ 43 | --d_ff 512 \ 44 | --batch_size 16 \ 45 | --learning_rate 0.0005 \ 46 | --itr 1 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/electricity/ \ 51 | --data_path electricity.csv \ 52 | --model_id ECL_192_96 \ 53 | --model $model_name \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 192 \ 57 | --pred_len 96 \ 58 | --e_layers 3 \ 59 | --enc_in 321 \ 60 | --dec_in 321 \ 61 | --c_out 321 \ 62 | --des 'Exp' \ 63 | --d_model 512 \ 64 | --d_ff 512 \ 65 | --batch_size 16 \ 66 | --learning_rate 0.0005 \ 67 | --itr 1 68 | 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path ./dataset/electricity/ \ 72 | --data_path electricity.csv \ 73 | --model_id ECL_336_96 \ 74 | --model $model_name \ 75 | --data custom \ 76 | --features M \ 77 | --seq_len 336 \ 78 | --pred_len 96 \ 79 | --e_layers 3 \ 80 | --enc_in 321 \ 81 | --dec_in 321 \ 82 | --c_out 321 \ 83 | --des 'Exp' \ 84 | --d_model 512 \ 85 | --d_ff 512 \ 86 | --batch_size 16 \ 87 | --learning_rate 0.0005 \ 88 | --itr 1 89 | 90 | python -u run.py \ 91 | --is_training 1 \ 92 | --root_path ./dataset/electricity/ \ 93 | --data_path electricity.csv \ 94 | --model_id ECL_720_96 \ 95 | --model $model_name \ 96 | --data custom \ 97 | --features M \ 98 | --seq_len 720 \ 99 | --pred_len 96 \ 100 | --e_layers 3 \ 101 | --enc_in 321 \ 102 | --dec_in 321 \ 103 | --c_out 321 \ 104 | --des 'Exp' \ 105 | --d_model 512 \ 106 | --d_ff 512 \ 107 | --batch_size 16 \ 108 | --learning_rate 0.0005 \ 109 | --itr 1 110 | -------------------------------------------------------------------------------- /scripts/increasing_lookback/README.md: -------------------------------------------------------------------------------- 1 | # iTransformer for Enlarged Lookback Window 2 | 3 | This folder contains the implementation of the iTransformer for an enlarged lookback window. If you are new to this repo, we recommend you to read this [README](../multivariate_forecasting/README.md) first. 4 | 5 | ## Scripts 6 | 7 | In each folder named after the dataset, we provide the iTransformers and the vanilla Transformers experiments under five increasing prediction lengths. 8 | 9 | ``` 10 | # iTransformer on the Traffic Dataset with gradually enlarged lookback windows. 11 | 12 | bash ./scripts/increasing_lookback/Traffic/iTransformer.sh 13 | ``` 14 | 15 | You can change the ```model_name``` in the script to switch the selection of the vanilla Transformer and the inverted version. 16 | 17 | ## Results 18 | 19 |

20 | 21 |

22 | 23 | The inverted framework empowers Transformers with improved performance on the enlarged lookback window. -------------------------------------------------------------------------------- /scripts/increasing_lookback/Traffic/iFlowformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | # model_name=Flowformer 4 | model_name=iFlowformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/traffic/ \ 9 | --data_path traffic.csv \ 10 | --model_id traffic_48_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 48 \ 15 | --pred_len 96 \ 16 | --e_layers 4 \ 17 | --enc_in 862 \ 18 | --dec_in 862 \ 19 | --c_out 862 \ 20 | --des 'Exp' \ 21 | --d_model 512 \ 22 | --d_ff 512 \ 23 | --batch_size 16 \ 24 | --learning_rate 0.001 \ 25 | --itr 1 26 | 27 | python -u run.py \ 28 | --is_training 1 \ 29 | --root_path ./dataset/traffic/ \ 30 | --data_path traffic.csv \ 31 | --model_id traffic_96_96 \ 32 | --model $model_name \ 33 | --data custom \ 34 | --features M \ 35 | --seq_len 96 \ 36 | --pred_len 96 \ 37 | --e_layers 4 \ 38 | --enc_in 862 \ 39 | --dec_in 862 \ 40 | --c_out 862 \ 41 | --des 'Exp' \ 42 | --d_model 512 \ 43 | --d_ff 512 \ 44 | --batch_size 16 \ 45 | --learning_rate 0.001 \ 46 | --itr 1 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/traffic/ \ 51 | --data_path traffic.csv \ 52 | --model_id traffic_192_96 \ 53 | --model $model_name \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 192 \ 57 | --pred_len 96 \ 58 | --e_layers 4 \ 59 | --enc_in 862 \ 60 | --dec_in 862 \ 61 | --c_out 862 \ 62 | --des 'Exp' \ 63 | --d_model 512 \ 64 | --d_ff 512 \ 65 | --batch_size 16 \ 66 | --learning_rate 0.001 \ 67 | --itr 1 68 | 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path ./dataset/traffic/ \ 72 | --data_path traffic.csv \ 73 | --model_id traffic_336_96 \ 74 | --model $model_name \ 75 | --data custom \ 76 | --features M \ 77 | --seq_len 336 \ 78 | --pred_len 96 \ 79 | --e_layers 4 \ 80 | --enc_in 862 \ 81 | --dec_in 862 \ 82 | --c_out 862 \ 83 | --des 'Exp' \ 84 | --d_model 512 \ 85 | --d_ff 512 \ 86 | --batch_size 16 \ 87 | --learning_rate 0.001 \ 88 | --itr 1 89 | 90 | python -u run.py \ 91 | --is_training 1 \ 92 | --root_path ./dataset/traffic/ \ 93 | --data_path traffic.csv \ 94 | --model_id traffic_720_96 \ 95 | --model $model_name \ 96 | --data custom \ 97 | --features M \ 98 | --seq_len 720 \ 99 | --pred_len 96 \ 100 | --e_layers 4 \ 101 | --enc_in 862 \ 102 | --dec_in 862 \ 103 | --c_out 862 \ 104 | --des 'Exp' \ 105 | --d_model 512 \ 106 | --d_ff 512 \ 107 | --batch_size 16 \ 108 | --learning_rate 0.001 \ 109 | --itr 1 110 | -------------------------------------------------------------------------------- /scripts/increasing_lookback/Traffic/iInformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | # model_name=Informer 4 | model_name=iInformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/traffic/ \ 9 | --data_path traffic.csv \ 10 | --model_id traffic_48_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 48 \ 15 | --pred_len 96 \ 16 | --e_layers 4 \ 17 | --enc_in 862 \ 18 | --dec_in 862 \ 19 | --c_out 862 \ 20 | --des 'Exp' \ 21 | --d_model 512 \ 22 | --d_ff 512 \ 23 | --batch_size 16 \ 24 | --learning_rate 0.001 \ 25 | --itr 1 26 | 27 | python -u run.py \ 28 | --is_training 1 \ 29 | --root_path ./dataset/traffic/ \ 30 | --data_path traffic.csv \ 31 | --model_id traffic_96_96 \ 32 | --model $model_name \ 33 | --data custom \ 34 | --features M \ 35 | --seq_len 96 \ 36 | --pred_len 96 \ 37 | --e_layers 4 \ 38 | --enc_in 862 \ 39 | --dec_in 862 \ 40 | --c_out 862 \ 41 | --des 'Exp' \ 42 | --d_model 512 \ 43 | --d_ff 512 \ 44 | --batch_size 16 \ 45 | --learning_rate 0.001 \ 46 | --itr 1 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/traffic/ \ 51 | --data_path traffic.csv \ 52 | --model_id traffic_192_96 \ 53 | --model $model_name \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 192 \ 57 | --pred_len 96 \ 58 | --e_layers 4 \ 59 | --enc_in 862 \ 60 | --dec_in 862 \ 61 | --c_out 862 \ 62 | --des 'Exp' \ 63 | --d_model 512 \ 64 | --d_ff 512 \ 65 | --batch_size 16 \ 66 | --learning_rate 0.001 \ 67 | --itr 1 68 | 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path ./dataset/traffic/ \ 72 | --data_path traffic.csv \ 73 | --model_id traffic_336_96 \ 74 | --model $model_name \ 75 | --data custom \ 76 | --features M \ 77 | --seq_len 336 \ 78 | --pred_len 96 \ 79 | --e_layers 4 \ 80 | --enc_in 862 \ 81 | --dec_in 862 \ 82 | --c_out 862 \ 83 | --des 'Exp' \ 84 | --d_model 512 \ 85 | --d_ff 512 \ 86 | --batch_size 16 \ 87 | --learning_rate 0.001 \ 88 | --itr 1 89 | 90 | python -u run.py \ 91 | --is_training 1 \ 92 | --root_path ./dataset/traffic/ \ 93 | --data_path traffic.csv \ 94 | --model_id traffic_720_96 \ 95 | --model $model_name \ 96 | --data custom \ 97 | --features M \ 98 | --seq_len 720 \ 99 | --pred_len 96 \ 100 | --e_layers 4 \ 101 | --enc_in 862 \ 102 | --dec_in 862 \ 103 | --c_out 862 \ 104 | --des 'Exp' \ 105 | --d_model 512 \ 106 | --d_ff 512 \ 107 | --batch_size 16 \ 108 | --learning_rate 0.001 \ 109 | --itr 1 110 | -------------------------------------------------------------------------------- /scripts/increasing_lookback/Traffic/iReformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | # model_name=Reformer 4 | model_name=iReformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/traffic/ \ 9 | --data_path traffic.csv \ 10 | --model_id traffic_48_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 48 \ 15 | --pred_len 96 \ 16 | --e_layers 4 \ 17 | --enc_in 862 \ 18 | --dec_in 862 \ 19 | --c_out 862 \ 20 | --des 'Exp' \ 21 | --d_model 512 \ 22 | --d_ff 512 \ 23 | --batch_size 16 \ 24 | --learning_rate 0.001 \ 25 | --itr 1 26 | 27 | python -u run.py \ 28 | --is_training 1 \ 29 | --root_path ./dataset/traffic/ \ 30 | --data_path traffic.csv \ 31 | --model_id traffic_96_96 \ 32 | --model $model_name \ 33 | --data custom \ 34 | --features M \ 35 | --seq_len 96 \ 36 | --pred_len 96 \ 37 | --e_layers 4 \ 38 | --enc_in 862 \ 39 | --dec_in 862 \ 40 | --c_out 862 \ 41 | --des 'Exp' \ 42 | --d_model 512 \ 43 | --d_ff 512 \ 44 | --batch_size 16 \ 45 | --learning_rate 0.001 \ 46 | --itr 1 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/traffic/ \ 51 | --data_path traffic.csv \ 52 | --model_id traffic_192_96 \ 53 | --model $model_name \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 192 \ 57 | --pred_len 96 \ 58 | --e_layers 4 \ 59 | --enc_in 862 \ 60 | --dec_in 862 \ 61 | --c_out 862 \ 62 | --des 'Exp' \ 63 | --d_model 512 \ 64 | --d_ff 512 \ 65 | --batch_size 16 \ 66 | --learning_rate 0.001 \ 67 | --itr 1 68 | 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path ./dataset/traffic/ \ 72 | --data_path traffic.csv \ 73 | --model_id traffic_336_96 \ 74 | --model $model_name \ 75 | --data custom \ 76 | --features M \ 77 | --seq_len 336 \ 78 | --pred_len 96 \ 79 | --e_layers 4 \ 80 | --enc_in 862 \ 81 | --dec_in 862 \ 82 | --c_out 862 \ 83 | --des 'Exp' \ 84 | --d_model 512 \ 85 | --d_ff 512 \ 86 | --batch_size 16 \ 87 | --learning_rate 0.001 \ 88 | --itr 1 89 | 90 | python -u run.py \ 91 | --is_training 1 \ 92 | --root_path ./dataset/traffic/ \ 93 | --data_path traffic.csv \ 94 | --model_id traffic_720_96 \ 95 | --model $model_name \ 96 | --data custom \ 97 | --features M \ 98 | --seq_len 720 \ 99 | --pred_len 96 \ 100 | --e_layers 4 \ 101 | --enc_in 862 \ 102 | --dec_in 862 \ 103 | --c_out 862 \ 104 | --des 'Exp' \ 105 | --d_model 512 \ 106 | --d_ff 512 \ 107 | --batch_size 16 \ 108 | --learning_rate 0.001 \ 109 | --itr 1 -------------------------------------------------------------------------------- /scripts/increasing_lookback/Traffic/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | # model_name=Transformer 4 | model_name=iTransformer 5 | 6 | python -u run.py \ 7 | --is_training 1 \ 8 | --root_path ./dataset/traffic/ \ 9 | --data_path traffic.csv \ 10 | --model_id traffic_48_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 48 \ 15 | --pred_len 96 \ 16 | --e_layers 4 \ 17 | --enc_in 862 \ 18 | --dec_in 862 \ 19 | --c_out 862 \ 20 | --des 'Exp' \ 21 | --d_model 512 \ 22 | --d_ff 512 \ 23 | --batch_size 16 \ 24 | --learning_rate 0.001 \ 25 | --itr 1 26 | 27 | python -u run.py \ 28 | --is_training 1 \ 29 | --root_path ./dataset/traffic/ \ 30 | --data_path traffic.csv \ 31 | --model_id traffic_96_96 \ 32 | --model $model_name \ 33 | --data custom \ 34 | --features M \ 35 | --seq_len 96 \ 36 | --pred_len 96 \ 37 | --factor 3 \ 38 | --enc_in 862 \ 39 | --dec_in 862 \ 40 | --c_out 862 \ 41 | --des 'Exp' \ 42 | --d_model 512 \ 43 | --d_ff 512 \ 44 | --batch_size 16 \ 45 | --learning_rate 0.001 \ 46 | --itr 1 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/traffic/ \ 51 | --data_path traffic.csv \ 52 | --model_id traffic_192_96 \ 53 | --model $model_name \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 192 \ 57 | --pred_len 96 \ 58 | --e_layers 4 \ 59 | --enc_in 862 \ 60 | --dec_in 862 \ 61 | --c_out 862 \ 62 | --des 'Exp' \ 63 | --d_model 512 \ 64 | --d_ff 512 \ 65 | --batch_size 16 \ 66 | --learning_rate 0.001 \ 67 | --itr 1 68 | 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path ./dataset/traffic/ \ 72 | --data_path traffic.csv \ 73 | --model_id traffic_336_96 \ 74 | --model $model_name \ 75 | --data custom \ 76 | --features M \ 77 | --seq_len 336 \ 78 | --pred_len 96 \ 79 | --e_layers 4 \ 80 | --enc_in 862 \ 81 | --dec_in 862 \ 82 | --c_out 862 \ 83 | --des 'Exp' \ 84 | --d_model 512 \ 85 | --d_ff 512 \ 86 | --batch_size 16 \ 87 | --learning_rate 0.001 \ 88 | --itr 1 89 | 90 | python -u run.py \ 91 | --is_training 1 \ 92 | --root_path ./dataset/traffic/ \ 93 | --data_path traffic.csv \ 94 | --model_id traffic_720_96 \ 95 | --model $model_name \ 96 | --data custom \ 97 | --features M \ 98 | --seq_len 720 \ 99 | --pred_len 96 \ 100 | --e_layers 4 \ 101 | --enc_in 862 \ 102 | --dec_in 862 \ 103 | --c_out 862 \ 104 | --des 'Exp' \ 105 | --d_model 512 \ 106 | --d_ff 512 \ 107 | --batch_size 16 \ 108 | --learning_rate 0.001 \ 109 | --itr 1 110 | -------------------------------------------------------------------------------- /scripts/model_efficiency/ECL/iFlashTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=Flashformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/electricity/ \ 8 | --data_path electricity.csv \ 9 | --model_id ECL_96_96 \ 10 | --model $model_name \ 11 | --data custom \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --label_len 48 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --d_layers 1 \ 18 | --factor 3 \ 19 | --enc_in 321 \ 20 | --dec_in 321 \ 21 | --c_out 321 \ 22 | --des 'Exp' \ 23 | --itr 1 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/electricity/ \ 28 | --data_path electricity.csv \ 29 | --model_id ECL_96_192 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --label_len 48 \ 35 | --pred_len 192 \ 36 | --e_layers 2 \ 37 | --d_layers 1 \ 38 | --factor 3 \ 39 | --enc_in 321 \ 40 | --dec_in 321 \ 41 | --c_out 321 \ 42 | --des 'Exp' \ 43 | --itr 1 44 | 45 | python -u run.py \ 46 | --is_training 1 \ 47 | --root_path ./dataset/electricity/ \ 48 | --data_path electricity.csv \ 49 | --model_id ECL_96_336 \ 50 | --model $model_name \ 51 | --data custom \ 52 | --features M \ 53 | --seq_len 96 \ 54 | --label_len 48 \ 55 | --pred_len 336 \ 56 | --e_layers 2 \ 57 | --d_layers 1 \ 58 | --factor 3 \ 59 | --enc_in 321 \ 60 | --dec_in 321 \ 61 | --c_out 321 \ 62 | --des 'Exp' \ 63 | --itr 1 64 | 65 | python -u run.py \ 66 | --is_training 1 \ 67 | --root_path ./dataset/electricity/ \ 68 | --data_path electricity.csv \ 69 | --model_id ECL_96_720 \ 70 | --model $model_name \ 71 | --data custom \ 72 | --features M \ 73 | --seq_len 96 \ 74 | --label_len 48 \ 75 | --pred_len 720 \ 76 | --e_layers 2 \ 77 | --d_layers 1 \ 78 | --factor 3 \ 79 | --enc_in 321 \ 80 | --dec_in 321 \ 81 | --c_out 321 \ 82 | --des 'Exp' \ 83 | --itr 1 84 | 85 | model_name=iFlashformer 86 | 87 | python -u run.py \ 88 | --is_training 1 \ 89 | --root_path ./dataset/electricity/ \ 90 | --data_path electricity.csv \ 91 | --model_id ECL_96_96 \ 92 | --model $model_name \ 93 | --data custom \ 94 | --features M \ 95 | --seq_len 96 \ 96 | --label_len 48 \ 97 | --pred_len 96 \ 98 | --e_layers 2 \ 99 | --d_layers 1 \ 100 | --factor 3 \ 101 | --enc_in 321 \ 102 | --dec_in 321 \ 103 | --c_out 321 \ 104 | --des 'Exp' \ 105 | --itr 1 106 | 107 | python -u run.py \ 108 | --is_training 1 \ 109 | --root_path ./dataset/electricity/ \ 110 | --data_path electricity.csv \ 111 | --model_id ECL_96_192 \ 112 | --model $model_name \ 113 | --data custom \ 114 | --features M \ 115 | --seq_len 96 \ 116 | --label_len 48 \ 117 | --pred_len 192 \ 118 | --e_layers 2 \ 119 | --d_layers 1 \ 120 | --factor 3 \ 121 | --enc_in 321 \ 122 | --dec_in 321 \ 123 | --c_out 321 \ 124 | --des 'Exp' \ 125 | --itr 1 126 | 127 | python -u run.py \ 128 | --is_training 1 \ 129 | --root_path ./dataset/electricity/ \ 130 | --data_path electricity.csv \ 131 | --model_id ECL_96_336 \ 132 | --model $model_name \ 133 | --data custom \ 134 | --features M \ 135 | --seq_len 96 \ 136 | --label_len 48 \ 137 | --pred_len 336 \ 138 | --e_layers 2 \ 139 | --d_layers 1 \ 140 | --factor 3 \ 141 | --enc_in 321 \ 142 | --dec_in 321 \ 143 | --c_out 321 \ 144 | --des 'Exp' \ 145 | --itr 1 146 | 147 | python -u run.py \ 148 | --is_training 1 \ 149 | --root_path ./dataset/electricity/ \ 150 | --data_path electricity.csv \ 151 | --model_id ECL_96_720 \ 152 | --model $model_name \ 153 | --data custom \ 154 | --features M \ 155 | --seq_len 96 \ 156 | --label_len 48 \ 157 | --pred_len 720 \ 158 | --e_layers 2 \ 159 | --d_layers 1 \ 160 | --factor 3 \ 161 | --enc_in 321 \ 162 | --dec_in 321 \ 163 | --c_out 321 \ 164 | --des 'Exp' \ 165 | --itr 1 166 | -------------------------------------------------------------------------------- /scripts/model_efficiency/README.md: -------------------------------------------------------------------------------- 1 | # Efficiency Improvement of iTransformer 2 | 3 | Supposing the input multivariate time series has a shape of $T \times N$. The vanilla attention module has a complexity of $\mathcal{O}(L^2)$, where $L$ is the number of tokens. 4 | 5 | * In Transformer, we have $L=T$ because of the manner of time points as tokens. 6 | * In iTransformer, we have $L=N$ because of the manner of variates as tokens. 7 | 8 | ## Benefit from Efficient Attention 9 | 10 | Since the attention mechanism is applied on the variate dimension in the inverted structure, efficient attention with reduced complexity essentially addresses the problem of numerous variates, which is ubiquitous in real-world applications. 11 | 12 | We currently try out the linear complexity attention from [Flowformer](https://github.com/thuml/Flowformer), and the hardware-accelerated attention mechanism from [FlashAttention](https://github.com/shreyansh26/FlashAttention-PyTorch). It demonstrates efficiency improvement by adopting these novel attention mechanisms. 13 | 14 | ### Scripts 15 | We provide the iTransformers with the FlashAttention module: 16 | 17 | ``` 18 | # iTransformer on the Traffic Dataset with hardware-friendly FlashAttention for acceleration 19 | 20 | bash ./scripts/model_efficiency/Traffic/iFlashTransformer.sh 21 | ``` 22 | 23 | 24 | ## Efficient Training Strategy 25 | With the input flexibility of attention, the token number can vary from training to inference, **our model is the first one to be capable of training on arbitrary numbers of series**. We propose a novel training strategy for high-dimensional multivariate series by taking advantage of the [variate generation capability](../variate_generalization/README.md). 26 | 27 | Concretely, we randomly choose part of the variates in each batch and only train the model with selected variates. Since the number of variate channels is flexible because of our inverting, the model can predict all the variates for predictions. 28 | 29 | 30 | ## Results 31 | 32 | **Environments**: The batch size of training is fixed as 16 with comparable model hyperparameters. The experiments run on P100 (16G). We comprehensively compare the training speed, memory footprint, and performance of the following. 33 | 34 |

35 | 36 |

37 | 38 | * The efficiency of iTransformer exceeds other Transformers in Weather with 21 variates. In Traffic with 862 variates, the memory footprints are basically the same, but iTransformer can be trained faster. 39 | * iTransformer achieves particularly better performance on the dataset with numerous variates, since the multivariate correlations can be explicitly utilized. 40 | * By adopting an efficient attention module or our proposed efficient training strategy on partial variates, iTransformer can enjoy the same level of speed and memory footprint as linear forecasters. -------------------------------------------------------------------------------- /scripts/model_efficiency/Traffic/iFlashTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=Flashformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/traffic/ \ 8 | --data_path traffic.csv \ 9 | --model_id traffic_96_96 \ 10 | --model $model_name \ 11 | --data custom \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --label_len 48 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --d_layers 1 \ 18 | --factor 3 \ 19 | --enc_in 862 \ 20 | --dec_in 862 \ 21 | --c_out 862 \ 22 | --des 'Exp' \ 23 | --itr 1 \ 24 | --train_epochs 3 25 | 26 | python -u run.py \ 27 | --is_training 1 \ 28 | --root_path ./dataset/traffic/ \ 29 | --data_path traffic.csv \ 30 | --model_id traffic_96_192 \ 31 | --model $model_name \ 32 | --data custom \ 33 | --features M \ 34 | --seq_len 96 \ 35 | --label_len 48 \ 36 | --pred_len 192 \ 37 | --e_layers 2 \ 38 | --d_layers 1 \ 39 | --factor 3 \ 40 | --enc_in 862 \ 41 | --dec_in 862 \ 42 | --c_out 862 \ 43 | --des 'Exp' \ 44 | --itr 1 \ 45 | --train_epochs 3 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/traffic/ \ 50 | --data_path traffic.csv \ 51 | --model_id traffic_96_336 \ 52 | --model $model_name \ 53 | --data custom \ 54 | --features M \ 55 | --seq_len 96 \ 56 | --label_len 48 \ 57 | --pred_len 336 \ 58 | --e_layers 2 \ 59 | --d_layers 1 \ 60 | --factor 3 \ 61 | --enc_in 862 \ 62 | --dec_in 862 \ 63 | --c_out 862 \ 64 | --des 'Exp' \ 65 | --itr 1 \ 66 | --train_epochs 3 67 | 68 | python -u run.py \ 69 | --is_training 1 \ 70 | --root_path ./dataset/traffic/ \ 71 | --data_path traffic.csv \ 72 | --model_id traffic_96_720 \ 73 | --model $model_name \ 74 | --data custom \ 75 | --features M \ 76 | --seq_len 96 \ 77 | --label_len 48 \ 78 | --pred_len 720 \ 79 | --e_layers 2 \ 80 | --d_layers 1 \ 81 | --factor 3 \ 82 | --enc_in 862 \ 83 | --dec_in 862 \ 84 | --c_out 862 \ 85 | --des 'Exp' \ 86 | --itr 1 \ 87 | --train_epochs 3 88 | 89 | model_name=iFlashformer 90 | 91 | python -u run.py \ 92 | --is_training 1 \ 93 | --root_path ./dataset/traffic/ \ 94 | --data_path traffic.csv \ 95 | --model_id traffic_96_96 \ 96 | --model $model_name \ 97 | --data custom \ 98 | --features M \ 99 | --seq_len 96 \ 100 | --label_len 48 \ 101 | --pred_len 96 \ 102 | --e_layers 2 \ 103 | --d_layers 1 \ 104 | --factor 3 \ 105 | --enc_in 862 \ 106 | --dec_in 862 \ 107 | --c_out 862 \ 108 | --des 'Exp' \ 109 | --itr 1 \ 110 | --train_epochs 3 111 | 112 | python -u run.py \ 113 | --is_training 1 \ 114 | --root_path ./dataset/traffic/ \ 115 | --data_path traffic.csv \ 116 | --model_id traffic_96_192 \ 117 | --model $model_name \ 118 | --data custom \ 119 | --features M \ 120 | --seq_len 96 \ 121 | --label_len 48 \ 122 | --pred_len 192 \ 123 | --e_layers 2 \ 124 | --d_layers 1 \ 125 | --factor 3 \ 126 | --enc_in 862 \ 127 | --dec_in 862 \ 128 | --c_out 862 \ 129 | --des 'Exp' \ 130 | --itr 1 \ 131 | --train_epochs 3 132 | 133 | python -u run.py \ 134 | --is_training 1 \ 135 | --root_path ./dataset/traffic/ \ 136 | --data_path traffic.csv \ 137 | --model_id traffic_96_336 \ 138 | --model $model_name \ 139 | --data custom \ 140 | --features M \ 141 | --seq_len 96 \ 142 | --label_len 48 \ 143 | --pred_len 336 \ 144 | --e_layers 2 \ 145 | --d_layers 1 \ 146 | --factor 3 \ 147 | --enc_in 862 \ 148 | --dec_in 862 \ 149 | --c_out 862 \ 150 | --des 'Exp' \ 151 | --itr 1 \ 152 | --train_epochs 3 153 | 154 | python -u run.py \ 155 | --is_training 1 \ 156 | --root_path ./dataset/traffic/ \ 157 | --data_path traffic.csv \ 158 | --model_id traffic_96_720 \ 159 | --model $model_name \ 160 | --data custom \ 161 | --features M \ 162 | --seq_len 96 \ 163 | --label_len 48 \ 164 | --pred_len 720 \ 165 | --e_layers 2 \ 166 | --d_layers 1 \ 167 | --factor 3 \ 168 | --enc_in 862 \ 169 | --dec_in 862 \ 170 | --c_out 862 \ 171 | --des 'Exp' \ 172 | --itr 1 \ 173 | --train_epochs 3 174 | -------------------------------------------------------------------------------- /scripts/model_efficiency/Weather/iFlashTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=Flashformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/weather/ \ 8 | --data_path weather.csv \ 9 | --model_id weather_96_96 \ 10 | --model $model_name \ 11 | --data custom \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --label_len 48 \ 15 | --pred_len 96 \ 16 | --e_layers 2 \ 17 | --d_layers 1 \ 18 | --factor 3 \ 19 | --enc_in 21 \ 20 | --dec_in 21 \ 21 | --c_out 21 \ 22 | --des 'Exp' \ 23 | --itr 1 \ 24 | --batch_size 128 \ 25 | --train_epochs 3 26 | 27 | python -u run.py \ 28 | --is_training 1 \ 29 | --root_path ./dataset/weather/ \ 30 | --data_path weather.csv \ 31 | --model_id weather_96_192 \ 32 | --model $model_name \ 33 | --data custom \ 34 | --features M \ 35 | --seq_len 96 \ 36 | --label_len 48 \ 37 | --pred_len 192 \ 38 | --e_layers 2 \ 39 | --d_layers 1 \ 40 | --factor 3 \ 41 | --enc_in 21 \ 42 | --dec_in 21 \ 43 | --c_out 21 \ 44 | --des 'Exp' \ 45 | --batch_size 128 \ 46 | --itr 1 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/weather/ \ 51 | --data_path weather.csv \ 52 | --model_id weather_96_336 \ 53 | --model $model_name \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 96 \ 57 | --label_len 48 \ 58 | --pred_len 336 \ 59 | --e_layers 2 \ 60 | --d_layers 1 \ 61 | --factor 3 \ 62 | --enc_in 21 \ 63 | --dec_in 21 \ 64 | --c_out 21 \ 65 | --des 'Exp' \ 66 | --batch_size 128 \ 67 | --itr 1 68 | 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path ./dataset/weather/ \ 72 | --data_path weather.csv \ 73 | --model_id weather_96_720 \ 74 | --model $model_name \ 75 | --data custom \ 76 | --features M \ 77 | --seq_len 96 \ 78 | --label_len 48 \ 79 | --pred_len 720 \ 80 | --e_layers 2 \ 81 | --d_layers 1 \ 82 | --factor 3 \ 83 | --enc_in 21 \ 84 | --dec_in 21 \ 85 | --c_out 21 \ 86 | --des 'Exp' \ 87 | --batch_size 128 \ 88 | --itr 1 89 | 90 | model_name=iFlashformer 91 | 92 | python -u run.py \ 93 | --is_training 1 \ 94 | --root_path ./dataset/weather/ \ 95 | --data_path weather.csv \ 96 | --model_id weather_96_96 \ 97 | --model $model_name \ 98 | --data custom \ 99 | --features M \ 100 | --seq_len 96 \ 101 | --label_len 48 \ 102 | --pred_len 96 \ 103 | --e_layers 2 \ 104 | --d_layers 1 \ 105 | --factor 3 \ 106 | --enc_in 21 \ 107 | --dec_in 21 \ 108 | --c_out 21 \ 109 | --des 'Exp' \ 110 | --batch_size 128 \ 111 | --itr 1 \ 112 | --train_epochs 3 113 | 114 | python -u run.py \ 115 | --is_training 1 \ 116 | --root_path ./dataset/weather/ \ 117 | --data_path weather.csv \ 118 | --model_id weather_96_192 \ 119 | --model $model_name \ 120 | --data custom \ 121 | --features M \ 122 | --seq_len 96 \ 123 | --label_len 48 \ 124 | --pred_len 192 \ 125 | --e_layers 2 \ 126 | --d_layers 1 \ 127 | --factor 3 \ 128 | --enc_in 21 \ 129 | --dec_in 21 \ 130 | --c_out 21 \ 131 | --des 'Exp' \ 132 | --batch_size 128 \ 133 | --itr 1 134 | 135 | python -u run.py \ 136 | --is_training 1 \ 137 | --root_path ./dataset/weather/ \ 138 | --data_path weather.csv \ 139 | --model_id weather_96_336 \ 140 | --model $model_name \ 141 | --data custom \ 142 | --features M \ 143 | --seq_len 96 \ 144 | --label_len 48 \ 145 | --pred_len 336 \ 146 | --e_layers 2 \ 147 | --d_layers 1 \ 148 | --factor 3 \ 149 | --enc_in 21 \ 150 | --dec_in 21 \ 151 | --c_out 21 \ 152 | --des 'Exp' \ 153 | --batch_size 128 \ 154 | --itr 1 155 | 156 | python -u run.py \ 157 | --is_training 1 \ 158 | --root_path ./dataset/weather/ \ 159 | --data_path weather.csv \ 160 | --model_id weather_96_720 \ 161 | --model $model_name \ 162 | --data custom \ 163 | --features M \ 164 | --seq_len 96 \ 165 | --label_len 48 \ 166 | --pred_len 720 \ 167 | --e_layers 2 \ 168 | --d_layers 1 \ 169 | --factor 3 \ 170 | --enc_in 21 \ 171 | --dec_in 21 \ 172 | --c_out 21 \ 173 | --des 'Exp' \ 174 | --batch_size 128 \ 175 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ECL/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/electricity/ \ 8 | --data_path electricity.csv \ 9 | --model_id ECL_96_96 \ 10 | --model $model_name \ 11 | --data custom \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 96 \ 15 | --e_layers 3 \ 16 | --enc_in 321 \ 17 | --dec_in 321 \ 18 | --c_out 321 \ 19 | --des 'Exp' \ 20 | --d_model 512 \ 21 | --d_ff 512 \ 22 | --batch_size 16 \ 23 | --learning_rate 0.0005 \ 24 | --itr 1 25 | 26 | python -u run.py \ 27 | --is_training 1 \ 28 | --root_path ./dataset/electricity/ \ 29 | --data_path electricity.csv \ 30 | --model_id ECL_96_192 \ 31 | --model $model_name \ 32 | --data custom \ 33 | --features M \ 34 | --seq_len 96 \ 35 | --pred_len 192 \ 36 | --e_layers 3 \ 37 | --enc_in 321 \ 38 | --dec_in 321 \ 39 | --c_out 321 \ 40 | --des 'Exp' \ 41 | --d_model 512 \ 42 | --d_ff 512 \ 43 | --batch_size 16 \ 44 | --learning_rate 0.0005 \ 45 | --itr 1 46 | 47 | 48 | python -u run.py \ 49 | --is_training 1 \ 50 | --root_path ./dataset/electricity/ \ 51 | --data_path electricity.csv \ 52 | --model_id ECL_96_336 \ 53 | --model $model_name \ 54 | --data custom \ 55 | --features M \ 56 | --seq_len 96 \ 57 | --pred_len 336 \ 58 | --e_layers 3 \ 59 | --enc_in 321 \ 60 | --dec_in 321 \ 61 | --c_out 321 \ 62 | --des 'Exp' \ 63 | --d_model 512 \ 64 | --d_ff 512 \ 65 | --batch_size 16 \ 66 | --learning_rate 0.0005 \ 67 | --itr 1 68 | 69 | 70 | python -u run.py \ 71 | --is_training 1 \ 72 | --root_path ./dataset/electricity/ \ 73 | --data_path electricity.csv \ 74 | --model_id ECL_96_720 \ 75 | --model $model_name \ 76 | --data custom \ 77 | --features M \ 78 | --seq_len 96 \ 79 | --pred_len 720 \ 80 | --e_layers 3 \ 81 | --enc_in 321 \ 82 | --dec_in 321 \ 83 | --c_out 321 \ 84 | --des 'Exp' \ 85 | --d_model 512 \ 86 | --d_ff 512 \ 87 | --batch_size 16 \ 88 | --learning_rate 0.0005 \ 89 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ETT/iTransformer_ETTh1.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/ETT-small/ \ 8 | --data_path ETTh1.csv \ 9 | --model_id ETTh1_96_96 \ 10 | --model $model_name \ 11 | --data ETTh1 \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 96 \ 15 | --e_layers 2 \ 16 | --enc_in 7 \ 17 | --dec_in 7 \ 18 | --c_out 7 \ 19 | --des 'Exp' \ 20 | --d_model 256 \ 21 | --d_ff 256 \ 22 | --itr 1 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/ETT-small/ \ 27 | --data_path ETTh1.csv \ 28 | --model_id ETTh1_96_192 \ 29 | --model $model_name \ 30 | --data ETTh1 \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --enc_in 7 \ 36 | --dec_in 7 \ 37 | --c_out 7 \ 38 | --des 'Exp' \ 39 | --d_model 256 \ 40 | --d_ff 256 \ 41 | --itr 1 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/ETT-small/ \ 46 | --data_path ETTh1.csv \ 47 | --model_id ETTh1_96_336 \ 48 | --model $model_name \ 49 | --data ETTh1 \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --pred_len 336 \ 53 | --e_layers 2 \ 54 | --enc_in 7 \ 55 | --dec_in 7 \ 56 | --c_out 7 \ 57 | --des 'Exp' \ 58 | --d_model 512 \ 59 | --d_ff 512 \ 60 | --itr 1 61 | 62 | python -u run.py \ 63 | --is_training 1 \ 64 | --root_path ./dataset/ETT-small/ \ 65 | --data_path ETTh1.csv \ 66 | --model_id ETTh1_96_720 \ 67 | --model $model_name \ 68 | --data ETTh1 \ 69 | --features M \ 70 | --seq_len 96 \ 71 | --pred_len 720 \ 72 | --e_layers 2 \ 73 | --enc_in 7 \ 74 | --dec_in 7 \ 75 | --c_out 7 \ 76 | --des 'Exp' \ 77 | --d_model 512 \ 78 | --d_ff 512 \ 79 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ETT/iTransformer_ETTh2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/ETT-small/ \ 8 | --data_path ETTh2.csv \ 9 | --model_id ETTh2_96_96 \ 10 | --model $model_name \ 11 | --data ETTh2 \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 96 \ 15 | --e_layers 2 \ 16 | --enc_in 7 \ 17 | --dec_in 7 \ 18 | --c_out 7 \ 19 | --des 'Exp' \ 20 | --d_model 128 \ 21 | --d_ff 128 \ 22 | --itr 1 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/ETT-small/ \ 27 | --data_path ETTh2.csv \ 28 | --model_id ETTh2_96_192 \ 29 | --model $model_name \ 30 | --data ETTh2 \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --enc_in 7 \ 36 | --dec_in 7 \ 37 | --c_out 7 \ 38 | --des 'Exp' \ 39 | --d_model 128 \ 40 | --d_ff 128 \ 41 | --itr 1 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/ETT-small/ \ 46 | --data_path ETTh2.csv \ 47 | --model_id ETTh2_96_336 \ 48 | --model $model_name \ 49 | --data ETTh2 \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --pred_len 336 \ 53 | --e_layers 2 \ 54 | --enc_in 7 \ 55 | --dec_in 7 \ 56 | --c_out 7 \ 57 | --des 'Exp' \ 58 | --d_model 128 \ 59 | --d_ff 128 \ 60 | --itr 1 61 | 62 | python -u run.py \ 63 | --is_training 1 \ 64 | --root_path ./dataset/ETT-small/ \ 65 | --data_path ETTh2.csv \ 66 | --model_id ETTh2_96_720 \ 67 | --model $model_name \ 68 | --data ETTh2 \ 69 | --features M \ 70 | --seq_len 96 \ 71 | --pred_len 720 \ 72 | --e_layers 2 \ 73 | --enc_in 7 \ 74 | --dec_in 7 \ 75 | --c_out 7 \ 76 | --des 'Exp' \ 77 | --d_model 128 \ 78 | --d_ff 128 \ 79 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ETT/iTransformer_ETTm1.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/ETT-small/ \ 8 | --data_path ETTm1.csv \ 9 | --model_id ETTm1_96_96 \ 10 | --model $model_name \ 11 | --data ETTm1 \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 96 \ 15 | --e_layers 2 \ 16 | --enc_in 7 \ 17 | --dec_in 7 \ 18 | --c_out 7 \ 19 | --des 'Exp' \ 20 | --d_model 128 \ 21 | --d_ff 128 \ 22 | --itr 1 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/ETT-small/ \ 27 | --data_path ETTm1.csv \ 28 | --model_id ETTm1_96_192 \ 29 | --model $model_name \ 30 | --data ETTm1 \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --enc_in 7 \ 36 | --dec_in 7 \ 37 | --c_out 7 \ 38 | --des 'Exp' \ 39 | --d_model 128 \ 40 | --d_ff 128 \ 41 | --itr 1 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/ETT-small/ \ 46 | --data_path ETTm1.csv \ 47 | --model_id ETTm1_96_336 \ 48 | --model $model_name \ 49 | --data ETTm1 \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --pred_len 336 \ 53 | --e_layers 2 \ 54 | --enc_in 7 \ 55 | --dec_in 7 \ 56 | --c_out 7 \ 57 | --des 'Exp' \ 58 | --d_model 128 \ 59 | --d_ff 128 \ 60 | --itr 1 61 | 62 | python -u run.py \ 63 | --is_training 1 \ 64 | --root_path ./dataset/ETT-small/ \ 65 | --data_path ETTm1.csv \ 66 | --model_id ETTm1_96_720 \ 67 | --model $model_name \ 68 | --data ETTm1 \ 69 | --features M \ 70 | --seq_len 96 \ 71 | --pred_len 720 \ 72 | --e_layers 2 \ 73 | --enc_in 7 \ 74 | --dec_in 7 \ 75 | --c_out 7 \ 76 | --des 'Exp' \ 77 | --d_model 128 \ 78 | --d_ff 128 \ 79 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/ETT/iTransformer_ETTm2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/ETT-small/ \ 8 | --data_path ETTm2.csv \ 9 | --model_id ETTm2_96_96 \ 10 | --model $model_name \ 11 | --data ETTm2 \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 96 \ 15 | --e_layers 2 \ 16 | --enc_in 7 \ 17 | --dec_in 7 \ 18 | --c_out 7 \ 19 | --des 'Exp' \ 20 | --d_model 128 \ 21 | --d_ff 128 \ 22 | --itr 1 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/ETT-small/ \ 27 | --data_path ETTm2.csv \ 28 | --model_id ETTm2_96_192 \ 29 | --model $model_name \ 30 | --data ETTm2 \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --enc_in 7 \ 36 | --dec_in 7 \ 37 | --c_out 7 \ 38 | --des 'Exp' \ 39 | --d_model 128 \ 40 | --d_ff 128 \ 41 | --itr 1 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/ETT-small/ \ 46 | --data_path ETTm2.csv \ 47 | --model_id ETTm2_96_336 \ 48 | --model $model_name \ 49 | --data ETTm2 \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --pred_len 336 \ 53 | --e_layers 2 \ 54 | --enc_in 7 \ 55 | --dec_in 7 \ 56 | --c_out 7 \ 57 | --des 'Exp' \ 58 | --d_model 128 \ 59 | --d_ff 128 \ 60 | --itr 1 61 | 62 | python -u run.py \ 63 | --is_training 1 \ 64 | --root_path ./dataset/ETT-small/ \ 65 | --data_path ETTm2.csv \ 66 | --model_id ETTm2_96_720 \ 67 | --model $model_name \ 68 | --data ETTm2 \ 69 | --features M \ 70 | --seq_len 96 \ 71 | --pred_len 720 \ 72 | --e_layers 2 \ 73 | --enc_in 7 \ 74 | --dec_in 7 \ 75 | --c_out 7 \ 76 | --des 'Exp' \ 77 | --d_model 128 \ 78 | --d_ff 128 \ 79 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/Exchange/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/exchange_rate/ \ 8 | --data_path exchange_rate.csv \ 9 | --model_id Exchange_96_96 \ 10 | --model $model_name \ 11 | --data custom \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 96 \ 15 | --e_layers 2 \ 16 | --enc_in 8 \ 17 | --dec_in 8 \ 18 | --c_out 8 \ 19 | --des 'Exp' \ 20 | --d_model 128 \ 21 | --d_ff 128 \ 22 | --itr 1 23 | 24 | python -u run.py \ 25 | --is_training 1 \ 26 | --root_path ./dataset/exchange_rate/ \ 27 | --data_path exchange_rate.csv \ 28 | --model_id Exchange_96_192 \ 29 | --model $model_name \ 30 | --data custom \ 31 | --features M \ 32 | --seq_len 96 \ 33 | --pred_len 192 \ 34 | --e_layers 2 \ 35 | --enc_in 8 \ 36 | --dec_in 8 \ 37 | --c_out 8 \ 38 | --des 'Exp' \ 39 | --d_model 128 \ 40 | --d_ff 128 \ 41 | --itr 1 42 | 43 | python -u run.py \ 44 | --is_training 1 \ 45 | --root_path ./dataset/exchange_rate/ \ 46 | --data_path exchange_rate.csv \ 47 | --model_id Exchange_96_336 \ 48 | --model $model_name \ 49 | --data custom \ 50 | --features M \ 51 | --seq_len 96 \ 52 | --pred_len 336 \ 53 | --e_layers 2 \ 54 | --enc_in 8 \ 55 | --dec_in 8 \ 56 | --c_out 8 \ 57 | --des 'Exp' \ 58 | --itr 1 \ 59 | --d_model 128 \ 60 | --d_ff 128 \ 61 | --train_epochs 1 62 | 63 | python -u run.py \ 64 | --is_training 1 \ 65 | --root_path ./dataset/exchange_rate/ \ 66 | --data_path exchange_rate.csv \ 67 | --model_id Exchange_96_720 \ 68 | --model $model_name \ 69 | --data custom \ 70 | --features M \ 71 | --seq_len 96 \ 72 | --pred_len 720 \ 73 | --e_layers 2 \ 74 | --enc_in 8 \ 75 | --dec_in 8 \ 76 | --c_out 8 \ 77 | --des 'Exp' \ 78 | --d_model 128 \ 79 | --d_ff 128 \ 80 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/PEMS/iTransformer_03.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/PEMS/ \ 8 | --data_path PEMS03.npz \ 9 | --model_id PEMS03_96_12 \ 10 | --model $model_name \ 11 | --data PEMS \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 12 \ 15 | --e_layers 4 \ 16 | --enc_in 358 \ 17 | --dec_in 358 \ 18 | --c_out 358 \ 19 | --des 'Exp' \ 20 | --d_model 512 \ 21 | --d_ff 512 \ 22 | --learning_rate 0.001 \ 23 | --itr 1 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/PEMS/ \ 28 | --data_path PEMS03.npz \ 29 | --model_id PEMS03_96_24 \ 30 | --model $model_name \ 31 | --data PEMS \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --pred_len 24 \ 35 | --e_layers 4 \ 36 | --enc_in 358 \ 37 | --dec_in 358 \ 38 | --c_out 358 \ 39 | --des 'Exp' \ 40 | --d_model 512 \ 41 | --d_ff 512 \ 42 | --learning_rate 0.001 \ 43 | --itr 1 44 | 45 | 46 | python -u run.py \ 47 | --is_training 1 \ 48 | --root_path ./dataset/PEMS/ \ 49 | --data_path PEMS03.npz \ 50 | --model_id PEMS03_96_48 \ 51 | --model $model_name \ 52 | --data PEMS \ 53 | --features M \ 54 | --seq_len 96 \ 55 | --pred_len 48 \ 56 | --e_layers 4 \ 57 | --enc_in 358 \ 58 | --dec_in 358 \ 59 | --c_out 358 \ 60 | --des 'Exp' \ 61 | --d_model 512 \ 62 | --d_ff 512 \ 63 | --learning_rate 0.001 \ 64 | --itr 1 65 | 66 | 67 | python -u run.py \ 68 | --is_training 1 \ 69 | --root_path ./dataset/PEMS/ \ 70 | --data_path PEMS03.npz \ 71 | --model_id PEMS03_96_96 \ 72 | --model $model_name \ 73 | --data PEMS \ 74 | --features M \ 75 | --seq_len 96 \ 76 | --pred_len 96 \ 77 | --e_layers 4 \ 78 | --enc_in 358 \ 79 | --dec_in 358 \ 80 | --c_out 358 \ 81 | --des 'Exp' \ 82 | --d_model 512 \ 83 | --d_ff 512 \ 84 | --learning_rate 0.001 \ 85 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/PEMS/iTransformer_04.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/PEMS/ \ 8 | --data_path PEMS04.npz \ 9 | --model_id PEMS04_96_12 \ 10 | --model $model_name \ 11 | --data PEMS \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 12 \ 15 | --e_layers 4 \ 16 | --enc_in 307 \ 17 | --dec_in 307 \ 18 | --c_out 307 \ 19 | --des 'Exp' \ 20 | --d_model 1024 \ 21 | --d_ff 1024 \ 22 | --learning_rate 0.0005 \ 23 | --itr 1 \ 24 | --use_norm 0 25 | 26 | python -u run.py \ 27 | --is_training 1 \ 28 | --root_path ./dataset/PEMS/ \ 29 | --data_path PEMS04.npz \ 30 | --model_id PEMS04_96_24 \ 31 | --model $model_name \ 32 | --data PEMS \ 33 | --features M \ 34 | --seq_len 96 \ 35 | --pred_len 24 \ 36 | --e_layers 4 \ 37 | --enc_in 307 \ 38 | --dec_in 307 \ 39 | --c_out 307 \ 40 | --des 'Exp' \ 41 | --d_model 1024 \ 42 | --d_ff 1024 \ 43 | --learning_rate 0.0005 \ 44 | --itr 1 \ 45 | --use_norm 0 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/PEMS/ \ 50 | --data_path PEMS04.npz \ 51 | --model_id PEMS04_96_48 \ 52 | --model $model_name \ 53 | --data PEMS \ 54 | --features M \ 55 | --seq_len 96 \ 56 | --pred_len 48 \ 57 | --e_layers 4 \ 58 | --enc_in 307 \ 59 | --dec_in 307 \ 60 | --c_out 307 \ 61 | --des 'Exp' \ 62 | --d_model 1024 \ 63 | --d_ff 1024 \ 64 | --learning_rate 0.0005 \ 65 | --itr 1 \ 66 | --use_norm 0 67 | 68 | python -u run.py \ 69 | --is_training 1 \ 70 | --root_path ./dataset/PEMS/ \ 71 | --data_path PEMS04.npz \ 72 | --model_id PEMS04_96_96 \ 73 | --model $model_name \ 74 | --data PEMS \ 75 | --features M \ 76 | --seq_len 96 \ 77 | --pred_len 96 \ 78 | --e_layers 4 \ 79 | --enc_in 307 \ 80 | --dec_in 307 \ 81 | --c_out 307 \ 82 | --des 'Exp' \ 83 | --d_model 1024 \ 84 | --d_ff 1024 \ 85 | --learning_rate 0.0005 \ 86 | --itr 1 \ 87 | --use_norm 0 88 | -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/PEMS/iTransformer_07.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/PEMS/ \ 8 | --data_path PEMS07.npz \ 9 | --model_id PEMS07_96_12 \ 10 | --model $model_name \ 11 | --data PEMS \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 12 \ 15 | --e_layers 2 \ 16 | --enc_in 883 \ 17 | --dec_in 883 \ 18 | --c_out 883 \ 19 | --des 'Exp' \ 20 | --d_model 512 \ 21 | --d_ff 512 \ 22 | --learning_rate 0.001 \ 23 | --itr 1 \ 24 | --use_norm 0 25 | 26 | python -u run.py \ 27 | --is_training 1 \ 28 | --root_path ./dataset/PEMS/ \ 29 | --data_path PEMS07.npz \ 30 | --model_id PEMS07_96_24 \ 31 | --model $model_name \ 32 | --data PEMS \ 33 | --features M \ 34 | --seq_len 96 \ 35 | --pred_len 24 \ 36 | --e_layers 2 \ 37 | --enc_in 883 \ 38 | --dec_in 883 \ 39 | --c_out 883 \ 40 | --des 'Exp' \ 41 | --d_model 512 \ 42 | --d_ff 512 \ 43 | --learning_rate 0.001 \ 44 | --itr 1 \ 45 | --use_norm 0 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/PEMS/ \ 50 | --data_path PEMS07.npz \ 51 | --model_id PEMS07_96_48 \ 52 | --model $model_name \ 53 | --data PEMS \ 54 | --features M \ 55 | --seq_len 96 \ 56 | --pred_len 48 \ 57 | --e_layers 4 \ 58 | --enc_in 883 \ 59 | --dec_in 883 \ 60 | --c_out 883 \ 61 | --des 'Exp' \ 62 | --d_model 512 \ 63 | --d_ff 512 \ 64 | --batch_size 16\ 65 | --learning_rate 0.001 \ 66 | --itr 1 \ 67 | --use_norm 0 68 | 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path ./dataset/PEMS/ \ 72 | --data_path PEMS07.npz \ 73 | --model_id PEMS07_96_96 \ 74 | --model $model_name \ 75 | --data PEMS \ 76 | --features M \ 77 | --seq_len 96 \ 78 | --pred_len 96 \ 79 | --e_layers 4 \ 80 | --enc_in 883 \ 81 | --dec_in 883 \ 82 | --c_out 883 \ 83 | --des 'Exp' \ 84 | --d_model 512 \ 85 | --d_ff 512 \ 86 | --batch_size 16\ 87 | --learning_rate 0.001 \ 88 | --itr 1 \ 89 | --use_norm 0 90 | -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/PEMS/iTransformer_08.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/PEMS/ \ 8 | --data_path PEMS08.npz \ 9 | --model_id PEMS08_96_12 \ 10 | --model $model_name \ 11 | --data PEMS \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 12 \ 15 | --e_layers 2 \ 16 | --enc_in 170 \ 17 | --dec_in 170 \ 18 | --c_out 170 \ 19 | --des 'Exp' \ 20 | --d_model 512 \ 21 | --d_ff 512 \ 22 | --itr 1 \ 23 | --use_norm 1 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/PEMS/ \ 28 | --data_path PEMS08.npz \ 29 | --model_id PEMS08_96_24 \ 30 | --model $model_name \ 31 | --data PEMS \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --pred_len 24 \ 35 | --e_layers 2 \ 36 | --enc_in 170 \ 37 | --dec_in 170 \ 38 | --c_out 170 \ 39 | --des 'Exp' \ 40 | --d_model 512 \ 41 | --d_ff 512 \ 42 | --itr 1 \ 43 | --use_norm 1 44 | 45 | python -u run.py \ 46 | --is_training 1 \ 47 | --root_path ./dataset/PEMS/ \ 48 | --data_path PEMS08.npz \ 49 | --model_id PEMS08_96_48 \ 50 | --model $model_name \ 51 | --data PEMS \ 52 | --features M \ 53 | --seq_len 96 \ 54 | --pred_len 48 \ 55 | --e_layers 4 \ 56 | --enc_in 170 \ 57 | --dec_in 170 \ 58 | --c_out 170 \ 59 | --des 'Exp' \ 60 | --d_model 512 \ 61 | --d_ff 512 \ 62 | --batch_size 16\ 63 | --learning_rate 0.001 \ 64 | --itr 1 \ 65 | --use_norm 0 66 | 67 | python -u run.py \ 68 | --is_training 1 \ 69 | --root_path ./dataset/PEMS/ \ 70 | --data_path PEMS08.npz \ 71 | --model_id PEMS08_96_96 \ 72 | --model $model_name \ 73 | --data PEMS \ 74 | --features M \ 75 | --seq_len 96 \ 76 | --pred_len 96 \ 77 | --e_layers 4 \ 78 | --enc_in 170 \ 79 | --dec_in 170 \ 80 | --c_out 170 \ 81 | --des 'Exp' \ 82 | --d_model 512 \ 83 | --d_ff 512 \ 84 | --batch_size 16\ 85 | --learning_rate 0.001 \ 86 | --itr 1 \ 87 | --use_norm 0 88 | -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/README.md: -------------------------------------------------------------------------------- 1 | # iTransformer for Multivariate Time Series Forecasting 2 | 3 | This folder contains the reproductions of the iTransformers for Multivariate Time Series Forecasting (MTSF). 4 | 5 | ## Dataset 6 | 7 | Extensive challenging multivariate forecasting tasks are evaluated as the benchmark. We provide the download links: [Google Drive](https://drive.google.com/file/d/1l51QsKvQPcqILT3DwfjCgx8Dsg2rpjot/view?usp=drive_link) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/2ea5ca3d621e4e5ba36a/). 8 | 9 |

10 | 11 |

12 | 13 | ## Scripts 14 | 15 | In each folder named after the dataset, we provide the iTransformer experiments under four different prediction lengths as shown in the table above. 16 | 17 | ``` 18 | # iTransformer on the Traffic Dataset 19 | 20 | bash ./scripts/multivariate_forecasting/Traffic/iTransformer.sh 21 | ``` 22 | 23 | To evaluate the model under other input/prediction lengths, feel free to change the ```seq_len``` and ```pred_len``` arguments: 24 | 25 | ``` 26 | # iTransformer on the Electricity Dataset, where 180 time steps are inputted as the observations, and the task is to predict the future 60 steps 27 | 28 | python -u run.py \ 29 | --is_training 1 \ 30 | --root_path ./dataset/electricity/ \ 31 | --data_path electricity.csv \ 32 | --model_id ECL_180_60 \ 33 | --model $model_name \ 34 | --data custom \ 35 | --features M \ 36 | --seq_len 180 \ 37 | --pred_len 60 \ 38 | --e_layers 3 \ 39 | --enc_in 321 \ 40 | --dec_in 321 \ 41 | --c_out 321 \ 42 | --des 'Exp' \ 43 | --d_model 512 \ 44 | --d_ff 512 \ 45 | --batch_size 16 \ 46 | --learning_rate 0.0005 \ 47 | --itr 1 48 | ``` 49 | 50 | 51 | ## Training on Custom Dataset 52 | 53 | To train with your own time series dataset, you can try out the following steps: 54 | 55 | 1. Read through the ```Dataset_Custom``` class under the ```data_provider/data_loader``` folder, which provides the functionality to load and process time series files. 56 | 2. The file should be ```csv``` format with the first column containing the timestamp and the following columns containing the variates of time series. 57 | 3. Set ```data=custom``` and modify the ```enc_in```, ```dec_in```, ```c_out``` arguments according to your number of variates in the training script. 58 | -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/SolarEnergy/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/Solar/ \ 8 | --data_path solar_AL.txt \ 9 | --model_id solar_96_96 \ 10 | --model $model_name \ 11 | --data Solar \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 96 \ 15 | --e_layers 2 \ 16 | --enc_in 137 \ 17 | --dec_in 137 \ 18 | --c_out 137 \ 19 | --des 'Exp' \ 20 | --d_model 512 \ 21 | --d_ff 512 \ 22 | --learning_rate 0.0005 \ 23 | --itr 1 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/Solar/ \ 28 | --data_path solar_AL.txt \ 29 | --model_id solar_96_192 \ 30 | --model $model_name \ 31 | --data Solar \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --pred_len 192 \ 35 | --e_layers 2 \ 36 | --enc_in 137 \ 37 | --dec_in 137 \ 38 | --c_out 137 \ 39 | --des 'Exp' \ 40 | --d_model 512 \ 41 | --d_ff 512 \ 42 | --learning_rate 0.0005 \ 43 | --itr 1 44 | 45 | python -u run.py \ 46 | --is_training 1 \ 47 | --root_path ./dataset/Solar/ \ 48 | --data_path solar_AL.txt \ 49 | --model_id solar_96_336 \ 50 | --model $model_name \ 51 | --data Solar \ 52 | --features M \ 53 | --seq_len 96 \ 54 | --pred_len 336 \ 55 | --e_layers 2 \ 56 | --enc_in 137 \ 57 | --dec_in 137 \ 58 | --c_out 137 \ 59 | --des 'Exp' \ 60 | --d_model 512 \ 61 | --d_ff 512 \ 62 | --learning_rate 0.0005 \ 63 | --itr 1 64 | 65 | python -u run.py \ 66 | --is_training 1 \ 67 | --root_path ./dataset/Solar/ \ 68 | --data_path solar_AL.txt \ 69 | --model_id solar_96_720 \ 70 | --model $model_name \ 71 | --data Solar \ 72 | --features M \ 73 | --seq_len 96 \ 74 | --pred_len 720 \ 75 | --e_layers 2 \ 76 | --enc_in 137 \ 77 | --dec_in 137 \ 78 | --c_out 137 \ 79 | --des 'Exp' \ 80 | --d_model 512 \ 81 | --d_ff 512 \ 82 | --learning_rate 0.0005 \ 83 | --itr 1 84 | -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/Traffic/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/traffic/ \ 8 | --data_path traffic.csv \ 9 | --model_id traffic_96_96 \ 10 | --model $model_name \ 11 | --data custom \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 96 \ 15 | --e_layers 4 \ 16 | --enc_in 862 \ 17 | --dec_in 862 \ 18 | --c_out 862 \ 19 | --des 'Exp' \ 20 | --d_model 512\ 21 | --d_ff 512 \ 22 | --batch_size 16 \ 23 | --learning_rate 0.001 \ 24 | --itr 1 25 | 26 | python -u run.py \ 27 | --is_training 1 \ 28 | --root_path ./dataset/traffic/ \ 29 | --data_path traffic.csv \ 30 | --model_id traffic_96_192 \ 31 | --model $model_name \ 32 | --data custom \ 33 | --features M \ 34 | --seq_len 96 \ 35 | --pred_len 192 \ 36 | --e_layers 4 \ 37 | --enc_in 862 \ 38 | --dec_in 862 \ 39 | --c_out 862 \ 40 | --des 'Exp' \ 41 | --d_model 512 \ 42 | --d_ff 512 \ 43 | --batch_size 16 \ 44 | --learning_rate 0.001 \ 45 | --itr 1 46 | 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path ./dataset/traffic/ \ 50 | --data_path traffic.csv \ 51 | --model_id traffic_96_336 \ 52 | --model $model_name \ 53 | --data custom \ 54 | --features M \ 55 | --seq_len 96 \ 56 | --pred_len 336 \ 57 | --e_layers 4 \ 58 | --enc_in 862 \ 59 | --dec_in 862 \ 60 | --c_out 862 \ 61 | --des 'Exp' \ 62 | --d_model 512\ 63 | --d_ff 512 \ 64 | --batch_size 16 \ 65 | --learning_rate 0.001 \ 66 | --itr 1 67 | 68 | python -u run.py \ 69 | --is_training 1 \ 70 | --root_path ./dataset/traffic/ \ 71 | --data_path traffic.csv \ 72 | --model_id traffic_96_720 \ 73 | --model $model_name \ 74 | --data custom \ 75 | --features M \ 76 | --seq_len 96 \ 77 | --pred_len 720 \ 78 | --e_layers 4 \ 79 | --enc_in 862 \ 80 | --dec_in 862 \ 81 | --c_out 862 \ 82 | --des 'Exp' \ 83 | --d_model 512 \ 84 | --d_ff 512 \ 85 | --batch_size 16 \ 86 | --learning_rate 0.001\ 87 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate_forecasting/Weather/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=iTransformer 4 | 5 | python -u run.py \ 6 | --is_training 1 \ 7 | --root_path ./dataset/weather/ \ 8 | --data_path weather.csv \ 9 | --model_id weather_96_96 \ 10 | --model $model_name \ 11 | --data custom \ 12 | --features M \ 13 | --seq_len 96 \ 14 | --pred_len 96 \ 15 | --e_layers 3 \ 16 | --enc_in 21 \ 17 | --dec_in 21 \ 18 | --c_out 21 \ 19 | --des 'Exp' \ 20 | --d_model 512\ 21 | --d_ff 512\ 22 | --itr 1 23 | 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/weather/ \ 28 | --data_path weather.csv \ 29 | --model_id weather_96_192 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --pred_len 192 \ 35 | --e_layers 3 \ 36 | --enc_in 21 \ 37 | --dec_in 21 \ 38 | --c_out 21 \ 39 | --des 'Exp' \ 40 | --d_model 512\ 41 | --d_ff 512\ 42 | --itr 1 43 | 44 | 45 | python -u run.py \ 46 | --is_training 1 \ 47 | --root_path ./dataset/weather/ \ 48 | --data_path weather.csv \ 49 | --model_id weather_96_336 \ 50 | --model $model_name \ 51 | --data custom \ 52 | --features M \ 53 | --seq_len 96 \ 54 | --pred_len 336 \ 55 | --e_layers 3 \ 56 | --enc_in 21 \ 57 | --dec_in 21 \ 58 | --c_out 21 \ 59 | --des 'Exp' \ 60 | --d_model 512\ 61 | --d_ff 512\ 62 | --itr 1 63 | 64 | 65 | python -u run.py \ 66 | --is_training 1 \ 67 | --root_path ./dataset/weather/ \ 68 | --data_path weather.csv \ 69 | --model_id weather_96_720 \ 70 | --model $model_name \ 71 | --data custom \ 72 | --features M \ 73 | --seq_len 96 \ 74 | --pred_len 720 \ 75 | --e_layers 3 \ 76 | --enc_in 21 \ 77 | --dec_in 21 \ 78 | --c_out 21 \ 79 | --des 'Exp' \ 80 | --d_model 512\ 81 | --d_ff 512\ 82 | --itr 1 -------------------------------------------------------------------------------- /scripts/variate_generalization/ECL/iFlowformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=Flowformer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/electricity/ \ 8 | # --data_path electricity.csv \ 9 | # --model_id ECL_96_96 \ 10 | # --model $model_name \ 11 | # --data custom \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 321 \ 20 | # --dec_in 321 \ 21 | # --c_out 321 \ 22 | # --des 'Exp' \ 23 | # --itr 1 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/electricity/ \ 28 | --data_path electricity.csv \ 29 | --model_id ECL_96_96 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --label_len 48 \ 35 | --pred_len 96 \ 36 | --e_layers 2 \ 37 | --d_layers 1 \ 38 | --factor 3 \ 39 | --enc_in 64 \ 40 | --dec_in 64 \ 41 | --c_out 64 \ 42 | --des 'Exp' \ 43 | --channel_independence true \ 44 | --exp_name partial_train \ 45 | --batch_size 8 \ 46 | --d_model 32 \ 47 | --d_ff 64 \ 48 | --itr 1 49 | 50 | model_name=iFlowformer 51 | 52 | #python -u run.py \ 53 | ## --is_training 1 \ 54 | # --root_path ./dataset/electricity/ \ 55 | # --data_path electricity.csv \ 56 | # --model_id ECL_96_96 \ 57 | # --model $model_name \ 58 | # --data custom \ 59 | # --features M \ 60 | # --seq_len 96 \ 61 | # --label_len 48 \ 62 | # --pred_len 96 \ 63 | # --e_layers 2 \ 64 | # --d_layers 1 \ 65 | # --factor 3 \ 66 | # --enc_in 321 \ 67 | # --dec_in 321 \ 68 | # --c_out 321 \ 69 | # --des 'Exp' \ 70 | # --itr 1 71 | 72 | python -u run.py \ 73 | --is_training 1 \ 74 | --root_path ./dataset/electricity/ \ 75 | --data_path electricity.csv \ 76 | --model_id ECL_96_96 \ 77 | --model $model_name \ 78 | --data custom \ 79 | --features M \ 80 | --seq_len 96 \ 81 | --label_len 48 \ 82 | --pred_len 96 \ 83 | --e_layers 2 \ 84 | --d_layers 1 \ 85 | --factor 3 \ 86 | --enc_in 64 \ 87 | --dec_in 64 \ 88 | --c_out 64 \ 89 | --des 'Exp' \ 90 | --exp_name partial_train \ 91 | --itr 1 92 | -------------------------------------------------------------------------------- /scripts/variate_generalization/ECL/iInformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=Informer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/electricity/ \ 8 | # --data_path electricity.csv \ 9 | # --model_id ECL_96_96 \ 10 | # --model $model_name \ 11 | # --data custom \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 321 \ 20 | # --dec_in 321 \ 21 | # --c_out 321 \ 22 | # --des 'Exp' \ 23 | # --itr 1 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/electricity/ \ 28 | --data_path electricity.csv \ 29 | --model_id ECL_96_96 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --label_len 48 \ 35 | --pred_len 96 \ 36 | --e_layers 2 \ 37 | --d_layers 1 \ 38 | --factor 3 \ 39 | --enc_in 64 \ 40 | --dec_in 64 \ 41 | --c_out 64 \ 42 | --des 'Exp' \ 43 | --channel_independence true \ 44 | --exp_name partial_train \ 45 | --batch_size 8 \ 46 | --d_model 32 \ 47 | --d_ff 64 \ 48 | --itr 1 49 | 50 | model_name=iInformer 51 | 52 | #python -u run.py \ 53 | ## --is_training 1 \ 54 | # --root_path ./dataset/electricity/ \ 55 | # --data_path electricity.csv \ 56 | # --model_id ECL_96_96 \ 57 | # --model $model_name \ 58 | # --data custom \ 59 | # --features M \ 60 | # --seq_len 96 \ 61 | # --label_len 48 \ 62 | # --pred_len 96 \ 63 | # --e_layers 2 \ 64 | # --d_layers 1 \ 65 | # --factor 3 \ 66 | # --enc_in 321 \ 67 | # --dec_in 321 \ 68 | # --c_out 321 \ 69 | # --des 'Exp' \ 70 | # --itr 1 71 | 72 | python -u run.py \ 73 | --is_training 1 \ 74 | --root_path ./dataset/electricity/ \ 75 | --data_path electricity.csv \ 76 | --model_id ECL_96_96 \ 77 | --model $model_name \ 78 | --data custom \ 79 | --features M \ 80 | --seq_len 96 \ 81 | --label_len 48 \ 82 | --pred_len 96 \ 83 | --e_layers 2 \ 84 | --d_layers 1 \ 85 | --factor 3 \ 86 | --enc_in 64 \ 87 | --dec_in 64 \ 88 | --c_out 64 \ 89 | --des 'Exp' \ 90 | --exp_name partial_train \ 91 | --itr 1 -------------------------------------------------------------------------------- /scripts/variate_generalization/ECL/iReformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=Reformer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/electricity/ \ 8 | # --data_path electricity.csv \ 9 | # --model_id ECL_96_96 \ 10 | # --model $model_name \ 11 | # --data custom \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 321 \ 20 | # --dec_in 321 \ 21 | # --c_out 321 \ 22 | # --des 'Exp' \ 23 | # --itr 1 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/electricity/ \ 28 | --data_path electricity.csv \ 29 | --model_id ECL_96_96 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --label_len 48 \ 35 | --pred_len 96 \ 36 | --e_layers 2 \ 37 | --d_layers 1 \ 38 | --factor 3 \ 39 | --enc_in 64 \ 40 | --dec_in 64 \ 41 | --c_out 64 \ 42 | --des 'Exp' \ 43 | --channel_independence true \ 44 | --exp_name partial_train \ 45 | --batch_size 8 \ 46 | --d_model 32 \ 47 | --d_ff 64 \ 48 | --itr 1 49 | 50 | model_name=iReformer 51 | 52 | #python -u run.py \ 53 | ## --is_training 1 \ 54 | # --root_path ./dataset/electricity/ \ 55 | # --data_path electricity.csv \ 56 | # --model_id ECL_96_96 \ 57 | # --model $model_name \ 58 | # --data custom \ 59 | # --features M \ 60 | # --seq_len 96 \ 61 | # --label_len 48 \ 62 | # --pred_len 96 \ 63 | # --e_layers 2 \ 64 | # --d_layers 1 \ 65 | # --factor 3 \ 66 | # --enc_in 321 \ 67 | # --dec_in 321 \ 68 | # --c_out 321 \ 69 | # --des 'Exp' \ 70 | # --itr 1 71 | 72 | python -u run.py \ 73 | --is_training 1 \ 74 | --root_path ./dataset/electricity/ \ 75 | --data_path electricity.csv \ 76 | --model_id ECL_96_96 \ 77 | --model $model_name \ 78 | --data custom \ 79 | --features M \ 80 | --seq_len 96 \ 81 | --label_len 48 \ 82 | --pred_len 96 \ 83 | --e_layers 2 \ 84 | --d_layers 1 \ 85 | --factor 3 \ 86 | --enc_in 64 \ 87 | --dec_in 64 \ 88 | --c_out 64 \ 89 | --des 'Exp' \ 90 | --exp_name partial_train \ 91 | --itr 1 -------------------------------------------------------------------------------- /scripts/variate_generalization/ECL/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | model_name=Transformer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/electricity/ \ 8 | # --data_path electricity.csv \ 9 | # --model_id ECL_96_96 \ 10 | # --model $model_name \ 11 | # --data custom \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 321 \ 20 | # --dec_in 321 \ 21 | # --c_out 321 \ 22 | # --des 'Exp' \ 23 | # --itr 1 24 | 25 | # 20% partial variates, enc_in: 64 = 321 // 5 26 | python -u run.py \ 27 | --is_training 1 \ 28 | --root_path ./dataset/electricity/ \ 29 | --data_path electricity.csv \ 30 | --model_id ECL_96_96 \ 31 | --model $model_name \ 32 | --data custom \ 33 | --features M \ 34 | --seq_len 96 \ 35 | --label_len 48 \ 36 | --pred_len 96 \ 37 | --e_layers 2 \ 38 | --d_layers 1 \ 39 | --factor 3 \ 40 | --enc_in 64 \ 41 | --dec_in 64 \ 42 | --c_out 64 \ 43 | --des 'Exp' \ 44 | --channel_independence true \ 45 | --exp_name partial_train \ 46 | --batch_size 8 \ 47 | --d_model 32 \ 48 | --d_ff 64 \ 49 | --itr 1 50 | 51 | model_name=iTransformer 52 | 53 | #python -u run.py \ 54 | ## --is_training 1 \ 55 | # --root_path ./dataset/electricity/ \ 56 | # --data_path electricity.csv \ 57 | # --model_id ECL_96_96 \ 58 | # --model $model_name \ 59 | # --data custom \ 60 | # --features M \ 61 | # --seq_len 96 \ 62 | # --label_len 48 \ 63 | # --pred_len 96 \ 64 | # --e_layers 2 \ 65 | # --d_layers 1 \ 66 | # --factor 3 \ 67 | # --enc_in 321 \ 68 | # --dec_in 321 \ 69 | # --c_out 321 \ 70 | # --des 'Exp' \ 71 | # --itr 1 72 | 73 | python -u run.py \ 74 | --is_training 1 \ 75 | --root_path ./dataset/electricity/ \ 76 | --data_path electricity.csv \ 77 | --model_id ECL_96_96 \ 78 | --model $model_name \ 79 | --data custom \ 80 | --features M \ 81 | --seq_len 96 \ 82 | --label_len 48 \ 83 | --pred_len 96 \ 84 | --e_layers 2 \ 85 | --d_layers 1 \ 86 | --factor 3 \ 87 | --enc_in 64 \ 88 | --dec_in 64 \ 89 | --c_out 64 \ 90 | --des 'Exp' \ 91 | --exp_name partial_train \ 92 | --itr 1 -------------------------------------------------------------------------------- /scripts/variate_generalization/README.md: -------------------------------------------------------------------------------- 1 | # iTransformer for Variate Generalization 2 | 3 | This folder contains the implementation of the iTransformer to generalize on unseen variates. If you are new to this repo, we recommend you to read this [README](../multivariate_forecasting/README.md) first. 4 | 5 | By inverting vanilla Transformers, the model is empowered with the generalization capability on unseen variates. Firstly, benefiting from the flexibility of the number of input tokens, the amount of variate channels is no longer restricted and thus feasible to vary from training and inference. Second, feed-forward networks are identically applied on independent variate tokens to learn transferable representations of time series. 6 | 7 | ## Scripts 8 | 9 | ``` 10 | # Train models with only 20% of variates from Traffic and test the model on all variates without finetuning 11 | 12 | bash ./scripts/variate_generalization/Traffic/iTransformer.sh 13 | ``` 14 | 15 | > During Training 16 |

17 | 18 |

19 | 20 | > During Inference 21 |

22 | 23 |

24 | 25 | In each folder named after the dataset, we provide two strategies to enable Transformers to generalize on unseen variate. 26 | 27 | * **CI-Transformers**: Channel Independence regards each variate of time series as independent channels, and uses a shared backbone to forecast all variates. Therefore, the model can predict variates one by one, but the training and inference procedure can be time-consuming. 28 | 29 | * **iTransformers**: benefiting from the flexibility of attention that the number of input tokens can be dynamically changeable, the amount of variates as tokens is no longer restricted, and can even allow the model to be trained on arbitrary variables. 30 | 31 | ## Results 32 | 33 |

34 | 35 |

36 | 37 | iTransformers can be naturally trained with 20% variates and accomplish forecast on all variates with the ability to learn transferable representations. -------------------------------------------------------------------------------- /scripts/variate_generalization/SolarEnergy/iFlowformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=Flowformer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/Solar/ \ 8 | # --data_path solar_AL.txt \ 9 | # --model_id solar_96_96 \ 10 | # --model $model_name \ 11 | # --data Solar \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 137 \ 20 | # --dec_in 137 \ 21 | # --c_out 137 \ 22 | # --des 'Exp' \ 23 | # --learning_rate 0.0005 \ 24 | # --itr 1 25 | 26 | python -u run.py \ 27 | --is_training 1 \ 28 | --root_path ./dataset/Solar/ \ 29 | --data_path solar_AL.txt \ 30 | --model_id solar_96_96 \ 31 | --model $model_name \ 32 | --data Solar \ 33 | --features M \ 34 | --seq_len 96 \ 35 | --label_len 48 \ 36 | --pred_len 96 \ 37 | --e_layers 2 \ 38 | --d_layers 1 \ 39 | --factor 3 \ 40 | --enc_in 27 \ 41 | --dec_in 27 \ 42 | --c_out 27 \ 43 | --des 'Exp' \ 44 | --d_model 32 \ 45 | --d_ff 64 \ 46 | --learning_rate 0.0005 \ 47 | --channel_independence true \ 48 | --exp_name partial_train \ 49 | --batch_size 8 \ 50 | --itr 1 51 | 52 | model_name=iFlowformer 53 | 54 | #python -u run.py \ 55 | ## --is_training 1 \ 56 | # --root_path ./dataset/Solar/ \ 57 | # --data_path solar_AL.txt \ 58 | # --model_id solar_96_96 \ 59 | # --model $model_name \ 60 | # --data Solar \ 61 | # --features M \ 62 | # --seq_len 96 \ 63 | # --label_len 48 \ 64 | # --pred_len 96 \ 65 | # --e_layers 2 \ 66 | # --d_layers 1 \ 67 | # --factor 3 \ 68 | # --enc_in 137 \ 69 | # --dec_in 137 \ 70 | # --c_out 137 \ 71 | # --des 'Exp' \ 72 | # --learning_rate 0.0005 \ 73 | # --itr 1 74 | 75 | python -u run.py \ 76 | --is_training 1 \ 77 | --root_path ./dataset/Solar/ \ 78 | --data_path solar_AL.txt \ 79 | --model_id solar_96_96 \ 80 | --model $model_name \ 81 | --data Solar \ 82 | --features M \ 83 | --seq_len 96 \ 84 | --label_len 48 \ 85 | --pred_len 96 \ 86 | --e_layers 2 \ 87 | --d_layers 1 \ 88 | --factor 3 \ 89 | --enc_in 27 \ 90 | --dec_in 27 \ 91 | --c_out 27 \ 92 | --des 'Exp' \ 93 | --learning_rate 0.0005 \ 94 | --exp_name partial_train \ 95 | --itr 1 -------------------------------------------------------------------------------- /scripts/variate_generalization/SolarEnergy/iInformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=Informer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/Solar/ \ 8 | # --data_path solar_AL.txt \ 9 | # --model_id solar_96_96 \ 10 | # --model $model_name \ 11 | # --data Solar \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 137 \ 20 | # --dec_in 137 \ 21 | # --c_out 137 \ 22 | # --des 'Exp' \ 23 | # --learning_rate 0.0005 \ 24 | # --itr 1 25 | 26 | python -u run.py \ 27 | --is_training 1 \ 28 | --root_path ./dataset/Solar/ \ 29 | --data_path solar_AL.txt \ 30 | --model_id solar_96_96 \ 31 | --model $model_name \ 32 | --data Solar \ 33 | --features M \ 34 | --seq_len 96 \ 35 | --label_len 48 \ 36 | --pred_len 96 \ 37 | --e_layers 2 \ 38 | --d_layers 1 \ 39 | --factor 3 \ 40 | --enc_in 27 \ 41 | --dec_in 27 \ 42 | --c_out 27 \ 43 | --des 'Exp' \ 44 | --d_model 32 \ 45 | --d_ff 64 \ 46 | --learning_rate 0.0005 \ 47 | --channel_independence true \ 48 | --exp_name partial_train \ 49 | --batch_size 8 \ 50 | --itr 1 51 | 52 | model_name=iInformer 53 | 54 | #python -u run.py \ 55 | ## --is_training 1 \ 56 | # --root_path ./dataset/Solar/ \ 57 | # --data_path solar_AL.txt \ 58 | # --model_id solar_96_96 \ 59 | # --model $model_name \ 60 | # --data Solar \ 61 | # --features M \ 62 | # --seq_len 96 \ 63 | # --label_len 48 \ 64 | # --pred_len 96 \ 65 | # --e_layers 2 \ 66 | # --d_layers 1 \ 67 | # --factor 3 \ 68 | # --enc_in 137 \ 69 | # --dec_in 137 \ 70 | # --c_out 137 \ 71 | # --des 'Exp' \ 72 | # --learning_rate 0.0005 \ 73 | # --itr 1 74 | 75 | python -u run.py \ 76 | --is_training 1 \ 77 | --root_path ./dataset/Solar/ \ 78 | --data_path solar_AL.txt \ 79 | --model_id solar_96_96 \ 80 | --model $model_name \ 81 | --data Solar \ 82 | --features M \ 83 | --seq_len 96 \ 84 | --label_len 48 \ 85 | --pred_len 96 \ 86 | --e_layers 2 \ 87 | --d_layers 1 \ 88 | --factor 3 \ 89 | --enc_in 27 \ 90 | --dec_in 27 \ 91 | --c_out 27 \ 92 | --des 'Exp' \ 93 | --learning_rate 0.0005 \ 94 | --exp_name partial_train \ 95 | --itr 1 -------------------------------------------------------------------------------- /scripts/variate_generalization/SolarEnergy/iReformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=Reformer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/Solar/ \ 8 | # --data_path solar_AL.txt \ 9 | # --model_id solar_96_96 \ 10 | # --model $model_name \ 11 | # --data Solar \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 137 \ 20 | # --dec_in 137 \ 21 | # --c_out 137 \ 22 | # --des 'Exp' \ 23 | # --learning_rate 0.0005 \ 24 | # --itr 1 25 | 26 | python -u run.py \ 27 | --is_training 1 \ 28 | --root_path ./dataset/Solar/ \ 29 | --data_path solar_AL.txt \ 30 | --model_id solar_96_96 \ 31 | --model $model_name \ 32 | --data Solar \ 33 | --features M \ 34 | --seq_len 96 \ 35 | --label_len 48 \ 36 | --pred_len 96 \ 37 | --e_layers 2 \ 38 | --d_layers 1 \ 39 | --factor 3 \ 40 | --enc_in 27 \ 41 | --dec_in 27 \ 42 | --c_out 27 \ 43 | --des 'Exp' \ 44 | --d_model 32 \ 45 | --d_ff 64 \ 46 | --learning_rate 0.0005 \ 47 | --channel_independence true \ 48 | --exp_name partial_train \ 49 | --batch_size 8 \ 50 | --itr 1 51 | 52 | model_name=iReformer 53 | 54 | #python -u run.py \ 55 | ## --is_training 1 \ 56 | # --root_path ./dataset/Solar/ \ 57 | # --data_path solar_AL.txt \ 58 | # --model_id solar_96_96 \ 59 | # --model $model_name \ 60 | # --data Solar \ 61 | # --features M \ 62 | # --seq_len 96 \ 63 | # --label_len 48 \ 64 | # --pred_len 96 \ 65 | # --e_layers 2 \ 66 | # --d_layers 1 \ 67 | # --factor 3 \ 68 | # --enc_in 137 \ 69 | # --dec_in 137 \ 70 | # --c_out 137 \ 71 | # --des 'Exp' \ 72 | # --learning_rate 0.0005 \ 73 | # --itr 74 | 75 | python -u run.py \ 76 | --is_training 1 \ 77 | --root_path ./dataset/Solar/ \ 78 | --data_path solar_AL.txt \ 79 | --model_id solar_96_96 \ 80 | --model $model_name \ 81 | --data Solar \ 82 | --features M \ 83 | --seq_len 96 \ 84 | --label_len 48 \ 85 | --pred_len 96 \ 86 | --e_layers 2 \ 87 | --d_layers 1 \ 88 | --factor 3 \ 89 | --enc_in 27 \ 90 | --dec_in 27 \ 91 | --c_out 27 \ 92 | --des 'Exp' \ 93 | --learning_rate 0.0005 \ 94 | --exp_name partial_train \ 95 | --itr 1 -------------------------------------------------------------------------------- /scripts/variate_generalization/SolarEnergy/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | model_name=Transformer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/Solar/ \ 8 | # --data_path solar_AL.txt \ 9 | # --model_id solar_96_96 \ 10 | # --model $model_name \ 11 | # --data Solar \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 137 \ 20 | # --dec_in 137 \ 21 | # --c_out 137 \ 22 | # --des 'Exp' \ 23 | # --learning_rate 0.0005 \ 24 | # --itr 1 25 | 26 | # 20% partial variates: 27 = 137 // 5 27 | python -u run.py \ 28 | --is_training 1 \ 29 | --root_path ./dataset/Solar/ \ 30 | --data_path solar_AL.txt \ 31 | --model_id solar_96_96 \ 32 | --model $model_name \ 33 | --data Solar \ 34 | --features M \ 35 | --seq_len 96 \ 36 | --label_len 48 \ 37 | --pred_len 96 \ 38 | --e_layers 2 \ 39 | --d_layers 1 \ 40 | --factor 3 \ 41 | --enc_in 27 \ 42 | --dec_in 27 \ 43 | --c_out 27 \ 44 | --des 'Exp' \ 45 | --d_model 32 \ 46 | --d_ff 64 \ 47 | --learning_rate 0.0005 \ 48 | --channel_independence true \ 49 | --exp_name partial_train \ 50 | --batch_size 8 \ 51 | --itr 1 52 | 53 | model_name=iTransformer 54 | 55 | #python -u run.py \ 56 | ## --is_training 1 \ 57 | # --root_path ./dataset/Solar/ \ 58 | # --data_path solar_AL.txt \ 59 | # --model_id solar_96_96 \ 60 | # --model $model_name \ 61 | # --data Solar \ 62 | # --features M \ 63 | # --seq_len 96 \ 64 | # --label_len 48 \ 65 | # --pred_len 96 \ 66 | # --e_layers 2 \ 67 | # --d_layers 1 \ 68 | # --factor 3 \ 69 | # --enc_in 137 \ 70 | # --dec_in 137 \ 71 | # --c_out 137 \ 72 | # --des 'Exp' \ 73 | # --learning_rate 0.0005 \ 74 | # --itr 1 75 | 76 | python -u run.py \ 77 | --is_training 1 \ 78 | --root_path ./dataset/Solar/ \ 79 | --data_path solar_AL.txt \ 80 | --model_id solar_96_96 \ 81 | --model $model_name \ 82 | --data Solar \ 83 | --features M \ 84 | --seq_len 96 \ 85 | --label_len 48 \ 86 | --pred_len 96 \ 87 | --e_layers 2 \ 88 | --d_layers 1 \ 89 | --factor 3 \ 90 | --enc_in 27 \ 91 | --dec_in 27 \ 92 | --c_out 27 \ 93 | --des 'Exp' \ 94 | --learning_rate 0.0005 \ 95 | --exp_name partial_train \ 96 | --itr 1 -------------------------------------------------------------------------------- /scripts/variate_generalization/Traffic/iFlowformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=Flowformer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/traffic/ \ 8 | # --data_path traffic.csv \ 9 | # --model_id traffic_96_96 \ 10 | # --model $model_name \ 11 | # --data custom \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 862 \ 20 | # --dec_in 862 \ 21 | # --c_out 862 \ 22 | # --des 'Exp' \ 23 | # --itr 1 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/traffic/ \ 28 | --data_path traffic.csv \ 29 | --model_id traffic_96_96 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --label_len 48 \ 35 | --pred_len 96 \ 36 | --e_layers 2 \ 37 | --d_layers 1 \ 38 | --factor 3 \ 39 | --enc_in 172 \ 40 | --dec_in 172 \ 41 | --c_out 172 \ 42 | --des 'Exp' \ 43 | --channel_independence true \ 44 | --exp_name partial_train \ 45 | --batch_size 4 \ 46 | --d_model 32 \ 47 | --d_ff 64 \ 48 | --itr 1 49 | 50 | model_name=iFlowformer 51 | 52 | #python -u run.py \ 53 | ## --is_training 1 \ 54 | # --root_path ./dataset/traffic/ \ 55 | # --data_path traffic.csv \ 56 | # --model_id traffic_96_96 \ 57 | # --model $model_name \ 58 | # --data custom \ 59 | # --features M \ 60 | # --seq_len 96 \ 61 | # --label_len 48 \ 62 | # --pred_len 96 \ 63 | # --e_layers 2 \ 64 | # --d_layers 1 \ 65 | # --factor 3 \ 66 | # --enc_in 862 \ 67 | # --dec_in 862 \ 68 | # --c_out 862 \ 69 | # --des 'Exp' \ 70 | # --itr 1 71 | 72 | python -u run.py \ 73 | --is_training 1 \ 74 | --root_path ./dataset/traffic/ \ 75 | --data_path traffic.csv \ 76 | --model_id traffic_96_96 \ 77 | --model $model_name \ 78 | --data custom \ 79 | --features M \ 80 | --seq_len 96 \ 81 | --label_len 48 \ 82 | --pred_len 96 \ 83 | --e_layers 2 \ 84 | --d_layers 1 \ 85 | --factor 3 \ 86 | --enc_in 172 \ 87 | --dec_in 172 \ 88 | --c_out 172 \ 89 | --des 'Exp' \ 90 | --exp_name partial_train \ 91 | --itr 1 92 | -------------------------------------------------------------------------------- /scripts/variate_generalization/Traffic/iInformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=Informer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/traffic/ \ 8 | # --data_path traffic.csv \ 9 | # --model_id traffic_96_96 \ 10 | # --model $model_name \ 11 | # --data custom \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 862 \ 20 | # --dec_in 862 \ 21 | # --c_out 862 \ 22 | # --des 'Exp' \ 23 | # --itr 1 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/traffic/ \ 28 | --data_path traffic.csv \ 29 | --model_id traffic_96_96 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --label_len 48 \ 35 | --pred_len 96 \ 36 | --e_layers 2 \ 37 | --d_layers 1 \ 38 | --factor 3 \ 39 | --enc_in 172 \ 40 | --dec_in 172 \ 41 | --c_out 172 \ 42 | --des 'Exp' \ 43 | --channel_independence true \ 44 | --exp_name partial_train \ 45 | --batch_size 4 \ 46 | --d_model 32 \ 47 | --d_ff 64 \ 48 | --itr 1 49 | 50 | model_name=iInformer 51 | 52 | #python -u run.py \ 53 | ## --is_training 1 \ 54 | # --root_path ./dataset/traffic/ \ 55 | # --data_path traffic.csv \ 56 | # --model_id traffic_96_96 \ 57 | # --model $model_name \ 58 | # --data custom \ 59 | # --features M \ 60 | # --seq_len 96 \ 61 | # --label_len 48 \ 62 | # --pred_len 96 \ 63 | # --e_layers 2 \ 64 | # --d_layers 1 \ 65 | # --factor 3 \ 66 | # --enc_in 862 \ 67 | # --dec_in 862 \ 68 | # --c_out 862 \ 69 | # --des 'Exp' \ 70 | # --itr 1 71 | 72 | python -u run.py \ 73 | --is_training 1 \ 74 | --root_path ./dataset/traffic/ \ 75 | --data_path traffic.csv \ 76 | --model_id traffic_96_96 \ 77 | --model $model_name \ 78 | --data custom \ 79 | --features M \ 80 | --seq_len 96 \ 81 | --label_len 48 \ 82 | --pred_len 96 \ 83 | --e_layers 2 \ 84 | --d_layers 1 \ 85 | --factor 3 \ 86 | --enc_in 172 \ 87 | --dec_in 172 \ 88 | --c_out 172 \ 89 | --des 'Exp' \ 90 | --exp_name partial_train \ 91 | --itr 1 92 | -------------------------------------------------------------------------------- /scripts/variate_generalization/Traffic/iReformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=Reformer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/traffic/ \ 8 | # --data_path traffic.csv \ 9 | # --model_id traffic_96_96 \ 10 | # --model $model_name \ 11 | # --data custom \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 862 \ 20 | # --dec_in 862 \ 21 | # --c_out 862 \ 22 | # --des 'Exp' \ 23 | # --itr 1 24 | 25 | python -u run.py \ 26 | --is_training 1 \ 27 | --root_path ./dataset/traffic/ \ 28 | --data_path traffic.csv \ 29 | --model_id traffic_96_96 \ 30 | --model $model_name \ 31 | --data custom \ 32 | --features M \ 33 | --seq_len 96 \ 34 | --label_len 48 \ 35 | --pred_len 96 \ 36 | --e_layers 2 \ 37 | --d_layers 1 \ 38 | --factor 3 \ 39 | --enc_in 172 \ 40 | --dec_in 172 \ 41 | --c_out 172 \ 42 | --des 'Exp' \ 43 | --channel_independence true \ 44 | --exp_name partial_train \ 45 | --batch_size 4 \ 46 | --d_model 32 \ 47 | --d_ff 64 \ 48 | --itr 1 49 | 50 | model_name=iReformer 51 | 52 | #python -u run.py \ 53 | ## --is_training 1 \ 54 | # --root_path ./dataset/traffic/ \ 55 | # --data_path traffic.csv \ 56 | # --model_id traffic_96_96 \ 57 | # --model $model_name \ 58 | # --data custom \ 59 | # --features M \ 60 | # --seq_len 96 \ 61 | # --label_len 48 \ 62 | # --pred_len 96 \ 63 | # --e_layers 2 \ 64 | # --d_layers 1 \ 65 | # --factor 3 \ 66 | # --enc_in 862 \ 67 | # --dec_in 862 \ 68 | # --c_out 862 \ 69 | # --des 'Exp' \ 70 | # --itr 1 71 | 72 | python -u run.py \ 73 | --is_training 1 \ 74 | --root_path ./dataset/traffic/ \ 75 | --data_path traffic.csv \ 76 | --model_id traffic_96_96 \ 77 | --model $model_name \ 78 | --data custom \ 79 | --features M \ 80 | --seq_len 96 \ 81 | --label_len 48 \ 82 | --pred_len 96 \ 83 | --e_layers 2 \ 84 | --d_layers 1 \ 85 | --factor 3 \ 86 | --enc_in 172 \ 87 | --dec_in 172 \ 88 | --c_out 172 \ 89 | --des 'Exp' \ 90 | --exp_name partial_train \ 91 | --itr 1 -------------------------------------------------------------------------------- /scripts/variate_generalization/Traffic/iTransformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | model_name=Transformer 4 | 5 | #python -u run.py \ 6 | ## --is_training 1 \ 7 | # --root_path ./dataset/traffic/ \ 8 | # --data_path traffic.csv \ 9 | # --model_id traffic_96_96 \ 10 | # --model $model_name \ 11 | # --data custom \ 12 | # --features M \ 13 | # --seq_len 96 \ 14 | # --label_len 48 \ 15 | # --pred_len 96 \ 16 | # --e_layers 2 \ 17 | # --d_layers 1 \ 18 | # --factor 3 \ 19 | # --enc_in 862 \ 20 | # --dec_in 862 \ 21 | # --c_out 862 \ 22 | # --des 'Exp' \ 23 | # --itr 1 24 | 25 | # 20% partial variates, enc_in: 172 = 862 // 5 26 | python -u run.py \ 27 | --is_training 1 \ 28 | --root_path ./dataset/traffic/ \ 29 | --data_path traffic.csv \ 30 | --model_id traffic_96_96 \ 31 | --model $model_name \ 32 | --data custom \ 33 | --features M \ 34 | --seq_len 96 \ 35 | --label_len 48 \ 36 | --pred_len 96 \ 37 | --e_layers 2 \ 38 | --d_layers 1 \ 39 | --factor 3 \ 40 | --enc_in 172 \ 41 | --dec_in 172 \ 42 | --c_out 172 \ 43 | --des 'Exp' \ 44 | --channel_independence true \ 45 | --exp_name partial_train \ 46 | --batch_size 8 \ 47 | --d_model 32 \ 48 | --d_ff 64 \ 49 | --itr 1 50 | 51 | model_name=iTransformer 52 | 53 | #python -u run.py \ 54 | ## --is_training 1 \ 55 | # --root_path ./dataset/traffic/ \ 56 | # --data_path traffic.csv \ 57 | # --model_id traffic_96_96 \ 58 | # --model $model_name \ 59 | # --data custom \ 60 | # --features M \ 61 | # --seq_len 96 \ 62 | # --label_len 48 \ 63 | # --pred_len 96 \ 64 | # --e_layers 2 \ 65 | # --d_layers 1 \ 66 | # --factor 3 \ 67 | # --enc_in 862 \ 68 | # --dec_in 862 \ 69 | # --c_out 862 \ 70 | # --des 'Exp' \ 71 | # --itr 1 72 | 73 | python -u run.py \ 74 | --is_training 1 \ 75 | --root_path ./dataset/traffic/ \ 76 | --data_path traffic.csv \ 77 | --model_id traffic_96_96 \ 78 | --model $model_name \ 79 | --data custom \ 80 | --features M \ 81 | --seq_len 96 \ 82 | --label_len 48 \ 83 | --pred_len 96 \ 84 | --e_layers 2 \ 85 | --d_layers 1 \ 86 | --factor 3 \ 87 | --enc_in 172 \ 88 | --dec_in 172 \ 89 | --c_out 172 \ 90 | --des 'Exp' \ 91 | --exp_name partial_train \ 92 | --itr 1 93 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/utils/__init__.py -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 9 | 10 | @property 11 | def mask(self): 12 | return self._mask 13 | 14 | 15 | class ProbMask(): 16 | def __init__(self, B, H, L, index, scores, device="cpu"): 17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 19 | indicator = _mask_ex[torch.arange(B)[:, None, None], 20 | torch.arange(H)[None, :, None], 21 | index, :].to(device) 22 | self._mask = indicator.view(scores.shape).to(device) 23 | 24 | @property 25 | def mask(self): 26 | return self._mask 27 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 11 | return (u / d).mean(-1) 12 | 13 | 14 | def MAE(pred, true): 15 | return np.mean(np.abs(pred - true)) 16 | 17 | 18 | def MSE(pred, true): 19 | return np.mean((pred - true) ** 2) 20 | 21 | 22 | def RMSE(pred, true): 23 | return np.sqrt(MSE(pred, true)) 24 | 25 | 26 | def MAPE(pred, true): 27 | return np.mean(np.abs((pred - true) / true)) 28 | 29 | 30 | def MSPE(pred, true): 31 | return np.mean(np.square((pred - true) / true)) 32 | 33 | 34 | def metric(pred, true): 35 | mae = MAE(pred, true) 36 | mse = MSE(pred, true) 37 | rmse = RMSE(pred, true) 38 | mape = MAPE(pred, true) 39 | mspe = MSPE(pred, true) 40 | 41 | return mae, mse, rmse, mape, mspe 42 | -------------------------------------------------------------------------------- /utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | # From: gluonts/src/gluonts/time_feature/_base.py 2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"). 5 | # You may not use this file except in compliance with the License. 6 | # A copy of the License is located at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # or in the "license" file accompanying this file. This file is distributed 11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 | # express or implied. See the License for the specific language governing 13 | # permissions and limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from pandas.tseries import offsets 20 | from pandas.tseries.frequencies import to_offset 21 | 22 | 23 | class TimeFeature: 24 | def __init__(self): 25 | pass 26 | 27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 28 | pass 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + "()" 32 | 33 | 34 | class SecondOfMinute(TimeFeature): 35 | """Minute of hour encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.second / 59.0 - 0.5 39 | 40 | 41 | class MinuteOfHour(TimeFeature): 42 | """Minute of hour encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.minute / 59.0 - 0.5 46 | 47 | 48 | class HourOfDay(TimeFeature): 49 | """Hour of day encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return index.hour / 23.0 - 0.5 53 | 54 | 55 | class DayOfWeek(TimeFeature): 56 | """Hour of day encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return index.dayofweek / 6.0 - 0.5 60 | 61 | 62 | class DayOfMonth(TimeFeature): 63 | """Day of month encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.day - 1) / 30.0 - 0.5 67 | 68 | 69 | class DayOfYear(TimeFeature): 70 | """Day of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.dayofyear - 1) / 365.0 - 0.5 74 | 75 | 76 | class MonthOfYear(TimeFeature): 77 | """Month of year encoded as value between [-0.5, 0.5]""" 78 | 79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 80 | return (index.month - 1) / 11.0 - 0.5 81 | 82 | 83 | class WeekOfYear(TimeFeature): 84 | """Week of year encoded as value between [-0.5, 0.5]""" 85 | 86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 87 | return (index.isocalendar().week - 1) / 52.0 - 0.5 88 | 89 | 90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 91 | """ 92 | Returns a list of time features that will be appropriate for the given frequency string. 93 | Parameters 94 | ---------- 95 | freq_str 96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 97 | """ 98 | 99 | features_by_offsets = { 100 | offsets.YearEnd: [], 101 | offsets.QuarterEnd: [MonthOfYear], 102 | offsets.MonthEnd: [MonthOfYear], 103 | offsets.Week: [DayOfMonth, WeekOfYear], 104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 107 | offsets.Minute: [ 108 | MinuteOfHour, 109 | HourOfDay, 110 | DayOfWeek, 111 | DayOfMonth, 112 | DayOfYear, 113 | ], 114 | offsets.Second: [ 115 | SecondOfMinute, 116 | MinuteOfHour, 117 | HourOfDay, 118 | DayOfWeek, 119 | DayOfMonth, 120 | DayOfYear, 121 | ], 122 | } 123 | 124 | offset = to_offset(freq_str) 125 | 126 | for offset_type, feature_classes in features_by_offsets.items(): 127 | if isinstance(offset, offset_type): 128 | return [cls() for cls in feature_classes] 129 | 130 | supported_freq_msg = f""" 131 | Unsupported frequency {freq_str} 132 | The following frequencies are supported: 133 | Y - yearly 134 | alias: A 135 | M - monthly 136 | W - weekly 137 | D - daily 138 | B - business days 139 | H - hourly 140 | T - minutely 141 | alias: min 142 | S - secondly 143 | """ 144 | raise RuntimeError(supported_freq_msg) 145 | 146 | 147 | def time_features(dates, freq='h'): 148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 149 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | 8 | plt.switch_backend('agg') 9 | 10 | 11 | def adjust_learning_rate(optimizer, epoch, args): 12 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 13 | if args.lradj == 'type1': 14 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 15 | elif args.lradj == 'type2': 16 | lr_adjust = { 17 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 18 | 10: 5e-7, 15: 1e-7, 20: 5e-8 19 | } 20 | if epoch in lr_adjust.keys(): 21 | lr = lr_adjust[epoch] 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | print('Updating learning rate to {}'.format(lr)) 25 | 26 | 27 | class EarlyStopping: 28 | def __init__(self, patience=7, verbose=False, delta=0): 29 | self.patience = patience 30 | self.verbose = verbose 31 | self.counter = 0 32 | self.best_score = None 33 | self.early_stop = False 34 | self.val_loss_min = np.Inf 35 | self.delta = delta 36 | 37 | def __call__(self, val_loss, model, path): 38 | score = -val_loss 39 | if self.best_score is None: 40 | self.best_score = score 41 | self.save_checkpoint(val_loss, model, path) 42 | elif score < self.best_score + self.delta: 43 | self.counter += 1 44 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 45 | if self.counter >= self.patience: 46 | self.early_stop = True 47 | else: 48 | self.best_score = score 49 | self.save_checkpoint(val_loss, model, path) 50 | self.counter = 0 51 | 52 | def save_checkpoint(self, val_loss, model, path): 53 | if self.verbose: 54 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 55 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 56 | self.val_loss_min = val_loss 57 | 58 | 59 | class dotdict(dict): 60 | """dot.notation access to dictionary attributes""" 61 | __getattr__ = dict.get 62 | __setattr__ = dict.__setitem__ 63 | __delattr__ = dict.__delitem__ 64 | 65 | 66 | class StandardScaler(): 67 | def __init__(self, mean, std): 68 | self.mean = mean 69 | self.std = std 70 | 71 | def transform(self, data): 72 | return (data - self.mean) / self.std 73 | 74 | def inverse_transform(self, data): 75 | return (data * self.std) + self.mean 76 | 77 | 78 | def visual(true, preds=None, name='./pic/test.pdf'): 79 | """ 80 | Results visualization 81 | """ 82 | plt.figure() 83 | plt.plot(true, label='GroundTruth', linewidth=2) 84 | if preds is not None: 85 | plt.plot(preds, label='Prediction', linewidth=2) 86 | plt.legend() 87 | plt.savefig(name, bbox_inches='tight') 88 | 89 | 90 | def adjustment(gt, pred): 91 | anomaly_state = False 92 | for i in range(len(gt)): 93 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 94 | anomaly_state = True 95 | for j in range(i, 0, -1): 96 | if gt[j] == 0: 97 | break 98 | else: 99 | if pred[j] == 0: 100 | pred[j] = 1 101 | for j in range(i, len(gt)): 102 | if gt[j] == 0: 103 | break 104 | else: 105 | if pred[j] == 0: 106 | pred[j] = 1 107 | elif gt[i] == 0: 108 | anomaly_state = False 109 | if anomaly_state: 110 | pred[i] = 1 111 | return gt, pred 112 | 113 | 114 | def cal_accuracy(y_pred, y_true): 115 | return np.mean(y_pred == y_true) 116 | --------------------------------------------------------------------------------