├── .gitignore
├── LICENSE
├── README.md
├── data_provider
├── __init__.py
├── data_factory.py
└── data_loader.py
├── experiments
├── exp_basic.py
├── exp_long_term_forecasting.py
└── exp_long_term_forecasting_partial.py
├── figures
├── ablations.png
├── algorithm.png
├── analysis.png
├── architecture.png
├── boosting.png
├── boosting_trm.png
├── datasets.png
├── datasets_mtsf.png
├── efficiency.png
├── efficient.png
├── formulations.png
├── generability.png
├── groups.png
├── increase_lookback.png
├── layernorm.png
├── main_results.png
├── main_results_alipay.png
├── motivation.png
├── pi.png
├── pt.png
└── radar.png
├── layers
├── Embed.py
├── SelfAttention_Family.py
├── Transformer_EncDec.py
└── __init__.py
├── model
├── Flashformer.py
├── Flowformer.py
├── Informer.py
├── Reformer.py
├── Transformer.py
├── __init__.py
├── iFlashformer.py
├── iFlowformer.py
├── iInformer.py
├── iReformer.py
└── iTransformer.py
├── requirements.txt
├── run.py
├── scripts
├── boost_performance
│ ├── ECL
│ │ ├── iFlowformer.sh
│ │ ├── iInformer.sh
│ │ ├── iReformer.sh
│ │ └── iTransformer.sh
│ ├── README.md
│ ├── Traffic
│ │ ├── iFlowformer.sh
│ │ ├── iInformer.sh
│ │ ├── iReformer.sh
│ │ └── iTransformer.sh
│ └── Weather
│ │ ├── iFlowformer.sh
│ │ ├── iInformer.sh
│ │ ├── iReformer.sh
│ │ └── iTransformer.sh
├── increasing_lookback
│ ├── ECL
│ │ ├── iFlowformer.sh
│ │ ├── iInformer.sh
│ │ ├── iReformer.sh
│ │ └── iTransformer.sh
│ ├── README.md
│ └── Traffic
│ │ ├── iFlowformer.sh
│ │ ├── iInformer.sh
│ │ ├── iReformer.sh
│ │ └── iTransformer.sh
├── model_efficiency
│ ├── ECL
│ │ └── iFlashTransformer.sh
│ ├── README.md
│ ├── Traffic
│ │ └── iFlashTransformer.sh
│ └── Weather
│ │ └── iFlashTransformer.sh
├── multivariate_forecasting
│ ├── ECL
│ │ └── iTransformer.sh
│ ├── ETT
│ │ ├── iTransformer_ETTh1.sh
│ │ ├── iTransformer_ETTh2.sh
│ │ ├── iTransformer_ETTm1.sh
│ │ └── iTransformer_ETTm2.sh
│ ├── Exchange
│ │ └── iTransformer.sh
│ ├── PEMS
│ │ ├── iTransformer_03.sh
│ │ ├── iTransformer_04.sh
│ │ ├── iTransformer_07.sh
│ │ └── iTransformer_08.sh
│ ├── README.md
│ ├── SolarEnergy
│ │ └── iTransformer.sh
│ ├── Traffic
│ │ └── iTransformer.sh
│ └── Weather
│ │ └── iTransformer.sh
└── variate_generalization
│ ├── ECL
│ ├── iFlowformer.sh
│ ├── iInformer.sh
│ ├── iReformer.sh
│ └── iTransformer.sh
│ ├── README.md
│ ├── SolarEnergy
│ ├── iFlowformer.sh
│ ├── iInformer.sh
│ ├── iReformer.sh
│ └── iTransformer.sh
│ └── Traffic
│ ├── iFlowformer.sh
│ ├── iInformer.sh
│ ├── iReformer.sh
│ └── iTransformer.sh
└── utils
├── __init__.py
├── masking.py
├── metrics.py
├── timefeatures.py
└── tools.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | */.DS_Store
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 THUML @ Tsinghua University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # iTransformer
2 |
3 | The repo is the official implementation for the paper: [iTransformer: Inverted Transformers Are Effective for Time Series Forecasting](https://arxiv.org/abs/2310.06625). [[Slides]](https://cloud.tsinghua.edu.cn/f/175ff98f7e2d44fbbe8e/), [[Poster]](https://cloud.tsinghua.edu.cn/f/36a2ae6c132d44c0bd8c/).
4 |
5 |
6 | # Updates
7 |
8 | :triangular_flag_on_post: **News** (2024.10) [TimeXer](https://arxiv.org/abs/2402.19072), a Transformer for predicting with exogenous variables, is released. Code is available [here](https://github.com/thuml/TimeXer).
9 |
10 | :triangular_flag_on_post: **News** (2024.05) Many thanks for the great efforts from [lucidrains](https://github.com/lucidrains/iTransformer). A pip package for the usage of iTransformer variants can be simply installed via ```pip install iTransformer```
11 |
12 | :triangular_flag_on_post: **News** (2024.04) iTransformer has benn included in [NeuralForecast](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/itransformer.py). Special thanks to the contributor @[Marco](https://github.com/marcopeix)!
13 |
14 | :triangular_flag_on_post: **News** (2024.03) Introduction of our work in [Chinese](https://mp.weixin.qq.com/s/-pvBnA1_NSloNxa6TYXTSg) is available.
15 |
16 | :triangular_flag_on_post: **News** (2024.02) iTransformer has been accepted as **ICLR 2024 Spotlight**.
17 |
18 | :triangular_flag_on_post: **News** (2023.12) iTransformer available in [GluonTS](https://github.com/awslabs/gluonts/pull/3017) with probablistic head and support for static covariates. Notebook is available [here](https://github.com/awslabs/gluonts/blob/dev/examples/iTransformer.ipynb).
19 |
20 | :triangular_flag_on_post: **News** (2023.12) We received lots of valuable suggestions. A [revised version](https://arxiv.org/pdf/2310.06625v2.pdf) (**24 Pages**) is now available.
21 |
22 | :triangular_flag_on_post: **News** (2023.10) iTransformer has been included in [[Time-Series-Library]](https://github.com/thuml/Time-Series-Library) and achieves state-of-the-art in Lookback-$96$ forecasting.
23 |
24 | :triangular_flag_on_post: **News** (2023.10) All the scripts for the experiments in our [paper](https://arxiv.org/pdf/2310.06625.pdf) are available.
25 |
26 |
27 | ## Introduction
28 |
29 | 🌟 Considering the characteristics of multivariate time series, iTransformer breaks the conventional structure without modifying any Transformer modules. **Inverted Transformer is all you need in MTSF**.
30 |
31 |
32 |
33 |
34 |
35 | 🏆 iTransformer achieves the comprehensive state-of-the-art in challenging multivariate forecasting tasks and solves several pain points of Transformer on extensive time series data.
36 |
37 |
38 |
39 |
40 |
41 |
42 | ## Overall Architecture
43 |
44 | iTransformer regards **independent time series as variate tokens** to **capture multivariate correlations by attention** and **utilize layernorm and feed-forward networks to learn series representations**.
45 |
46 |
47 |
48 |
49 |
50 | The pseudo-code of iTransformer is as simple as the following:
51 |
52 |
53 |
54 |
55 |
56 | ## Usage
57 |
58 | 1. Install Pytorch and the necessary dependencies.
59 |
60 | ```
61 | pip install -r requirements.txt
62 | ```
63 |
64 | 1. The datasets can be obtained from [Google Drive](https://drive.google.com/file/d/1l51QsKvQPcqILT3DwfjCgx8Dsg2rpjot/view?usp=drive_link) or [Baidu Cloud](https://pan.baidu.com/s/11AWXg1Z6UwjHzmto4hesAA?pwd=9qjr).
65 |
66 | 2. Train and evaluate the model. We provide all the above tasks under the folder ./scripts/. You can reproduce the results as the following examples:
67 |
68 | ```
69 | # Multivariate forecasting with iTransformer
70 | bash ./scripts/multivariate_forecasting/Traffic/iTransformer.sh
71 |
72 | # Compare the performance of Transformer and iTransformer
73 | bash ./scripts/boost_performance/Weather/iTransformer.sh
74 |
75 | # Train the model with partial variates, and generalize to the unseen variates
76 | bash ./scripts/variate_generalization/ECL/iTransformer.sh
77 |
78 | # Test the performance on the enlarged lookback window
79 | bash ./scripts/increasing_lookback/Traffic/iTransformer.sh
80 |
81 | # Utilize FlashAttention for acceleration
82 | bash ./scripts/efficient_attentions/iFlashTransformer.sh
83 | ```
84 |
85 | ## Main Result of Multivariate Forecasting
86 |
87 | We evaluate the iTransformer on challenging multivariate forecasting benchmarks (**generally hundreds of variates**). **Comprehensive good performance** (MSE/MAE $\downarrow$) is achieved.
88 |
89 |
90 |
91 | ### Online Transaction Load Prediction of Alipay Trading Platform (Avg Results)
92 |
93 |
94 |
95 |
96 |
97 | ## General Performance Boosting on Transformers
98 |
99 | By introducing the proposed framework, Transformer and its variants achieve **significant performance improvement**, demonstrating the **generality of the iTransformer approach** and **benefiting from efficient attention mechanisms**.
100 |
101 |
102 |
103 |
104 |
105 | ## Zero-Shot Generalization on Variates
106 |
107 | **Technically, iTransformer is able to forecast with arbitrary numbers of variables**. We train iTransformers on partial variates and forecast unseen variates with good generalizability.
108 |
109 |
110 |
111 |
112 |
113 | ## Model Analysis
114 |
115 | Benefiting from inverted Transformer modules:
116 |
117 | - (Left) Inverted Transformers learn **better time series representations** (more similar [CKA](https://github.com/jayroxis/CKA-similarity)) favored by forecasting.
118 | - (Right) The inverted self-attention module learns **interpretable multivariate correlations**.
119 |
120 |
121 |
122 |
123 |
124 | ## Citation
125 |
126 | If you find this repo helpful, please cite our paper.
127 |
128 | ```
129 | @article{liu2023itransformer,
130 | title={iTransformer: Inverted Transformers Are Effective for Time Series Forecasting},
131 | author={Liu, Yong and Hu, Tengge and Zhang, Haoran and Wu, Haixu and Wang, Shiyu and Ma, Lintao and Long, Mingsheng},
132 | journal={arXiv preprint arXiv:2310.06625},
133 | year={2023}
134 | }
135 | ```
136 |
137 | ## Acknowledgement
138 |
139 | We appreciate the following GitHub repos a lot for their valuable code and efforts.
140 | - Reformer (https://github.com/lucidrains/reformer-pytorch)
141 | - Informer (https://github.com/zhouhaoyi/Informer2020)
142 | - FlashAttention (https://github.com/shreyansh26/FlashAttention-PyTorch)
143 | - Autoformer (https://github.com/thuml/Autoformer)
144 | - Stationary (https://github.com/thuml/Nonstationary_Transformers)
145 | - Time-Series-Library (https://github.com/thuml/Time-Series-Library)
146 | - lucidrains (https://github.com/lucidrains/iTransformer)
147 |
148 | This work was supported by Ant Group through the CCF-Ant Research Fund and awarded as [Outstanding Projects of CCF Fund](https://mp.weixin.qq.com/s/PDLNbibZD3kqhcUoNejLfA).
149 |
150 | ## Contact
151 |
152 | If you have any questions or want to use the code, feel free to contact:
153 | * Yong Liu (liuyong21@mails.tsinghua.edu.cn)
154 | * Haoran Zhang (z-hr20@mails.tsinghua.edu.cn)
155 | * Tengge Hu (htg21@mails.tsinghua.edu.cn)
156 |
--------------------------------------------------------------------------------
/data_provider/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/data_provider/__init__.py
--------------------------------------------------------------------------------
/data_provider/data_factory.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Solar, Dataset_PEMS, \
2 | Dataset_Pred
3 | from torch.utils.data import DataLoader
4 |
5 | data_dict = {
6 | 'ETTh1': Dataset_ETT_hour,
7 | 'ETTh2': Dataset_ETT_hour,
8 | 'ETTm1': Dataset_ETT_minute,
9 | 'ETTm2': Dataset_ETT_minute,
10 | 'Solar': Dataset_Solar,
11 | 'PEMS': Dataset_PEMS,
12 | 'custom': Dataset_Custom,
13 | }
14 |
15 |
16 | def data_provider(args, flag):
17 | Data = data_dict[args.data]
18 | timeenc = 0 if args.embed != 'timeF' else 1
19 |
20 | if flag == 'test':
21 | shuffle_flag = False
22 | drop_last = True
23 | batch_size = 1 # bsz=1 for evaluation
24 | freq = args.freq
25 | elif flag == 'pred':
26 | shuffle_flag = False
27 | drop_last = False
28 | batch_size = 1
29 | freq = args.freq
30 | Data = Dataset_Pred
31 | else:
32 | shuffle_flag = True
33 | drop_last = True
34 | batch_size = args.batch_size # bsz for train and valid
35 | freq = args.freq
36 |
37 | data_set = Data(
38 | root_path=args.root_path,
39 | data_path=args.data_path,
40 | flag=flag,
41 | size=[args.seq_len, args.label_len, args.pred_len],
42 | features=args.features,
43 | target=args.target,
44 | timeenc=timeenc,
45 | freq=freq,
46 | )
47 | print(flag, len(data_set))
48 | data_loader = DataLoader(
49 | data_set,
50 | batch_size=batch_size,
51 | shuffle=shuffle_flag,
52 | num_workers=args.num_workers,
53 | drop_last=drop_last)
54 | return data_set, data_loader
55 |
--------------------------------------------------------------------------------
/experiments/exp_basic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from model import Transformer, Informer, Reformer, Flowformer, Flashformer, \
4 | iTransformer, iInformer, iReformer, iFlowformer, iFlashformer
5 |
6 |
7 | class Exp_Basic(object):
8 | def __init__(self, args):
9 | self.args = args
10 | self.model_dict = {
11 | 'Transformer': Transformer,
12 | 'Informer': Informer,
13 | 'Reformer': Reformer,
14 | 'Flowformer': Flowformer,
15 | 'Flashformer': Flashformer,
16 | 'iTransformer': iTransformer,
17 | 'iInformer': iInformer,
18 | 'iReformer': iReformer,
19 | 'iFlowformer': iFlowformer,
20 | 'iFlashformer': iFlashformer,
21 | }
22 | self.device = self._acquire_device()
23 | self.model = self._build_model().to(self.device)
24 |
25 | def _build_model(self):
26 | raise NotImplementedError
27 | return None
28 |
29 | def _acquire_device(self):
30 | if self.args.use_gpu:
31 | os.environ["CUDA_VISIBLE_DEVICES"] = str(
32 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
33 | device = torch.device('cuda:{}'.format(self.args.gpu))
34 | print('Use GPU: cuda:{}'.format(self.args.gpu))
35 | else:
36 | device = torch.device('cpu')
37 | print('Use CPU')
38 | return device
39 |
40 | def _get_data(self):
41 | pass
42 |
43 | def vali(self):
44 | pass
45 |
46 | def train(self):
47 | pass
48 |
49 | def test(self):
50 | pass
51 |
--------------------------------------------------------------------------------
/figures/ablations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/ablations.png
--------------------------------------------------------------------------------
/figures/algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/algorithm.png
--------------------------------------------------------------------------------
/figures/analysis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/analysis.png
--------------------------------------------------------------------------------
/figures/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/architecture.png
--------------------------------------------------------------------------------
/figures/boosting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/boosting.png
--------------------------------------------------------------------------------
/figures/boosting_trm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/boosting_trm.png
--------------------------------------------------------------------------------
/figures/datasets.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/datasets.png
--------------------------------------------------------------------------------
/figures/datasets_mtsf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/datasets_mtsf.png
--------------------------------------------------------------------------------
/figures/efficiency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/efficiency.png
--------------------------------------------------------------------------------
/figures/efficient.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/efficient.png
--------------------------------------------------------------------------------
/figures/formulations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/formulations.png
--------------------------------------------------------------------------------
/figures/generability.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/generability.png
--------------------------------------------------------------------------------
/figures/groups.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/groups.png
--------------------------------------------------------------------------------
/figures/increase_lookback.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/increase_lookback.png
--------------------------------------------------------------------------------
/figures/layernorm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/layernorm.png
--------------------------------------------------------------------------------
/figures/main_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/main_results.png
--------------------------------------------------------------------------------
/figures/main_results_alipay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/main_results_alipay.png
--------------------------------------------------------------------------------
/figures/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/motivation.png
--------------------------------------------------------------------------------
/figures/pi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/pi.png
--------------------------------------------------------------------------------
/figures/pt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/pt.png
--------------------------------------------------------------------------------
/figures/radar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/figures/radar.png
--------------------------------------------------------------------------------
/layers/Embed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | class PositionalEmbedding(nn.Module):
7 | def __init__(self, d_model, max_len=5000):
8 | super(PositionalEmbedding, self).__init__()
9 | # Compute the positional encodings once in log space.
10 | pe = torch.zeros(max_len, d_model).float()
11 | pe.require_grad = False
12 |
13 | position = torch.arange(0, max_len).float().unsqueeze(1)
14 | div_term = (torch.arange(0, d_model, 2).float()
15 | * -(math.log(10000.0) / d_model)).exp()
16 |
17 | pe[:, 0::2] = torch.sin(position * div_term)
18 | pe[:, 1::2] = torch.cos(position * div_term)
19 |
20 | pe = pe.unsqueeze(0)
21 | self.register_buffer('pe', pe)
22 |
23 | def forward(self, x):
24 | return self.pe[:, :x.size(1)]
25 |
26 |
27 | class TokenEmbedding(nn.Module):
28 | def __init__(self, c_in, d_model):
29 | super(TokenEmbedding, self).__init__()
30 | padding = 1 if torch.__version__ >= '1.5.0' else 2
31 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
32 | kernel_size=3, padding=padding, padding_mode='circular', bias=False)
33 | for m in self.modules():
34 | if isinstance(m, nn.Conv1d):
35 | nn.init.kaiming_normal_(
36 | m.weight, mode='fan_in', nonlinearity='leaky_relu')
37 |
38 | def forward(self, x):
39 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
40 | return x
41 |
42 |
43 | class FixedEmbedding(nn.Module):
44 | def __init__(self, c_in, d_model):
45 | super(FixedEmbedding, self).__init__()
46 |
47 | w = torch.zeros(c_in, d_model).float()
48 | w.require_grad = False
49 |
50 | position = torch.arange(0, c_in).float().unsqueeze(1)
51 | div_term = (torch.arange(0, d_model, 2).float()
52 | * -(math.log(10000.0) / d_model)).exp()
53 |
54 | w[:, 0::2] = torch.sin(position * div_term)
55 | w[:, 1::2] = torch.cos(position * div_term)
56 |
57 | self.emb = nn.Embedding(c_in, d_model)
58 | self.emb.weight = nn.Parameter(w, requires_grad=False)
59 |
60 | def forward(self, x):
61 | return self.emb(x).detach()
62 |
63 |
64 | class TemporalEmbedding(nn.Module):
65 | def __init__(self, d_model, embed_type='fixed', freq='h'):
66 | super(TemporalEmbedding, self).__init__()
67 |
68 | minute_size = 4
69 | hour_size = 24
70 | weekday_size = 7
71 | day_size = 32
72 | month_size = 13
73 |
74 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
75 | if freq == 't':
76 | self.minute_embed = Embed(minute_size, d_model)
77 | self.hour_embed = Embed(hour_size, d_model)
78 | self.weekday_embed = Embed(weekday_size, d_model)
79 | self.day_embed = Embed(day_size, d_model)
80 | self.month_embed = Embed(month_size, d_model)
81 |
82 | def forward(self, x):
83 | x = x.long()
84 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
85 | self, 'minute_embed') else 0.
86 | hour_x = self.hour_embed(x[:, :, 3])
87 | weekday_x = self.weekday_embed(x[:, :, 2])
88 | day_x = self.day_embed(x[:, :, 1])
89 | month_x = self.month_embed(x[:, :, 0])
90 |
91 | return hour_x + weekday_x + day_x + month_x + minute_x
92 |
93 |
94 | class TimeFeatureEmbedding(nn.Module):
95 | def __init__(self, d_model, embed_type='timeF', freq='h'):
96 | super(TimeFeatureEmbedding, self).__init__()
97 |
98 | freq_map = {'h': 4, 't': 5, 's': 6,
99 | 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
100 | d_inp = freq_map[freq]
101 | self.embed = nn.Linear(d_inp, d_model, bias=False)
102 |
103 | def forward(self, x):
104 | return self.embed(x)
105 |
106 |
107 | class DataEmbedding(nn.Module):
108 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
109 | super(DataEmbedding, self).__init__()
110 |
111 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
112 | self.position_embedding = PositionalEmbedding(d_model=d_model)
113 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
114 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
115 | d_model=d_model, embed_type=embed_type, freq=freq)
116 | self.dropout = nn.Dropout(p=dropout)
117 |
118 | def forward(self, x, x_mark):
119 | if x_mark is None:
120 | x = self.value_embedding(x) + self.position_embedding(x)
121 | else:
122 | x = self.value_embedding(
123 | x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
124 | return self.dropout(x)
125 |
126 |
127 | class DataEmbedding_inverted(nn.Module):
128 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
129 | super(DataEmbedding_inverted, self).__init__()
130 | self.value_embedding = nn.Linear(c_in, d_model)
131 | self.dropout = nn.Dropout(p=dropout)
132 |
133 | def forward(self, x, x_mark):
134 | x = x.permute(0, 2, 1)
135 | # x: [Batch Variate Time]
136 | if x_mark is None:
137 | x = self.value_embedding(x)
138 | else:
139 | # the potential to take covariates (e.g. timestamps) as tokens
140 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
141 | # x: [Batch Variate d_model]
142 | return self.dropout(x)
143 |
144 |
--------------------------------------------------------------------------------
/layers/SelfAttention_Family.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from math import sqrt
5 | from utils.masking import TriangularCausalMask, ProbMask
6 | from reformer_pytorch import LSHSelfAttention
7 | from einops import rearrange
8 |
9 |
10 | # Code implementation from https://github.com/thuml/Flowformer
11 | class FlowAttention(nn.Module):
12 | def __init__(self, attention_dropout=0.1):
13 | super(FlowAttention, self).__init__()
14 | self.dropout = nn.Dropout(attention_dropout)
15 |
16 | def kernel_method(self, x):
17 | return torch.sigmoid(x)
18 |
19 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
20 | queries = queries.transpose(1, 2)
21 | keys = keys.transpose(1, 2)
22 | values = values.transpose(1, 2)
23 | # kernel
24 | queries = self.kernel_method(queries)
25 | keys = self.kernel_method(keys)
26 | # incoming and outgoing
27 | normalizer_row = 1.0 / (torch.einsum("nhld,nhd->nhl", queries + 1e-6, keys.sum(dim=2) + 1e-6))
28 | normalizer_col = 1.0 / (torch.einsum("nhsd,nhd->nhs", keys + 1e-6, queries.sum(dim=2) + 1e-6))
29 | # reweighting
30 | normalizer_row_refine = (
31 | torch.einsum("nhld,nhd->nhl", queries + 1e-6, (keys * normalizer_col[:, :, :, None]).sum(dim=2) + 1e-6))
32 | normalizer_col_refine = (
33 | torch.einsum("nhsd,nhd->nhs", keys + 1e-6, (queries * normalizer_row[:, :, :, None]).sum(dim=2) + 1e-6))
34 | # competition and allocation
35 | normalizer_row_refine = torch.sigmoid(
36 | normalizer_row_refine * (float(queries.shape[2]) / float(keys.shape[2])))
37 | normalizer_col_refine = torch.softmax(normalizer_col_refine, dim=-1) * keys.shape[2] # B h L vis
38 | # multiply
39 | kv = keys.transpose(-2, -1) @ (values * normalizer_col_refine[:, :, :, None])
40 | x = (((queries @ kv) * normalizer_row[:, :, :, None]) * normalizer_row_refine[:, :, :, None]).transpose(1,
41 | 2).contiguous()
42 | return x, None
43 |
44 |
45 | # Code implementation from https://github.com/shreyansh26/FlashAttention-PyTorch
46 | class FlashAttention(nn.Module):
47 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
48 | super(FlashAttention, self).__init__()
49 | self.scale = scale
50 | self.mask_flag = mask_flag
51 | self.output_attention = output_attention
52 | self.dropout = nn.Dropout(attention_dropout)
53 |
54 | def flash_attention_forward(self, Q, K, V, mask=None):
55 | BLOCK_SIZE = 32
56 | NEG_INF = -1e10 # -infinity
57 | EPSILON = 1e-10
58 | # mask = torch.randint(0, 2, (128, 8)).to(device='cuda')
59 | O = torch.zeros_like(Q, requires_grad=True)
60 | l = torch.zeros(Q.shape[:-1])[..., None]
61 | m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF
62 |
63 | O = O.to(device='cuda')
64 | l = l.to(device='cuda')
65 | m = m.to(device='cuda')
66 |
67 | Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1])
68 | KV_BLOCK_SIZE = BLOCK_SIZE
69 |
70 | Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)
71 | K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)
72 | V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
73 | if mask is not None:
74 | mask_BLOCKS = list(torch.split(mask, KV_BLOCK_SIZE, dim=1))
75 |
76 | Tr = len(Q_BLOCKS)
77 | Tc = len(K_BLOCKS)
78 |
79 | O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
80 | l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
81 | m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))
82 |
83 | for j in range(Tc):
84 | Kj = K_BLOCKS[j]
85 | Vj = V_BLOCKS[j]
86 | if mask is not None:
87 | maskj = mask_BLOCKS[j]
88 |
89 | for i in range(Tr):
90 | Qi = Q_BLOCKS[i]
91 | Oi = O_BLOCKS[i]
92 | li = l_BLOCKS[i]
93 | mi = m_BLOCKS[i]
94 |
95 | scale = 1 / np.sqrt(Q.shape[-1])
96 | Qi_scaled = Qi * scale
97 |
98 | S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj)
99 | if mask is not None:
100 | # Masking
101 | maskj_temp = rearrange(maskj, 'b j -> b 1 1 j')
102 | S_ij = torch.where(maskj_temp > 0, S_ij, NEG_INF)
103 |
104 | m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
105 | P_ij = torch.exp(S_ij - m_block_ij)
106 | if mask is not None:
107 | # Masking
108 | P_ij = torch.where(maskj_temp > 0, P_ij, 0.)
109 |
110 | l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON
111 |
112 | P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj)
113 |
114 | mi_new = torch.maximum(m_block_ij, mi)
115 | li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij
116 |
117 | O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi + (
118 | torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
119 | l_BLOCKS[i] = li_new
120 | m_BLOCKS[i] = mi_new
121 |
122 | O = torch.cat(O_BLOCKS, dim=2)
123 | l = torch.cat(l_BLOCKS, dim=2)
124 | m = torch.cat(m_BLOCKS, dim=2)
125 | return O, l, m
126 |
127 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
128 | res = \
129 | self.flash_attention_forward(queries.permute(0, 2, 1, 3), keys.permute(0, 2, 1, 3), values.permute(0, 2, 1, 3),
130 | attn_mask)[0]
131 | return res.permute(0, 2, 1, 3).contiguous(), None
132 |
133 |
134 | class FullAttention(nn.Module):
135 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
136 | super(FullAttention, self).__init__()
137 | self.scale = scale
138 | self.mask_flag = mask_flag
139 | self.output_attention = output_attention
140 | self.dropout = nn.Dropout(attention_dropout)
141 |
142 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
143 | B, L, H, E = queries.shape
144 | _, S, _, D = values.shape
145 | scale = self.scale or 1. / sqrt(E)
146 |
147 | scores = torch.einsum("blhe,bshe->bhls", queries, keys)
148 |
149 | if self.mask_flag:
150 | if attn_mask is None:
151 | attn_mask = TriangularCausalMask(B, L, device=queries.device)
152 |
153 | scores.masked_fill_(attn_mask.mask, -np.inf)
154 |
155 | A = self.dropout(torch.softmax(scale * scores, dim=-1))
156 | V = torch.einsum("bhls,bshd->blhd", A, values)
157 |
158 | if self.output_attention:
159 | return (V.contiguous(), A)
160 | else:
161 | return (V.contiguous(), None)
162 |
163 |
164 | # Code implementation from https://github.com/zhouhaoyi/Informer2020
165 | class ProbAttention(nn.Module):
166 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
167 | super(ProbAttention, self).__init__()
168 | self.factor = factor
169 | self.scale = scale
170 | self.mask_flag = mask_flag
171 | self.output_attention = output_attention
172 | self.dropout = nn.Dropout(attention_dropout)
173 |
174 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
175 | # Q [B, H, L, D]
176 | B, H, L_K, E = K.shape
177 | _, _, L_Q, _ = Q.shape
178 |
179 | # calculate the sampled Q_K
180 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
181 | # real U = U_part(factor*ln(L_k))*L_q
182 | index_sample = torch.randint(L_K, (L_Q, sample_k))
183 | K_sample = K_expand[:, :, torch.arange(
184 | L_Q).unsqueeze(1), index_sample, :]
185 | Q_K_sample = torch.matmul(
186 | Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
187 |
188 | # find the Top_k query with sparisty measurement
189 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
190 | M_top = M.topk(n_top, sorted=False)[1]
191 |
192 | # use the reduced Q to calculate Q_K
193 | Q_reduce = Q[torch.arange(B)[:, None, None],
194 | torch.arange(H)[None, :, None],
195 | M_top, :] # factor*ln(L_q)
196 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
197 |
198 | return Q_K, M_top
199 |
200 | def _get_initial_context(self, V, L_Q):
201 | B, H, L_V, D = V.shape
202 | if not self.mask_flag:
203 | # V_sum = V.sum(dim=-2)
204 | V_sum = V.mean(dim=-2)
205 | contex = V_sum.unsqueeze(-2).expand(B, H,
206 | L_Q, V_sum.shape[-1]).clone()
207 | else: # use mask
208 | # requires that L_Q == L_V, i.e. for self-attention only
209 | assert (L_Q == L_V)
210 | contex = V.cumsum(dim=-2)
211 | return contex
212 |
213 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
214 | B, H, L_V, D = V.shape
215 |
216 | if self.mask_flag:
217 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
218 | scores.masked_fill_(attn_mask.mask, -np.inf)
219 |
220 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
221 |
222 | context_in[torch.arange(B)[:, None, None],
223 | torch.arange(H)[None, :, None],
224 | index, :] = torch.matmul(attn, V).type_as(context_in)
225 | if self.output_attention:
226 | attns = (torch.ones([B, H, L_V, L_V]) /
227 | L_V).type_as(attn).to(attn.device)
228 | attns[torch.arange(B)[:, None, None], torch.arange(H)[
229 | None, :, None], index, :] = attn
230 | return (context_in, attns)
231 | else:
232 | return (context_in, None)
233 |
234 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
235 | B, L_Q, H, D = queries.shape
236 | _, L_K, _, _ = keys.shape
237 |
238 | queries = queries.transpose(2, 1)
239 | keys = keys.transpose(2, 1)
240 | values = values.transpose(2, 1)
241 |
242 | U_part = self.factor * \
243 | np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
244 | u = self.factor * \
245 | np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
246 |
247 | U_part = U_part if U_part < L_K else L_K
248 | u = u if u < L_Q else L_Q
249 |
250 | scores_top, index = self._prob_QK(
251 | queries, keys, sample_k=U_part, n_top=u)
252 |
253 | # add scale factor
254 | scale = self.scale or 1. / sqrt(D)
255 | if scale is not None:
256 | scores_top = scores_top * scale
257 | # get the context
258 | context = self._get_initial_context(values, L_Q)
259 | # update the context with selected top_k queries
260 | context, attn = self._update_context(
261 | context, values, scores_top, index, L_Q, attn_mask)
262 |
263 | return context.contiguous(), attn
264 |
265 |
266 | class AttentionLayer(nn.Module):
267 | def __init__(self, attention, d_model, n_heads, d_keys=None,
268 | d_values=None):
269 | super(AttentionLayer, self).__init__()
270 |
271 | d_keys = d_keys or (d_model // n_heads)
272 | d_values = d_values or (d_model // n_heads)
273 |
274 | self.inner_attention = attention
275 | self.query_projection = nn.Linear(d_model, d_keys * n_heads)
276 | self.key_projection = nn.Linear(d_model, d_keys * n_heads)
277 | self.value_projection = nn.Linear(d_model, d_values * n_heads)
278 | self.out_projection = nn.Linear(d_values * n_heads, d_model)
279 | self.n_heads = n_heads
280 |
281 | def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
282 | B, L, _ = queries.shape
283 | _, S, _ = keys.shape
284 | H = self.n_heads
285 |
286 | queries = self.query_projection(queries).view(B, L, H, -1)
287 | keys = self.key_projection(keys).view(B, S, H, -1)
288 | values = self.value_projection(values).view(B, S, H, -1)
289 |
290 | out, attn = self.inner_attention(
291 | queries,
292 | keys,
293 | values,
294 | attn_mask,
295 | tau=tau,
296 | delta=delta
297 | )
298 | out = out.view(B, L, -1)
299 |
300 | return self.out_projection(out), attn
301 |
302 |
303 | class ReformerLayer(nn.Module):
304 | def __init__(self, attention, d_model, n_heads, d_keys=None,
305 | d_values=None, causal=False, bucket_size=4, n_hashes=4):
306 | super().__init__()
307 | self.bucket_size = bucket_size
308 | self.attn = LSHSelfAttention(
309 | dim=d_model,
310 | heads=n_heads,
311 | bucket_size=bucket_size,
312 | n_hashes=n_hashes,
313 | causal=causal
314 | )
315 |
316 | def fit_length(self, queries):
317 | # inside reformer: assert N % (bucket_size * 2) == 0
318 | B, N, C = queries.shape
319 | if N % (self.bucket_size * 2) == 0:
320 | return queries
321 | else:
322 | # fill the time series
323 | fill_len = (self.bucket_size * 2) - (N % (self.bucket_size * 2))
324 | return torch.cat([queries, torch.zeros([B, fill_len, C]).to(queries.device)], dim=1)
325 |
326 | def forward(self, queries, keys, values, attn_mask, tau, delta):
327 | # in Reformer: defalut queries=keys
328 | B, N, C = queries.shape
329 | queries = self.attn(self.fit_length(queries))[:, :N, :]
330 | return queries, None
331 |
332 |
--------------------------------------------------------------------------------
/layers/Transformer_EncDec.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class ConvLayer(nn.Module):
6 | def __init__(self, c_in):
7 | super(ConvLayer, self).__init__()
8 | self.downConv = nn.Conv1d(in_channels=c_in,
9 | out_channels=c_in,
10 | kernel_size=3,
11 | padding=2,
12 | padding_mode='circular')
13 | self.norm = nn.BatchNorm1d(c_in)
14 | self.activation = nn.ELU()
15 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
16 |
17 | def forward(self, x):
18 | x = self.downConv(x.permute(0, 2, 1))
19 | x = self.norm(x)
20 | x = self.activation(x)
21 | x = self.maxPool(x)
22 | x = x.transpose(1, 2)
23 | return x
24 |
25 |
26 | class EncoderLayer(nn.Module):
27 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
28 | super(EncoderLayer, self).__init__()
29 | d_ff = d_ff or 4 * d_model
30 | self.attention = attention
31 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
32 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
33 | self.norm1 = nn.LayerNorm(d_model)
34 | self.norm2 = nn.LayerNorm(d_model)
35 | self.dropout = nn.Dropout(dropout)
36 | self.activation = F.relu if activation == "relu" else F.gelu
37 |
38 | def forward(self, x, attn_mask=None, tau=None, delta=None):
39 | new_x, attn = self.attention(
40 | x, x, x,
41 | attn_mask=attn_mask,
42 | tau=tau, delta=delta
43 | )
44 | x = x + self.dropout(new_x)
45 |
46 | y = x = self.norm1(x)
47 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
48 | y = self.dropout(self.conv2(y).transpose(-1, 1))
49 |
50 | return self.norm2(x + y), attn
51 |
52 |
53 | class Encoder(nn.Module):
54 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
55 | super(Encoder, self).__init__()
56 | self.attn_layers = nn.ModuleList(attn_layers)
57 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
58 | self.norm = norm_layer
59 |
60 | def forward(self, x, attn_mask=None, tau=None, delta=None):
61 | # x [B, L, D]
62 | attns = []
63 | if self.conv_layers is not None:
64 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)):
65 | delta = delta if i == 0 else None
66 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
67 | x = conv_layer(x)
68 | attns.append(attn)
69 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
70 | attns.append(attn)
71 | else:
72 | for attn_layer in self.attn_layers:
73 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
74 | attns.append(attn)
75 |
76 | if self.norm is not None:
77 | x = self.norm(x)
78 |
79 | return x, attns
80 |
81 |
82 | class DecoderLayer(nn.Module):
83 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
84 | dropout=0.1, activation="relu"):
85 | super(DecoderLayer, self).__init__()
86 | d_ff = d_ff or 4 * d_model
87 | self.self_attention = self_attention
88 | self.cross_attention = cross_attention
89 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
90 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
91 | self.norm1 = nn.LayerNorm(d_model)
92 | self.norm2 = nn.LayerNorm(d_model)
93 | self.norm3 = nn.LayerNorm(d_model)
94 | self.dropout = nn.Dropout(dropout)
95 | self.activation = F.relu if activation == "relu" else F.gelu
96 |
97 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
98 | x = x + self.dropout(self.self_attention(
99 | x, x, x,
100 | attn_mask=x_mask,
101 | tau=tau, delta=None
102 | )[0])
103 | x = self.norm1(x)
104 |
105 | x = x + self.dropout(self.cross_attention(
106 | x, cross, cross,
107 | attn_mask=cross_mask,
108 | tau=tau, delta=delta
109 | )[0])
110 |
111 | y = x = self.norm2(x)
112 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
113 | y = self.dropout(self.conv2(y).transpose(-1, 1))
114 |
115 | return self.norm3(x + y)
116 |
117 |
118 | class Decoder(nn.Module):
119 | def __init__(self, layers, norm_layer=None, projection=None):
120 | super(Decoder, self).__init__()
121 | self.layers = nn.ModuleList(layers)
122 | self.norm = norm_layer
123 | self.projection = projection
124 |
125 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
126 | for layer in self.layers:
127 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta)
128 |
129 | if self.norm is not None:
130 | x = self.norm(x)
131 |
132 | if self.projection is not None:
133 | x = self.projection(x)
134 | return x
135 |
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/layers/__init__.py
--------------------------------------------------------------------------------
/model/Flashformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
5 | from layers.SelfAttention_Family import FlashAttention, AttentionLayer, FullAttention
6 | from layers.Embed import DataEmbedding
7 | import numpy as np
8 |
9 |
10 | class Model(nn.Module):
11 | """
12 | Vanilla Transformer
13 | with O(L^2) complexity
14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
15 | """
16 |
17 | def __init__(self, configs):
18 | super(Model, self).__init__()
19 | self.pred_len = configs.pred_len
20 | self.output_attention = configs.output_attention
21 | # Embedding
22 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
23 | configs.dropout)
24 | # Encoder
25 | self.encoder = Encoder(
26 | [
27 | EncoderLayer(
28 | AttentionLayer(
29 | FlashAttention(False, configs.factor, attention_dropout=configs.dropout,
30 | output_attention=configs.output_attention), configs.d_model, configs.n_heads),
31 | configs.d_model,
32 | configs.d_ff,
33 | dropout=configs.dropout,
34 | activation=configs.activation
35 | ) for l in range(configs.e_layers)
36 | ],
37 | norm_layer=torch.nn.LayerNorm(configs.d_model)
38 | )
39 | # Decoder
40 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
41 | configs.dropout)
42 | self.decoder = Decoder(
43 | [
44 | DecoderLayer(
45 | AttentionLayer(
46 | FullAttention(True, configs.factor, attention_dropout=configs.dropout,
47 | output_attention=False),
48 | configs.d_model, configs.n_heads),
49 | AttentionLayer(
50 | FullAttention(False, configs.factor, attention_dropout=configs.dropout,
51 | output_attention=False),
52 | configs.d_model, configs.n_heads),
53 | configs.d_model,
54 | configs.d_ff,
55 | dropout=configs.dropout,
56 | activation=configs.activation,
57 | )
58 | for l in range(configs.d_layers)
59 | ],
60 | norm_layer=torch.nn.LayerNorm(configs.d_model),
61 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
62 | )
63 |
64 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
65 | # Embedding
66 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
67 | enc_out, attns = self.encoder(enc_out, attn_mask=None)
68 |
69 | dec_out = self.dec_embedding(x_dec, x_mark_dec)
70 | dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
71 | return dec_out
72 |
73 |
74 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
75 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
76 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
77 |
--------------------------------------------------------------------------------
/model/Flowformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer, FlowAttention
6 | from layers.Embed import DataEmbedding
7 | import numpy as np
8 |
9 |
10 | class Model(nn.Module):
11 | """
12 | Vanilla Transformer
13 | with O(L^2) complexity
14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
15 | """
16 |
17 | def __init__(self, configs):
18 | super(Model, self).__init__()
19 | self.pred_len = configs.pred_len
20 | self.output_attention = configs.output_attention
21 |
22 | if configs.channel_independence:
23 | self.enc_in = 1
24 | self.dec_in = 1
25 | self.c_out = 1
26 | else:
27 | self.enc_in = configs.enc_in
28 | self.dec_in = configs.dec_in
29 | self.c_out = configs.c_out
30 |
31 | # Embedding
32 | self.enc_embedding = DataEmbedding(self.enc_in, configs.d_model, configs.embed, configs.freq,
33 | configs.dropout)
34 | # Encoder
35 | self.encoder = Encoder(
36 | [
37 | EncoderLayer(
38 | AttentionLayer(
39 | FlowAttention(attention_dropout=configs.dropout), configs.d_model, configs.n_heads),
40 | configs.d_model,
41 | configs.d_ff,
42 | dropout=configs.dropout,
43 | activation=configs.activation
44 | ) for l in range(configs.e_layers)
45 | ],
46 | norm_layer=torch.nn.LayerNorm(configs.d_model)
47 | )
48 | # Decoder
49 | self.dec_embedding = DataEmbedding(self.dec_in, configs.d_model, configs.embed, configs.freq,
50 | configs.dropout)
51 | self.decoder = Decoder(
52 | [
53 | DecoderLayer(
54 | AttentionLayer(
55 | FullAttention(True, configs.factor, attention_dropout=configs.dropout,
56 | output_attention=False),
57 | configs.d_model, configs.n_heads),
58 | AttentionLayer(
59 | FullAttention(False, configs.factor, attention_dropout=configs.dropout,
60 | output_attention=False),
61 | configs.d_model, configs.n_heads),
62 | configs.d_model,
63 | configs.d_ff,
64 | dropout=configs.dropout,
65 | activation=configs.activation,
66 | )
67 | for l in range(configs.d_layers)
68 | ],
69 | norm_layer=torch.nn.LayerNorm(configs.d_model),
70 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
71 | )
72 |
73 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
74 | # Embedding
75 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
76 | enc_out, attns = self.encoder(enc_out, attn_mask=None)
77 |
78 | dec_out = self.dec_embedding(x_dec, x_mark_dec)
79 | dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
80 | return dec_out
81 |
82 |
83 |
84 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
85 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
86 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
87 |
--------------------------------------------------------------------------------
/model/Informer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
5 | from layers.SelfAttention_Family import ProbAttention, AttentionLayer
6 | from layers.Embed import DataEmbedding
7 |
8 |
9 | class Model(nn.Module):
10 | """
11 | Informer with Propspare attention in O(LlogL) complexity
12 | Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/17325/17132
13 | """
14 |
15 | def __init__(self, configs):
16 | super(Model, self).__init__()
17 |
18 | self.pred_len = configs.pred_len
19 | self.label_len = configs.label_len
20 |
21 | if configs.channel_independence:
22 | self.enc_in = 1
23 | self.dec_in = 1
24 | self.c_out = 1
25 | else:
26 | self.enc_in = configs.enc_in
27 | self.dec_in = configs.dec_in
28 | self.c_out = configs.c_out
29 |
30 | # Embedding
31 | self.enc_embedding = DataEmbedding(self.enc_in, configs.d_model, configs.embed, configs.freq,
32 | configs.dropout)
33 | self.dec_embedding = DataEmbedding(self.dec_in, configs.d_model, configs.embed, configs.freq,
34 | configs.dropout)
35 |
36 | # Encoder
37 | self.encoder = Encoder(
38 | [
39 | EncoderLayer(
40 | AttentionLayer(
41 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout,
42 | output_attention=configs.output_attention),
43 | configs.d_model, configs.n_heads),
44 | configs.d_model,
45 | configs.d_ff,
46 | dropout=configs.dropout,
47 | activation=configs.activation
48 | ) for l in range(configs.e_layers)
49 | ],
50 | [
51 | ConvLayer(
52 | configs.d_model
53 | ) for l in range(configs.e_layers - 1)
54 | ] if configs.distil else None,
55 | norm_layer=torch.nn.LayerNorm(configs.d_model)
56 | )
57 | # Decoder
58 | self.decoder = Decoder(
59 | [
60 | DecoderLayer(
61 | AttentionLayer(
62 | ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False),
63 | configs.d_model, configs.n_heads),
64 | AttentionLayer(
65 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
66 | configs.d_model, configs.n_heads),
67 | configs.d_model,
68 | configs.d_ff,
69 | dropout=configs.dropout,
70 | activation=configs.activation,
71 | )
72 | for l in range(configs.d_layers)
73 | ],
74 | norm_layer=torch.nn.LayerNorm(configs.d_model),
75 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
76 | )
77 |
78 |
79 | def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
80 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
81 | dec_out = self.dec_embedding(x_dec, x_mark_dec)
82 | enc_out, attns = self.encoder(enc_out, attn_mask=None)
83 |
84 | dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
85 |
86 | return dec_out # [B, L, D]
87 |
88 |
89 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
90 | dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
91 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
92 |
--------------------------------------------------------------------------------
/model/Reformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Encoder, EncoderLayer
5 | from layers.SelfAttention_Family import ReformerLayer
6 | from layers.Embed import DataEmbedding
7 |
8 |
9 | class Model(nn.Module):
10 | """
11 | Reformer with O(LlogL) complexity
12 | Paper link: https://openreview.net/forum?id=rkgNKkHtvB
13 | """
14 |
15 | def __init__(self, configs, bucket_size=4, n_hashes=4):
16 | """
17 | bucket_size: int,
18 | n_hashes: int,
19 | """
20 | super(Model, self).__init__()
21 | self.pred_len = configs.pred_len
22 | self.seq_len = configs.seq_len
23 |
24 | if configs.channel_independence:
25 | self.enc_in = 1
26 | self.dec_in = 1
27 | self.c_out = 1
28 | else:
29 | self.enc_in = configs.enc_in
30 | self.dec_in = configs.dec_in
31 | self.c_out = configs.c_out
32 |
33 | self.enc_embedding = DataEmbedding(self.enc_in, configs.d_model, configs.embed, configs.freq,
34 | configs.dropout)
35 | # Encoder
36 | self.encoder = Encoder(
37 | [
38 | EncoderLayer(
39 | ReformerLayer(None, configs.d_model, configs.n_heads,
40 | bucket_size=bucket_size, n_hashes=n_hashes),
41 | configs.d_model,
42 | configs.d_ff,
43 | dropout=configs.dropout,
44 | activation=configs.activation
45 | ) for l in range(configs.e_layers)
46 | ],
47 | norm_layer=torch.nn.LayerNorm(configs.d_model)
48 | )
49 |
50 | self.projection = nn.Linear(
51 | configs.d_model, configs.c_out, bias=True)
52 |
53 | def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
54 | # add placeholder
55 | x_enc = torch.cat([x_enc, x_dec[:, -self.pred_len:, :]], dim=1)
56 | if x_mark_enc is not None:
57 | x_mark_enc = torch.cat(
58 | [x_mark_enc, x_mark_dec[:, -self.pred_len:, :]], dim=1)
59 |
60 | enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
61 | enc_out, attns = self.encoder(enc_out, attn_mask=None)
62 | dec_out = self.projection(enc_out)
63 |
64 | return dec_out # [B, L, D]
65 |
66 |
67 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
68 | dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
69 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
70 |
--------------------------------------------------------------------------------
/model/Transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer
6 | from layers.Embed import DataEmbedding
7 | import numpy as np
8 |
9 |
10 | class Model(nn.Module):
11 | """
12 | Vanilla Transformer
13 | with O(L^2) complexity
14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
15 | """
16 |
17 | def __init__(self, configs):
18 | super(Model, self).__init__()
19 | self.pred_len = configs.pred_len
20 | self.output_attention = configs.output_attention
21 |
22 | if configs.channel_independence:
23 | self.enc_in = 1
24 | self.dec_in = 1
25 | self.c_out = 1
26 | else:
27 | self.enc_in = configs.enc_in
28 | self.dec_in = configs.dec_in
29 | self.c_out = configs.c_out
30 |
31 | # Embedding
32 | self.enc_embedding = DataEmbedding(self.enc_in, configs.d_model, configs.embed, configs.freq,
33 | configs.dropout)
34 | # Encoder
35 | self.encoder = Encoder(
36 | [
37 | EncoderLayer(
38 | AttentionLayer(
39 | FullAttention(False, configs.factor, attention_dropout=configs.dropout,
40 | output_attention=configs.output_attention), configs.d_model, configs.n_heads),
41 | configs.d_model,
42 | configs.d_ff,
43 | dropout=configs.dropout,
44 | activation=configs.activation
45 | ) for l in range(configs.e_layers)
46 | ],
47 | norm_layer=torch.nn.LayerNorm(configs.d_model)
48 | )
49 | # Decoder
50 | self.dec_embedding = DataEmbedding(self.dec_in, configs.d_model, configs.embed, configs.freq,
51 | configs.dropout)
52 | self.decoder = Decoder(
53 | [
54 | DecoderLayer(
55 | AttentionLayer(
56 | FullAttention(True, configs.factor, attention_dropout=configs.dropout,
57 | output_attention=False),
58 | configs.d_model, configs.n_heads),
59 | AttentionLayer(
60 | FullAttention(False, configs.factor, attention_dropout=configs.dropout,
61 | output_attention=False),
62 | configs.d_model, configs.n_heads),
63 | configs.d_model,
64 | configs.d_ff,
65 | dropout=configs.dropout,
66 | activation=configs.activation,
67 | )
68 | for l in range(configs.d_layers)
69 | ],
70 | norm_layer=torch.nn.LayerNorm(configs.d_model),
71 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
72 | )
73 |
74 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
75 | # Embedding
76 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
77 | enc_out, attns = self.encoder(enc_out, attn_mask=None)
78 |
79 | dec_out = self.dec_embedding(x_dec, x_mark_dec)
80 | dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None)
81 | return dec_out
82 |
83 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
84 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
85 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
86 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/model/__init__.py
--------------------------------------------------------------------------------
/model/iFlashformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Encoder, EncoderLayer
5 | from layers.SelfAttention_Family import FlashAttention, AttentionLayer
6 | from layers.Embed import DataEmbedding_inverted
7 | import numpy as np
8 |
9 |
10 | class Model(nn.Module):
11 | """
12 | Vanilla Transformer
13 | with O(L^2) complexity
14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
15 | """
16 |
17 | def __init__(self, configs):
18 | super(Model, self).__init__()
19 | self.seq_len = configs.seq_len
20 | self.pred_len = configs.pred_len
21 | self.output_attention = configs.output_attention
22 | # Embedding
23 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
24 | configs.dropout)
25 | # Encoder-only architecture
26 | self.encoder = Encoder(
27 | [
28 | EncoderLayer(
29 | AttentionLayer(
30 | FlashAttention(False, configs.factor, attention_dropout=configs.dropout,
31 | output_attention=configs.output_attention), configs.d_model, configs.n_heads),
32 | configs.d_model,
33 | configs.d_ff,
34 | dropout=configs.dropout,
35 | activation=configs.activation
36 | ) for l in range(configs.e_layers)
37 | ],
38 | norm_layer=torch.nn.LayerNorm(configs.d_model)
39 | )
40 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True)
41 |
42 |
43 |
44 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
45 | # Normalization from Non-stationary Transformer
46 | means = x_enc.mean(1, keepdim=True).detach()
47 | x_enc = x_enc - means
48 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
49 | x_enc /= stdev
50 |
51 | _, _, N = x_enc.shape
52 |
53 | # Embedding
54 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
55 | enc_out, attns = self.encoder(enc_out, attn_mask=None)
56 |
57 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N]
58 | # De-Normalization from Non-stationary Transformer
59 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
60 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
61 | return dec_out, attns
62 |
63 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
64 | dec_out, attns = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
65 |
66 | if self.output_attention:
67 | return dec_out[:, -self.pred_len:, :], attns
68 | else:
69 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
70 |
--------------------------------------------------------------------------------
/model/iFlowformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Encoder, EncoderLayer
5 | from layers.SelfAttention_Family import FlowAttention, AttentionLayer
6 | from layers.Embed import DataEmbedding_inverted
7 | import numpy as np
8 |
9 |
10 | class Model(nn.Module):
11 | """
12 | Vanilla Transformer
13 | with O(L^2) complexity
14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
15 | """
16 |
17 | def __init__(self, configs):
18 | super(Model, self).__init__()
19 | self.seq_len = configs.seq_len
20 | self.pred_len = configs.pred_len
21 | self.output_attention = configs.output_attention
22 | # Embedding
23 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
24 | configs.dropout)
25 | # Encoder-only architecture
26 | self.encoder = Encoder(
27 | [
28 | EncoderLayer(
29 | AttentionLayer(
30 | FlowAttention(attention_dropout=configs.dropout), configs.d_model, configs.n_heads),
31 | configs.d_model,
32 | configs.d_ff,
33 | dropout=configs.dropout,
34 | activation=configs.activation
35 | ) for l in range(configs.e_layers)
36 | ],
37 | norm_layer=torch.nn.LayerNorm(configs.d_model)
38 | )
39 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True)
40 |
41 |
42 |
43 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
44 | # Normalization from Non-stationary Transformer
45 | means = x_enc.mean(1, keepdim=True).detach()
46 | x_enc = x_enc - means
47 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
48 | x_enc /= stdev
49 |
50 | _, _, N = x_enc.shape
51 |
52 | # Embedding
53 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
54 | enc_out, attns = self.encoder(enc_out, attn_mask=None)
55 |
56 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N]
57 | # De-Normalization from Non-stationary Transformer
58 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
59 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
60 | return dec_out, attns
61 |
62 |
63 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
64 | dec_out, attns = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
65 |
66 | if self.output_attention:
67 | return dec_out[:, -self.pred_len:, :], attns
68 | else:
69 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
70 |
--------------------------------------------------------------------------------
/model/iInformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Encoder, EncoderLayer
5 | from layers.SelfAttention_Family import ProbAttention, AttentionLayer
6 | from layers.Embed import DataEmbedding_inverted
7 | import numpy as np
8 |
9 |
10 | class Model(nn.Module):
11 | """
12 | Vanilla Transformer
13 | with O(L^2) complexity
14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
15 | """
16 |
17 | def __init__(self, configs):
18 | super(Model, self).__init__()
19 | self.seq_len = configs.seq_len
20 | self.pred_len = configs.pred_len
21 | self.output_attention = configs.output_attention
22 | # Embedding
23 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
24 | configs.dropout)
25 | # Encoder-only architecture
26 | self.encoder = Encoder(
27 | [
28 | EncoderLayer(
29 | AttentionLayer(
30 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout,
31 | output_attention=configs.output_attention), configs.d_model, configs.n_heads),
32 | configs.d_model,
33 | configs.d_ff,
34 | dropout=configs.dropout,
35 | activation=configs.activation
36 | ) for l in range(configs.e_layers)
37 | ],
38 | norm_layer=torch.nn.LayerNorm(configs.d_model)
39 | )
40 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True)
41 |
42 |
43 |
44 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
45 | # Normalization from Non-stationary Transformer
46 | means = x_enc.mean(1, keepdim=True).detach()
47 | x_enc = x_enc - means
48 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
49 | x_enc /= stdev
50 |
51 | _, _, N = x_enc.shape
52 |
53 | # Embedding
54 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
55 | enc_out, attns = self.encoder(enc_out, attn_mask=None)
56 |
57 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N]
58 | # De-Normalization from Non-stationary Transformer
59 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
60 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
61 | return dec_out, attns
62 |
63 |
64 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
65 | dec_out, attns = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
66 |
67 | if self.output_attention:
68 | return dec_out[:, -self.pred_len:, :], attns
69 | else:
70 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
71 |
--------------------------------------------------------------------------------
/model/iReformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Encoder, EncoderLayer
5 | from layers.SelfAttention_Family import ReformerLayer
6 | from layers.Embed import DataEmbedding_inverted
7 | import numpy as np
8 |
9 |
10 | class Model(nn.Module):
11 | """
12 | Vanilla Transformer
13 | with O(L^2) complexity
14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
15 | """
16 |
17 | def __init__(self, configs):
18 | super(Model, self).__init__()
19 | self.seq_len = configs.seq_len
20 | self.pred_len = configs.pred_len
21 | self.output_attention = configs.output_attention
22 | # Embedding
23 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
24 | configs.dropout)
25 | # Encoder-only architecture
26 | self.encoder = Encoder(
27 | [
28 | EncoderLayer(
29 | ReformerLayer(None, configs.d_model, configs.n_heads,
30 | bucket_size=4, n_hashes=4),
31 | configs.d_model,
32 | configs.d_ff,
33 | dropout=configs.dropout,
34 | activation=configs.activation
35 | ) for l in range(configs.e_layers)
36 | ],
37 | norm_layer=torch.nn.LayerNorm(configs.d_model)
38 | )
39 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True)
40 |
41 |
42 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
43 | # Normalization from Non-stationary Transformer
44 | means = x_enc.mean(1, keepdim=True).detach()
45 | x_enc = x_enc - means
46 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
47 | x_enc /= stdev
48 |
49 | _, _, N = x_enc.shape
50 |
51 | # Embedding
52 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
53 | enc_out, attns = self.encoder(enc_out, attn_mask=None)
54 |
55 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N]
56 | # De-Normalization from Non-stationary Transformer
57 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
58 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
59 | return dec_out, attns
60 |
61 |
62 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
63 | dec_out, attns = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
64 |
65 | if self.output_attention:
66 | return dec_out[:, -self.pred_len:, :], attns
67 | else:
68 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
--------------------------------------------------------------------------------
/model/iTransformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Encoder, EncoderLayer
5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer
6 | from layers.Embed import DataEmbedding_inverted
7 | import numpy as np
8 |
9 |
10 | class Model(nn.Module):
11 | """
12 | Paper link: https://arxiv.org/abs/2310.06625
13 | """
14 |
15 | def __init__(self, configs):
16 | super(Model, self).__init__()
17 | self.seq_len = configs.seq_len
18 | self.pred_len = configs.pred_len
19 | self.output_attention = configs.output_attention
20 | self.use_norm = configs.use_norm
21 | # Embedding
22 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
23 | configs.dropout)
24 | self.class_strategy = configs.class_strategy
25 | # Encoder-only architecture
26 | self.encoder = Encoder(
27 | [
28 | EncoderLayer(
29 | AttentionLayer(
30 | FullAttention(False, configs.factor, attention_dropout=configs.dropout,
31 | output_attention=configs.output_attention), configs.d_model, configs.n_heads),
32 | configs.d_model,
33 | configs.d_ff,
34 | dropout=configs.dropout,
35 | activation=configs.activation
36 | ) for l in range(configs.e_layers)
37 | ],
38 | norm_layer=torch.nn.LayerNorm(configs.d_model)
39 | )
40 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True)
41 |
42 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
43 | if self.use_norm:
44 | # Normalization from Non-stationary Transformer
45 | means = x_enc.mean(1, keepdim=True).detach()
46 | x_enc = x_enc - means
47 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
48 | x_enc /= stdev
49 |
50 | _, _, N = x_enc.shape # B L N
51 | # B: batch_size; E: d_model;
52 | # L: seq_len; S: pred_len;
53 | # N: number of variate (tokens), can also includes covariates
54 |
55 | # Embedding
56 | # B L N -> B N E (B L N -> B L E in the vanilla Transformer)
57 | enc_out = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens
58 |
59 | # B N E -> B N E (B L E -> B L E in the vanilla Transformer)
60 | # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules
61 | enc_out, attns = self.encoder(enc_out, attn_mask=None)
62 |
63 | # B N E -> B N S -> B S N
64 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates
65 |
66 | if self.use_norm:
67 | # De-Normalization from Non-stationary Transformer
68 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
69 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
70 |
71 | return dec_out, attns
72 |
73 |
74 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
75 | dec_out, attns = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
76 |
77 | if self.output_attention:
78 | return dec_out[:, -self.pred_len:, :], attns
79 | else:
80 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas==1.5.3
2 | scikit-learn==1.2.2
3 | numpy==1.23.5
4 | matplotlib==3.7.0
5 | torch==2.0.0
6 | reformer-pytorch==1.4.4
7 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from experiments.exp_long_term_forecasting import Exp_Long_Term_Forecast
4 | from experiments.exp_long_term_forecasting_partial import Exp_Long_Term_Forecast_Partial
5 | import random
6 | import numpy as np
7 |
8 | if __name__ == '__main__':
9 | fix_seed = 2023
10 | random.seed(fix_seed)
11 | torch.manual_seed(fix_seed)
12 | np.random.seed(fix_seed)
13 |
14 | parser = argparse.ArgumentParser(description='iTransformer')
15 |
16 | # basic config
17 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status')
18 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id')
19 | parser.add_argument('--model', type=str, required=True, default='iTransformer',
20 | help='model name, options: [iTransformer, iInformer, iReformer, iFlowformer, iFlashformer]')
21 |
22 | # data loader
23 | parser.add_argument('--data', type=str, required=True, default='custom', help='dataset type')
24 | parser.add_argument('--root_path', type=str, default='./data/electricity/', help='root path of the data file')
25 | parser.add_argument('--data_path', type=str, default='electricity.csv', help='data csv file')
26 | parser.add_argument('--features', type=str, default='M',
27 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
28 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
29 | parser.add_argument('--freq', type=str, default='h',
30 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
31 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
32 |
33 | # forecasting task
34 | parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
35 | parser.add_argument('--label_len', type=int, default=48, help='start token length') # no longer needed in inverted Transformers
36 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
37 |
38 | # model define
39 | parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
40 | parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
41 | parser.add_argument('--c_out', type=int, default=7, help='output size') # applicable on arbitrary number of variates in inverted Transformers
42 | parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
43 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
44 | parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
45 | parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
46 | parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
47 | parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
48 | parser.add_argument('--factor', type=int, default=1, help='attn factor')
49 | parser.add_argument('--distil', action='store_false',
50 | help='whether to use distilling in encoder, using this argument means not using distilling',
51 | default=True)
52 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
53 | parser.add_argument('--embed', type=str, default='timeF',
54 | help='time features encoding, options:[timeF, fixed, learned]')
55 | parser.add_argument('--activation', type=str, default='gelu', help='activation')
56 | parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
57 | parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data')
58 |
59 | # optimization
60 | parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
61 | parser.add_argument('--itr', type=int, default=1, help='experiments times')
62 | parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
63 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
64 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience')
65 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
66 | parser.add_argument('--des', type=str, default='test', help='exp description')
67 | parser.add_argument('--loss', type=str, default='MSE', help='loss function')
68 | parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
69 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
70 |
71 | # GPU
72 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
73 | parser.add_argument('--gpu', type=int, default=0, help='gpu')
74 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
75 | parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
76 |
77 | # iTransformer
78 | parser.add_argument('--exp_name', type=str, required=False, default='MTSF',
79 | help='experiemnt name, options:[MTSF, partial_train]')
80 | parser.add_argument('--channel_independence', type=bool, default=False, help='whether to use channel_independence mechanism')
81 | parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False)
82 | parser.add_argument('--class_strategy', type=str, default='projection', help='projection/average/cls_token')
83 | parser.add_argument('--target_root_path', type=str, default='./data/electricity/', help='root path of the data file')
84 | parser.add_argument('--target_data_path', type=str, default='electricity.csv', help='data file')
85 | parser.add_argument('--efficient_training', type=bool, default=False, help='whether to use efficient_training (exp_name should be partial train)') # See Figure 8 of our paper for the detail
86 | parser.add_argument('--use_norm', type=int, default=True, help='use norm and denorm')
87 | parser.add_argument('--partial_start_index', type=int, default=0, help='the start index of variates for partial training, '
88 | 'you can select [partial_start_index, min(enc_in + partial_start_index, N)]')
89 |
90 | args = parser.parse_args()
91 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
92 |
93 | if args.use_gpu and args.use_multi_gpu:
94 | args.devices = args.devices.replace(' ', '')
95 | device_ids = args.devices.split(',')
96 | args.device_ids = [int(id_) for id_ in device_ids]
97 | args.gpu = args.device_ids[0]
98 |
99 | print('Args in experiment:')
100 | print(args)
101 |
102 | if args.exp_name == 'partial_train': # See Figure 8 of our paper, for the detail
103 | Exp = Exp_Long_Term_Forecast_Partial
104 | else: # MTSF: multivariate time series forecasting
105 | Exp = Exp_Long_Term_Forecast
106 |
107 |
108 | if args.is_training:
109 | for ii in range(args.itr):
110 | # setting record of experiments
111 | setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format(
112 | args.model_id,
113 | args.model,
114 | args.data,
115 | args.features,
116 | args.seq_len,
117 | args.label_len,
118 | args.pred_len,
119 | args.d_model,
120 | args.n_heads,
121 | args.e_layers,
122 | args.d_layers,
123 | args.d_ff,
124 | args.factor,
125 | args.embed,
126 | args.distil,
127 | args.des,
128 | args.class_strategy, ii)
129 |
130 | exp = Exp(args) # set experiments
131 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
132 | exp.train(setting)
133 |
134 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
135 | exp.test(setting)
136 |
137 | if args.do_predict:
138 | print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
139 | exp.predict(setting, True)
140 |
141 | torch.cuda.empty_cache()
142 | else:
143 | ii = 0
144 | setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format(
145 | args.model_id,
146 | args.model,
147 | args.data,
148 | args.features,
149 | args.seq_len,
150 | args.label_len,
151 | args.pred_len,
152 | args.d_model,
153 | args.n_heads,
154 | args.e_layers,
155 | args.d_layers,
156 | args.d_ff,
157 | args.factor,
158 | args.embed,
159 | args.distil,
160 | args.des,
161 | args.class_strategy, ii)
162 |
163 | exp = Exp(args) # set experiments
164 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
165 | exp.test(setting, test=1)
166 | torch.cuda.empty_cache()
167 |
--------------------------------------------------------------------------------
/scripts/boost_performance/ECL/iFlowformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_name=iFlowformer
4 | # model_name=Flowformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/electricity/ \
9 | --data_path electricity.csv \
10 | --model_id ECL_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 321 \
18 | --dec_in 321 \
19 | --c_out 321 \
20 | --des 'Exp' \
21 | --itr 1
22 |
23 | python -u run.py \
24 | --is_training 1 \
25 | --root_path ./dataset/electricity/ \
26 | --data_path electricity.csv \
27 | --model_id ECL_96_192 \
28 | --model $model_name \
29 | --data custom \
30 | --features M \
31 | --seq_len 96 \
32 | --pred_len 192 \
33 | --e_layers 2 \
34 | --enc_in 321 \
35 | --dec_in 321 \
36 | --c_out 321 \
37 | --des 'Exp' \
38 | --itr 1
39 |
40 | python -u run.py \
41 | --is_training 1 \
42 | --root_path ./dataset/electricity/ \
43 | --data_path electricity.csv \
44 | --model_id ECL_96_336 \
45 | --model $model_name \
46 | --data custom \
47 | --features M \
48 | --seq_len 96 \
49 | --pred_len 336 \
50 | --e_layers 2 \
51 | --enc_in 321 \
52 | --dec_in 321 \
53 | --c_out 321 \
54 | --des 'Exp' \
55 | --itr 1
56 |
57 | python -u run.py \
58 | --is_training 1 \
59 | --root_path ./dataset/electricity/ \
60 | --data_path electricity.csv \
61 | --model_id ECL_96_720 \
62 | --model $model_name \
63 | --data custom \
64 | --features M \
65 | --seq_len 96 \
66 | --pred_len 720 \
67 | --e_layers 2 \
68 | --enc_in 321 \
69 | --dec_in 321 \
70 | --c_out 321 \
71 | --des 'Exp' \
72 | --itr 1
73 |
--------------------------------------------------------------------------------
/scripts/boost_performance/ECL/iInformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_name=iInformer
4 | # model_name=Informer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/electricity/ \
9 | --data_path electricity.csv \
10 | --model_id ECL_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 321 \
18 | --dec_in 321 \
19 | --c_out 321 \
20 | --des 'Exp' \
21 | --itr 1
22 |
23 | python -u run.py \
24 | --is_training 1 \
25 | --root_path ./dataset/electricity/ \
26 | --data_path electricity.csv \
27 | --model_id ECL_96_192 \
28 | --model $model_name \
29 | --data custom \
30 | --features M \
31 | --seq_len 96 \
32 | --pred_len 192 \
33 | --e_layers 2 \
34 | --enc_in 321 \
35 | --dec_in 321 \
36 | --c_out 321 \
37 | --des 'Exp' \
38 | --itr 1
39 |
40 | python -u run.py \
41 | --is_training 1 \
42 | --root_path ./dataset/electricity/ \
43 | --data_path electricity.csv \
44 | --model_id ECL_96_336 \
45 | --model $model_name \
46 | --data custom \
47 | --features M \
48 | --seq_len 96 \
49 | --pred_len 336 \
50 | --e_layers 2 \
51 | --enc_in 321 \
52 | --dec_in 321 \
53 | --c_out 321 \
54 | --des 'Exp' \
55 | --itr 1
56 |
57 | python -u run.py \
58 | --is_training 1 \
59 | --root_path ./dataset/electricity/ \
60 | --data_path electricity.csv \
61 | --model_id ECL_96_720 \
62 | --model $model_name \
63 | --data custom \
64 | --features M \
65 | --seq_len 96 \
66 | --pred_len 720 \
67 | --e_layers 2 \
68 | --enc_in 321 \
69 | --dec_in 321 \
70 | --c_out 321 \
71 | --des 'Exp' \
72 | --itr 1
73 |
--------------------------------------------------------------------------------
/scripts/boost_performance/ECL/iReformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_name=iReformer
4 | # model_name=Reformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/electricity/ \
9 | --data_path electricity.csv \
10 | --model_id ECL_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 321 \
18 | --dec_in 321 \
19 | --c_out 321 \
20 | --des 'Exp' \
21 | --itr 1
22 |
23 | python -u run.py \
24 | --is_training 1 \
25 | --root_path ./dataset/electricity/ \
26 | --data_path electricity.csv \
27 | --model_id ECL_96_192 \
28 | --model $model_name \
29 | --data custom \
30 | --features M \
31 | --seq_len 96 \
32 | --pred_len 192 \
33 | --e_layers 2 \
34 | --enc_in 321 \
35 | --dec_in 321 \
36 | --c_out 321 \
37 | --des 'Exp' \
38 | --itr 1
39 |
40 | python -u run.py \
41 | --is_training 1 \
42 | --root_path ./dataset/electricity/ \
43 | --data_path electricity.csv \
44 | --model_id ECL_96_336 \
45 | --model $model_name \
46 | --data custom \
47 | --features M \
48 | --seq_len 96 \
49 | --pred_len 336 \
50 | --e_layers 2 \
51 | --enc_in 321 \
52 | --dec_in 321 \
53 | --c_out 321 \
54 | --des 'Exp' \
55 | --itr 1
56 |
57 | python -u run.py \
58 | --is_training 1 \
59 | --root_path ./dataset/electricity/ \
60 | --data_path electricity.csv \
61 | --model_id ECL_96_720 \
62 | --model $model_name \
63 | --data custom \
64 | --features M \
65 | --seq_len 96 \
66 | --pred_len 720 \
67 | --e_layers 2 \
68 | --enc_in 321 \
69 | --dec_in 321 \
70 | --c_out 321 \
71 | --des 'Exp' \
72 | --itr 1
73 |
--------------------------------------------------------------------------------
/scripts/boost_performance/ECL/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_name=iTransformer
4 | # model_name=Transformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/electricity/ \
9 | --data_path electricity.csv \
10 | --model_id ECL_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 321 \
18 | --dec_in 321 \
19 | --c_out 321 \
20 | --des 'Exp' \
21 | --itr 1
22 |
23 | python -u run.py \
24 | --is_training 1 \
25 | --root_path ./dataset/electricity/ \
26 | --data_path electricity.csv \
27 | --model_id ECL_96_192 \
28 | --model $model_name \
29 | --data custom \
30 | --features M \
31 | --seq_len 96 \
32 | --pred_len 192 \
33 | --e_layers 2 \
34 | --enc_in 321 \
35 | --dec_in 321 \
36 | --c_out 321 \
37 | --des 'Exp' \
38 | --itr 1
39 |
40 | python -u run.py \
41 | --is_training 1 \
42 | --root_path ./dataset/electricity/ \
43 | --data_path electricity.csv \
44 | --model_id ECL_96_336 \
45 | --model $model_name \
46 | --data custom \
47 | --features M \
48 | --seq_len 96 \
49 | --pred_len 336 \
50 | --e_layers 2 \
51 | --enc_in 321 \
52 | --dec_in 321 \
53 | --c_out 321 \
54 | --des 'Exp' \
55 | --itr 1
56 |
57 | python -u run.py \
58 | --is_training 1 \
59 | --root_path ./dataset/electricity/ \
60 | --data_path electricity.csv \
61 | --model_id ECL_96_720 \
62 | --model $model_name \
63 | --data custom \
64 | --features M \
65 | --seq_len 96 \
66 | --pred_len 720 \
67 | --e_layers 2 \
68 | --enc_in 321 \
69 | --dec_in 321 \
70 | --c_out 321 \
71 | --des 'Exp' \
72 | --itr 1
73 |
--------------------------------------------------------------------------------
/scripts/boost_performance/README.md:
--------------------------------------------------------------------------------
1 | # Inverted Transformers Work Better for Time Series Forecasting
2 |
3 | This folder contains the comparison of the vanilla Transformer-based forecasters and the inverted versions. If you are new to this repo, we recommend you to have a look at this [README](../multivariate_forecasting/README.md) first.
4 |
5 | ## Scripts
6 |
7 | In each folder named after the dataset, we compare the performance of iTransformers and the vanilla Transformers.
8 |
9 | ```
10 | # iTransformer on the Traffic Dataset with gradually enlarged lookback windows.
11 |
12 | bash ./scripts/boost_performance/Traffic/iTransformer.sh
13 | ```
14 |
15 | You can change the ```model_name``` in the script to select one Transformer variant and its inverted version.
16 |
17 | ## Results
18 | We compare the performance of Transformer and iTransformer on all six datasets, indicating that the attention and feed-forward network on the
19 | inverted dimensions greatly empower Transformers in multivariate time series forecasting.
20 |
21 |
22 |
23 |
24 |
25 | We apply the proposed inverted framework to Transformer and its variants. It demonstrates that our iTransformers framework can consistently promote these Transformer variants,
26 | and take advantage of the booming efficient attention mechanisms.
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/scripts/boost_performance/Traffic/iFlowformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=iFlowformer
4 | # model_name=Flowformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/traffic/ \
9 | --data_path traffic.csv \
10 | --model_id traffic_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 862 \
18 | --dec_in 862 \
19 | --c_out 862 \
20 | --des 'Exp' \
21 | --itr 1 \
22 | --train_epochs 3
23 |
24 | python -u run.py \
25 | --is_training 1 \
26 | --root_path ./dataset/traffic/ \
27 | --data_path traffic.csv \
28 | --model_id traffic_96_192 \
29 | --model $model_name \
30 | --data custom \
31 | --features M \
32 | --seq_len 96 \
33 | --pred_len 192 \
34 | --e_layers 2 \
35 | --enc_in 862 \
36 | --dec_in 862 \
37 | --c_out 862 \
38 | --des 'Exp' \
39 | --itr 1 \
40 | --train_epochs 3
41 |
42 | python -u run.py \
43 | --is_training 1 \
44 | --root_path ./dataset/traffic/ \
45 | --data_path traffic.csv \
46 | --model_id traffic_96_336 \
47 | --model $model_name \
48 | --data custom \
49 | --features M \
50 | --seq_len 96 \
51 | --pred_len 336 \
52 | --e_layers 2 \
53 | --enc_in 862 \
54 | --dec_in 862 \
55 | --c_out 862 \
56 | --des 'Exp' \
57 | --itr 1 \
58 | --train_epochs 3
59 |
60 | python -u run.py \
61 | --is_training 1 \
62 | --root_path ./dataset/traffic/ \
63 | --data_path traffic.csv \
64 | --model_id traffic_96_720 \
65 | --model $model_name \
66 | --data custom \
67 | --features M \
68 | --seq_len 96 \
69 | --pred_len 720 \
70 | --e_layers 2 \
71 | --enc_in 862 \
72 | --dec_in 862 \
73 | --c_out 862 \
74 | --des 'Exp' \
75 | --itr 1 \
76 | --train_epochs 3
77 |
--------------------------------------------------------------------------------
/scripts/boost_performance/Traffic/iInformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | # model_name=iInformer
4 | model_name=Informer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/traffic/ \
9 | --data_path traffic.csv \
10 | --model_id traffic_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 862 \
18 | --dec_in 862 \
19 | --c_out 862 \
20 | --des 'Exp' \
21 | --itr 1 \
22 | --train_epochs 3
23 |
24 | python -u run.py \
25 | --is_training 1 \
26 | --root_path ./dataset/traffic/ \
27 | --data_path traffic.csv \
28 | --model_id traffic_96_192 \
29 | --model $model_name \
30 | --data custom \
31 | --features M \
32 | --seq_len 96 \
33 | --pred_len 192 \
34 | --e_layers 2 \
35 | --enc_in 862 \
36 | --dec_in 862 \
37 | --c_out 862 \
38 | --des 'Exp' \
39 | --itr 1 \
40 | --train_epochs 3
41 |
42 | python -u run.py \
43 | --is_training 1 \
44 | --root_path ./dataset/traffic/ \
45 | --data_path traffic.csv \
46 | --model_id traffic_96_336 \
47 | --model $model_name \
48 | --data custom \
49 | --features M \
50 | --seq_len 96 \
51 | --pred_len 336 \
52 | --e_layers 2 \
53 | --enc_in 862 \
54 | --dec_in 862 \
55 | --c_out 862 \
56 | --des 'Exp' \
57 | --itr 1 \
58 | --train_epochs 3
59 |
60 | python -u run.py \
61 | --is_training 1 \
62 | --root_path ./dataset/traffic/ \
63 | --data_path traffic.csv \
64 | --model_id traffic_96_720 \
65 | --model $model_name \
66 | --data custom \
67 | --features M \
68 | --seq_len 96 \
69 | --pred_len 720 \
70 | --e_layers 2 \
71 | --enc_in 862 \
72 | --dec_in 862 \
73 | --c_out 862 \
74 | --des 'Exp' \
75 | --itr 1 \
76 | --train_epochs 3
77 |
--------------------------------------------------------------------------------
/scripts/boost_performance/Traffic/iReformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=iReformer
4 | # model_name=Reformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/traffic/ \
9 | --data_path traffic.csv \
10 | --model_id traffic_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 862 \
18 | --dec_in 862 \
19 | --c_out 862 \
20 | --des 'Exp' \
21 | --itr 1 \
22 | --train_epochs 3
23 |
24 | python -u run.py \
25 | --is_training 1 \
26 | --root_path ./dataset/traffic/ \
27 | --data_path traffic.csv \
28 | --model_id traffic_96_192 \
29 | --model $model_name \
30 | --data custom \
31 | --features M \
32 | --seq_len 96 \
33 | --pred_len 192 \
34 | --e_layers 2 \
35 | --enc_in 862 \
36 | --dec_in 862 \
37 | --c_out 862 \
38 | --des 'Exp' \
39 | --itr 1 \
40 | --train_epochs 3
41 |
42 | python -u run.py \
43 | --is_training 1 \
44 | --root_path ./dataset/traffic/ \
45 | --data_path traffic.csv \
46 | --model_id traffic_96_336 \
47 | --model $model_name \
48 | --data custom \
49 | --features M \
50 | --seq_len 96 \
51 | --pred_len 336 \
52 | --e_layers 2 \
53 | --enc_in 862 \
54 | --dec_in 862 \
55 | --c_out 862 \
56 | --des 'Exp' \
57 | --itr 1 \
58 | --train_epochs 3
59 |
60 | python -u run.py \
61 | --is_training 1 \
62 | --root_path ./dataset/traffic/ \
63 | --data_path traffic.csv \
64 | --model_id traffic_96_720 \
65 | --model $model_name \
66 | --data custom \
67 | --features M \
68 | --seq_len 96 \
69 | --pred_len 720 \
70 | --e_layers 2 \
71 | --enc_in 862 \
72 | --dec_in 862 \
73 | --c_out 862 \
74 | --des 'Exp' \
75 | --itr 1 \
76 | --train_epochs 3
77 |
--------------------------------------------------------------------------------
/scripts/boost_performance/Traffic/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=iTransformer
4 | #model_name=Transformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/traffic/ \
9 | --data_path traffic.csv \
10 | --model_id traffic_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 862 \
18 | --dec_in 862 \
19 | --c_out 862 \
20 | --des 'Exp' \
21 | --itr 1 \
22 | --train_epochs 3
23 |
24 | python -u run.py \
25 | --is_training 1 \
26 | --root_path ./dataset/traffic/ \
27 | --data_path traffic.csv \
28 | --model_id traffic_96_192 \
29 | --model $model_name \
30 | --data custom \
31 | --features M \
32 | --seq_len 96 \
33 | --pred_len 192 \
34 | --e_layers 2 \
35 | --enc_in 862 \
36 | --dec_in 862 \
37 | --c_out 862 \
38 | --des 'Exp' \
39 | --itr 1 \
40 | --train_epochs 3
41 |
42 | python -u run.py \
43 | --is_training 1 \
44 | --root_path ./dataset/traffic/ \
45 | --data_path traffic.csv \
46 | --model_id traffic_96_336 \
47 | --model $model_name \
48 | --data custom \
49 | --features M \
50 | --seq_len 96 \
51 | --pred_len 336 \
52 | --e_layers 2 \
53 | --enc_in 862 \
54 | --dec_in 862 \
55 | --c_out 862 \
56 | --des 'Exp' \
57 | --itr 1 \
58 | --train_epochs 3
59 |
60 | python -u run.py \
61 | --is_training 1 \
62 | --root_path ./dataset/traffic/ \
63 | --data_path traffic.csv \
64 | --model_id traffic_96_720 \
65 | --model $model_name \
66 | --data custom \
67 | --features M \
68 | --seq_len 96 \
69 | --pred_len 720 \
70 | --e_layers 2 \
71 | --enc_in 862 \
72 | --dec_in 862 \
73 | --c_out 862 \
74 | --des 'Exp' \
75 | --itr 1 \
76 | --train_epochs 3
77 |
--------------------------------------------------------------------------------
/scripts/boost_performance/Weather/iFlowformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=iFlowformer
4 | # model_name=Flowformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/weather/ \
9 | --data_path weather.csv \
10 | --model_id weather_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 21 \
18 | --dec_in 21 \
19 | --c_out 21 \
20 | --des 'Exp' \
21 | --itr 1 \
22 | --batch_size 128 \
23 | --train_epochs 3
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/weather/ \
28 | --data_path weather.csv \
29 | --model_id weather_96_192 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --pred_len 192 \
35 | --e_layers 2 \
36 | --enc_in 21 \
37 | --dec_in 21 \
38 | --c_out 21 \
39 | --des 'Exp' \
40 | --batch_size 128 \
41 | --itr 1
42 |
43 | python -u run.py \
44 | --is_training 1 \
45 | --root_path ./dataset/weather/ \
46 | --data_path weather.csv \
47 | --model_id weather_96_336 \
48 | --model $model_name \
49 | --data custom \
50 | --features M \
51 | --seq_len 96 \
52 | --pred_len 336 \
53 | --e_layers 2 \
54 | --enc_in 21 \
55 | --dec_in 21 \
56 | --c_out 21 \
57 | --des 'Exp' \
58 | --batch_size 128 \
59 | --itr 1
60 |
61 | python -u run.py \
62 | --is_training 1 \
63 | --root_path ./dataset/weather/ \
64 | --data_path weather.csv \
65 | --model_id weather_96_720 \
66 | --model $model_name \
67 | --data custom \
68 | --features M \
69 | --seq_len 96 \
70 | --pred_len 720 \
71 | --e_layers 2 \
72 | --enc_in 21 \
73 | --dec_in 21 \
74 | --c_out 21 \
75 | --des 'Exp' \
76 | --batch_size 128 \
77 | --itr 1
78 |
--------------------------------------------------------------------------------
/scripts/boost_performance/Weather/iInformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=iInformer
4 | #model_name=Informer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/weather/ \
9 | --data_path weather.csv \
10 | --model_id weather_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 21 \
18 | --dec_in 21 \
19 | --c_out 21 \
20 | --des 'Exp' \
21 | --itr 1 \
22 | --batch_size 128 \
23 | --train_epochs 3
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/weather/ \
28 | --data_path weather.csv \
29 | --model_id weather_96_192 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --pred_len 192 \
35 | --e_layers 2 \
36 | --enc_in 21 \
37 | --dec_in 21 \
38 | --c_out 21 \
39 | --des 'Exp' \
40 | --batch_size 128 \
41 | --itr 1
42 |
43 | python -u run.py \
44 | --is_training 1 \
45 | --root_path ./dataset/weather/ \
46 | --data_path weather.csv \
47 | --model_id weather_96_336 \
48 | --model $model_name \
49 | --data custom \
50 | --features M \
51 | --seq_len 96 \
52 | --pred_len 336 \
53 | --e_layers 2 \
54 | --enc_in 21 \
55 | --dec_in 21 \
56 | --c_out 21 \
57 | --des 'Exp' \
58 | --batch_size 128 \
59 | --itr 1
60 |
61 | python -u run.py \
62 | --is_training 1 \
63 | --root_path ./dataset/weather/ \
64 | --data_path weather.csv \
65 | --model_id weather_96_720 \
66 | --model $model_name \
67 | --data custom \
68 | --features M \
69 | --seq_len 96 \
70 | --pred_len 720 \
71 | --e_layers 2 \
72 | --enc_in 21 \
73 | --dec_in 21 \
74 | --c_out 21 \
75 | --des 'Exp' \
76 | --batch_size 128 \
77 | --itr 1
78 |
--------------------------------------------------------------------------------
/scripts/boost_performance/Weather/iReformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=iReformer
4 | #model_name=Reformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/weather/ \
9 | --data_path weather.csv \
10 | --model_id weather_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 21 \
18 | --dec_in 21 \
19 | --c_out 21 \
20 | --des 'Exp' \
21 | --itr 1 \
22 | --batch_size 128 \
23 | --train_epochs 3
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/weather/ \
28 | --data_path weather.csv \
29 | --model_id weather_96_192 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --pred_len 192 \
35 | --e_layers 2 \
36 | --enc_in 21 \
37 | --dec_in 21 \
38 | --c_out 21 \
39 | --des 'Exp' \
40 | --batch_size 128 \
41 | --itr 1
42 |
43 | python -u run.py \
44 | --is_training 1 \
45 | --root_path ./dataset/weather/ \
46 | --data_path weather.csv \
47 | --model_id weather_96_336 \
48 | --model $model_name \
49 | --data custom \
50 | --features M \
51 | --seq_len 96 \
52 | --pred_len 336 \
53 | --e_layers 2 \
54 | --enc_in 21 \
55 | --dec_in 21 \
56 | --c_out 21 \
57 | --des 'Exp' \
58 | --batch_size 128 \
59 | --itr 1
60 |
61 | python -u run.py \
62 | --is_training 1 \
63 | --root_path ./dataset/weather/ \
64 | --data_path weather.csv \
65 | --model_id weather_96_720 \
66 | --model $model_name \
67 | --data custom \
68 | --features M \
69 | --seq_len 96 \
70 | --pred_len 720 \
71 | --e_layers 2 \
72 | --enc_in 21 \
73 | --dec_in 21 \
74 | --c_out 21 \
75 | --des 'Exp' \
76 | --batch_size 128 \
77 | --itr 1
78 |
--------------------------------------------------------------------------------
/scripts/boost_performance/Weather/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=iTransformer
4 | #model_name=Transformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/weather/ \
9 | --data_path weather.csv \
10 | --model_id weather_96_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 96 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --enc_in 21 \
18 | --dec_in 21 \
19 | --c_out 21 \
20 | --des 'Exp' \
21 | --itr 1 \
22 | --batch_size 128 \
23 | --train_epochs 3
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/weather/ \
28 | --data_path weather.csv \
29 | --model_id weather_96_192 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --pred_len 192 \
35 | --e_layers 2 \
36 | --enc_in 21 \
37 | --dec_in 21 \
38 | --c_out 21 \
39 | --des 'Exp' \
40 | --batch_size 128 \
41 | --itr 1
42 |
43 | python -u run.py \
44 | --is_training 1 \
45 | --root_path ./dataset/weather/ \
46 | --data_path weather.csv \
47 | --model_id weather_96_336 \
48 | --model $model_name \
49 | --data custom \
50 | --features M \
51 | --seq_len 96 \
52 | --pred_len 336 \
53 | --e_layers 2 \
54 | --enc_in 21 \
55 | --dec_in 21 \
56 | --c_out 21 \
57 | --des 'Exp' \
58 | --batch_size 128 \
59 | --itr 1
60 |
61 | python -u run.py \
62 | --is_training 1 \
63 | --root_path ./dataset/weather/ \
64 | --data_path weather.csv \
65 | --model_id weather_96_720 \
66 | --model $model_name \
67 | --data custom \
68 | --features M \
69 | --seq_len 96 \
70 | --pred_len 720 \
71 | --e_layers 2 \
72 | --enc_in 21 \
73 | --dec_in 21 \
74 | --c_out 21 \
75 | --des 'Exp' \
76 | --batch_size 128 \
77 | --itr 1
78 |
--------------------------------------------------------------------------------
/scripts/increasing_lookback/ECL/iFlowformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | # model_name=Flowformer
4 | model_name=iFlowformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/electricity/ \
9 | --data_path electricity.csv \
10 | --model_id ECL_48_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 48 \
15 | --pred_len 96 \
16 | --e_layers 3 \
17 | --enc_in 321 \
18 | --dec_in 321 \
19 | --c_out 321 \
20 | --des 'Exp' \
21 | --d_model 512 \
22 | --d_ff 512 \
23 | --batch_size 16 \
24 | --learning_rate 0.0005\
25 | --itr 1
26 |
27 | python -u run.py \
28 | --is_training 1 \
29 | --root_path ./dataset/electricity/ \
30 | --data_path electricity.csv \
31 | --model_id ECL_96_96 \
32 | --model $model_name \
33 | --data custom \
34 | --features M \
35 | --seq_len 96 \
36 | --pred_len 96 \
37 | --e_layers 3 \
38 | --enc_in 321 \
39 | --dec_in 321 \
40 | --c_out 321 \
41 | --des 'Exp' \
42 | --d_model 512 \
43 | --d_ff 512 \
44 | --batch_size 16 \
45 | --learning_rate 0.0005\
46 | --itr 1
47 |
48 | python -u run.py \
49 | --is_training 1 \
50 | --root_path ./dataset/electricity/ \
51 | --data_path electricity.csv \
52 | --model_id ECL_192_96 \
53 | --model $model_name \
54 | --data custom \
55 | --features M \
56 | --seq_len 192 \
57 | --pred_len 96 \
58 | --e_layers 3 \
59 | --enc_in 321 \
60 | --dec_in 321 \
61 | --c_out 321 \
62 | --des 'Exp' \
63 | --d_model 512 \
64 | --d_ff 512 \
65 | --batch_size 16 \
66 | --learning_rate 0.0005 \
67 | --itr 1
68 |
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path ./dataset/electricity/ \
72 | --data_path electricity.csv \
73 | --model_id ECL_336_96 \
74 | --model $model_name \
75 | --data custom \
76 | --features M \
77 | --seq_len 336 \
78 | --pred_len 96 \
79 | --e_layers 3 \
80 | --enc_in 321 \
81 | --dec_in 321 \
82 | --c_out 321 \
83 | --des 'Exp' \
84 | --d_model 512 \
85 | --d_ff 512 \
86 | --batch_size 16 \
87 | --learning_rate 0.0005 \
88 | --itr 1
89 |
90 | python -u run.py \
91 | --is_training 1 \
92 | --root_path ./dataset/electricity/ \
93 | --data_path electricity.csv \
94 | --model_id ECL_720_96 \
95 | --model $model_name \
96 | --data custom \
97 | --features M \
98 | --seq_len 720 \
99 | --pred_len 96 \
100 | --e_layers 3 \
101 | --enc_in 321 \
102 | --dec_in 321 \
103 | --c_out 321 \
104 | --des 'Exp' \
105 | --d_model 512 \
106 | --d_ff 512 \
107 | --batch_size 16 \
108 | --learning_rate 0.0005 \
109 | --itr 1
110 |
--------------------------------------------------------------------------------
/scripts/increasing_lookback/ECL/iInformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | # model_name=Informer
4 | model_name=iInformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/electricity/ \
9 | --data_path electricity.csv \
10 | --model_id ECL_48_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 48 \
15 | --pred_len 96 \
16 | --e_layers 3 \
17 | --enc_in 321 \
18 | --dec_in 321 \
19 | --c_out 321 \
20 | --des 'Exp' \
21 | --d_model 512 \
22 | --d_ff 512 \
23 | --batch_size 16 \
24 | --learning_rate 0.0005 \
25 | --itr 1
26 |
27 | python -u run.py \
28 | --is_training 1 \
29 | --root_path ./dataset/electricity/ \
30 | --data_path electricity.csv \
31 | --model_id ECL_96_96 \
32 | --model $model_name \
33 | --data custom \
34 | --features M \
35 | --seq_len 96 \
36 | --pred_len 96 \
37 | --e_layers 3 \
38 | --enc_in 321 \
39 | --dec_in 321 \
40 | --c_out 321 \
41 | --des 'Exp' \
42 | --d_model 512 \
43 | --d_ff 512 \
44 | --batch_size 16 \
45 | --learning_rate 0.0005 \
46 | --itr 1
47 |
48 | python -u run.py \
49 | --is_training 1 \
50 | --root_path ./dataset/electricity/ \
51 | --data_path electricity.csv \
52 | --model_id ECL_192_96 \
53 | --model $model_name \
54 | --data custom \
55 | --features M \
56 | --seq_len 192 \
57 | --pred_len 96 \
58 | --e_layers 3 \
59 | --enc_in 321 \
60 | --dec_in 321 \
61 | --c_out 321 \
62 | --des 'Exp' \
63 | --d_model 512 \
64 | --d_ff 512 \
65 | --batch_size 16 \
66 | --learning_rate 0.0005 \
67 | --itr 1
68 |
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path ./dataset/electricity/ \
72 | --data_path electricity.csv \
73 | --model_id ECL_336_96 \
74 | --model $model_name \
75 | --data custom \
76 | --features M \
77 | --seq_len 336 \
78 | --pred_len 96 \
79 | --e_layers 3 \
80 | --enc_in 321 \
81 | --dec_in 321 \
82 | --c_out 321 \
83 | --des 'Exp' \
84 | --d_model 512 \
85 | --d_ff 512 \
86 | --batch_size 16 \
87 | --learning_rate 0.0005 \
88 | --itr 1
89 |
90 | python -u run.py \
91 | --is_training 1 \
92 | --root_path ./dataset/electricity/ \
93 | --data_path electricity.csv \
94 | --model_id ECL_720_96 \
95 | --model $model_name \
96 | --data custom \
97 | --features M \
98 | --seq_len 720 \
99 | --pred_len 96 \
100 | --e_layers 3 \
101 | --enc_in 321 \
102 | --dec_in 321 \
103 | --c_out 321 \
104 | --des 'Exp' \
105 | --d_model 512 \
106 | --d_ff 512 \
107 | --batch_size 16 \
108 | --learning_rate 0.0005 \
109 | --itr 1
110 |
--------------------------------------------------------------------------------
/scripts/increasing_lookback/ECL/iReformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | # model_name=Reformer
4 | model_name=iReformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/electricity/ \
9 | --data_path electricity.csv \
10 | --model_id ECL_48_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 48 \
15 | --pred_len 96 \
16 | --e_layers 3 \
17 | --enc_in 321 \
18 | --dec_in 321 \
19 | --c_out 321 \
20 | --des 'Exp' \
21 | --d_model 512 \
22 | --d_ff 512 \
23 | --batch_size 16 \
24 | --learning_rate 0.0005 \
25 | --itr 1
26 |
27 | python -u run.py \
28 | --is_training 1 \
29 | --root_path ./dataset/electricity/ \
30 | --data_path electricity.csv \
31 | --model_id ECL_96_96 \
32 | --model $model_name \
33 | --data custom \
34 | --features M \
35 | --seq_len 96 \
36 | --pred_len 96 \
37 | --e_layers 3 \
38 | --enc_in 321 \
39 | --dec_in 321 \
40 | --c_out 321 \
41 | --des 'Exp' \
42 | --d_model 512 \
43 | --d_ff 512 \
44 | --batch_size 16 \
45 | --learning_rate 0.0005 \
46 | --itr 1
47 |
48 | python -u run.py \
49 | --is_training 1 \
50 | --root_path ./dataset/electricity/ \
51 | --data_path electricity.csv \
52 | --model_id ECL_192_96 \
53 | --model $model_name \
54 | --data custom \
55 | --features M \
56 | --seq_len 192 \
57 | --pred_len 96 \
58 | --e_layers 3 \
59 | --enc_in 321 \
60 | --dec_in 321 \
61 | --c_out 321 \
62 | --des 'Exp' \
63 | --d_model 512 \
64 | --d_ff 512 \
65 | --batch_size 16 \
66 | --learning_rate 0.0005 \
67 | --itr 1
68 |
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path ./dataset/electricity/ \
72 | --data_path electricity.csv \
73 | --model_id ECL_336_96 \
74 | --model $model_name \
75 | --data custom \
76 | --features M \
77 | --seq_len 336 \
78 | --pred_len 96 \
79 | --e_layers 3 \
80 | --enc_in 321 \
81 | --dec_in 321 \
82 | --c_out 321 \
83 | --des 'Exp' \
84 | --d_model 512 \
85 | --d_ff 512 \
86 | --batch_size 16 \
87 | --learning_rate 0.0005 \
88 | --itr 1
89 |
90 | python -u run.py \
91 | --is_training 1 \
92 | --root_path ./dataset/electricity/ \
93 | --data_path electricity.csv \
94 | --model_id ECL_720_96 \
95 | --model $model_name \
96 | --data custom \
97 | --features M \
98 | --seq_len 720 \
99 | --pred_len 96 \
100 | --e_layers 3 \
101 | --enc_in 321 \
102 | --dec_in 321 \
103 | --c_out 321 \
104 | --des 'Exp' \
105 | --d_model 512 \
106 | --d_ff 512 \
107 | --batch_size 16 \
108 | --learning_rate 0.0005 \
109 | --itr 1
110 |
--------------------------------------------------------------------------------
/scripts/increasing_lookback/ECL/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | # model_name=Transformer
4 | model_name=iTransformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/electricity/ \
9 | --data_path electricity.csv \
10 | --model_id ECL_48_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 48 \
15 | --pred_len 96 \
16 | --e_layers 3 \
17 | --enc_in 321 \
18 | --dec_in 321 \
19 | --c_out 321 \
20 | --des 'Exp' \
21 | --d_model 512 \
22 | --d_ff 512 \
23 | --batch_size 16 \
24 | --learning_rate 0.0005 \
25 | --itr 1
26 |
27 | python -u run.py \
28 | --is_training 1 \
29 | --root_path ./dataset/electricity/ \
30 | --data_path electricity.csv \
31 | --model_id ECL_96_96 \
32 | --model $model_name \
33 | --data custom \
34 | --features M \
35 | --seq_len 96 \
36 | --pred_len 96 \
37 | --e_layers 3 \
38 | --enc_in 321 \
39 | --dec_in 321 \
40 | --c_out 321 \
41 | --des 'Exp' \
42 | --d_model 512 \
43 | --d_ff 512 \
44 | --batch_size 16 \
45 | --learning_rate 0.0005 \
46 | --itr 1
47 |
48 | python -u run.py \
49 | --is_training 1 \
50 | --root_path ./dataset/electricity/ \
51 | --data_path electricity.csv \
52 | --model_id ECL_192_96 \
53 | --model $model_name \
54 | --data custom \
55 | --features M \
56 | --seq_len 192 \
57 | --pred_len 96 \
58 | --e_layers 3 \
59 | --enc_in 321 \
60 | --dec_in 321 \
61 | --c_out 321 \
62 | --des 'Exp' \
63 | --d_model 512 \
64 | --d_ff 512 \
65 | --batch_size 16 \
66 | --learning_rate 0.0005 \
67 | --itr 1
68 |
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path ./dataset/electricity/ \
72 | --data_path electricity.csv \
73 | --model_id ECL_336_96 \
74 | --model $model_name \
75 | --data custom \
76 | --features M \
77 | --seq_len 336 \
78 | --pred_len 96 \
79 | --e_layers 3 \
80 | --enc_in 321 \
81 | --dec_in 321 \
82 | --c_out 321 \
83 | --des 'Exp' \
84 | --d_model 512 \
85 | --d_ff 512 \
86 | --batch_size 16 \
87 | --learning_rate 0.0005 \
88 | --itr 1
89 |
90 | python -u run.py \
91 | --is_training 1 \
92 | --root_path ./dataset/electricity/ \
93 | --data_path electricity.csv \
94 | --model_id ECL_720_96 \
95 | --model $model_name \
96 | --data custom \
97 | --features M \
98 | --seq_len 720 \
99 | --pred_len 96 \
100 | --e_layers 3 \
101 | --enc_in 321 \
102 | --dec_in 321 \
103 | --c_out 321 \
104 | --des 'Exp' \
105 | --d_model 512 \
106 | --d_ff 512 \
107 | --batch_size 16 \
108 | --learning_rate 0.0005 \
109 | --itr 1
110 |
--------------------------------------------------------------------------------
/scripts/increasing_lookback/README.md:
--------------------------------------------------------------------------------
1 | # iTransformer for Enlarged Lookback Window
2 |
3 | This folder contains the implementation of the iTransformer for an enlarged lookback window. If you are new to this repo, we recommend you to read this [README](../multivariate_forecasting/README.md) first.
4 |
5 | ## Scripts
6 |
7 | In each folder named after the dataset, we provide the iTransformers and the vanilla Transformers experiments under five increasing prediction lengths.
8 |
9 | ```
10 | # iTransformer on the Traffic Dataset with gradually enlarged lookback windows.
11 |
12 | bash ./scripts/increasing_lookback/Traffic/iTransformer.sh
13 | ```
14 |
15 | You can change the ```model_name``` in the script to switch the selection of the vanilla Transformer and the inverted version.
16 |
17 | ## Results
18 |
19 |
20 |
21 |
22 |
23 | The inverted framework empowers Transformers with improved performance on the enlarged lookback window.
--------------------------------------------------------------------------------
/scripts/increasing_lookback/Traffic/iFlowformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | # model_name=Flowformer
4 | model_name=iFlowformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/traffic/ \
9 | --data_path traffic.csv \
10 | --model_id traffic_48_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 48 \
15 | --pred_len 96 \
16 | --e_layers 4 \
17 | --enc_in 862 \
18 | --dec_in 862 \
19 | --c_out 862 \
20 | --des 'Exp' \
21 | --d_model 512 \
22 | --d_ff 512 \
23 | --batch_size 16 \
24 | --learning_rate 0.001 \
25 | --itr 1
26 |
27 | python -u run.py \
28 | --is_training 1 \
29 | --root_path ./dataset/traffic/ \
30 | --data_path traffic.csv \
31 | --model_id traffic_96_96 \
32 | --model $model_name \
33 | --data custom \
34 | --features M \
35 | --seq_len 96 \
36 | --pred_len 96 \
37 | --e_layers 4 \
38 | --enc_in 862 \
39 | --dec_in 862 \
40 | --c_out 862 \
41 | --des 'Exp' \
42 | --d_model 512 \
43 | --d_ff 512 \
44 | --batch_size 16 \
45 | --learning_rate 0.001 \
46 | --itr 1
47 |
48 | python -u run.py \
49 | --is_training 1 \
50 | --root_path ./dataset/traffic/ \
51 | --data_path traffic.csv \
52 | --model_id traffic_192_96 \
53 | --model $model_name \
54 | --data custom \
55 | --features M \
56 | --seq_len 192 \
57 | --pred_len 96 \
58 | --e_layers 4 \
59 | --enc_in 862 \
60 | --dec_in 862 \
61 | --c_out 862 \
62 | --des 'Exp' \
63 | --d_model 512 \
64 | --d_ff 512 \
65 | --batch_size 16 \
66 | --learning_rate 0.001 \
67 | --itr 1
68 |
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path ./dataset/traffic/ \
72 | --data_path traffic.csv \
73 | --model_id traffic_336_96 \
74 | --model $model_name \
75 | --data custom \
76 | --features M \
77 | --seq_len 336 \
78 | --pred_len 96 \
79 | --e_layers 4 \
80 | --enc_in 862 \
81 | --dec_in 862 \
82 | --c_out 862 \
83 | --des 'Exp' \
84 | --d_model 512 \
85 | --d_ff 512 \
86 | --batch_size 16 \
87 | --learning_rate 0.001 \
88 | --itr 1
89 |
90 | python -u run.py \
91 | --is_training 1 \
92 | --root_path ./dataset/traffic/ \
93 | --data_path traffic.csv \
94 | --model_id traffic_720_96 \
95 | --model $model_name \
96 | --data custom \
97 | --features M \
98 | --seq_len 720 \
99 | --pred_len 96 \
100 | --e_layers 4 \
101 | --enc_in 862 \
102 | --dec_in 862 \
103 | --c_out 862 \
104 | --des 'Exp' \
105 | --d_model 512 \
106 | --d_ff 512 \
107 | --batch_size 16 \
108 | --learning_rate 0.001 \
109 | --itr 1
110 |
--------------------------------------------------------------------------------
/scripts/increasing_lookback/Traffic/iInformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | # model_name=Informer
4 | model_name=iInformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/traffic/ \
9 | --data_path traffic.csv \
10 | --model_id traffic_48_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 48 \
15 | --pred_len 96 \
16 | --e_layers 4 \
17 | --enc_in 862 \
18 | --dec_in 862 \
19 | --c_out 862 \
20 | --des 'Exp' \
21 | --d_model 512 \
22 | --d_ff 512 \
23 | --batch_size 16 \
24 | --learning_rate 0.001 \
25 | --itr 1
26 |
27 | python -u run.py \
28 | --is_training 1 \
29 | --root_path ./dataset/traffic/ \
30 | --data_path traffic.csv \
31 | --model_id traffic_96_96 \
32 | --model $model_name \
33 | --data custom \
34 | --features M \
35 | --seq_len 96 \
36 | --pred_len 96 \
37 | --e_layers 4 \
38 | --enc_in 862 \
39 | --dec_in 862 \
40 | --c_out 862 \
41 | --des 'Exp' \
42 | --d_model 512 \
43 | --d_ff 512 \
44 | --batch_size 16 \
45 | --learning_rate 0.001 \
46 | --itr 1
47 |
48 | python -u run.py \
49 | --is_training 1 \
50 | --root_path ./dataset/traffic/ \
51 | --data_path traffic.csv \
52 | --model_id traffic_192_96 \
53 | --model $model_name \
54 | --data custom \
55 | --features M \
56 | --seq_len 192 \
57 | --pred_len 96 \
58 | --e_layers 4 \
59 | --enc_in 862 \
60 | --dec_in 862 \
61 | --c_out 862 \
62 | --des 'Exp' \
63 | --d_model 512 \
64 | --d_ff 512 \
65 | --batch_size 16 \
66 | --learning_rate 0.001 \
67 | --itr 1
68 |
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path ./dataset/traffic/ \
72 | --data_path traffic.csv \
73 | --model_id traffic_336_96 \
74 | --model $model_name \
75 | --data custom \
76 | --features M \
77 | --seq_len 336 \
78 | --pred_len 96 \
79 | --e_layers 4 \
80 | --enc_in 862 \
81 | --dec_in 862 \
82 | --c_out 862 \
83 | --des 'Exp' \
84 | --d_model 512 \
85 | --d_ff 512 \
86 | --batch_size 16 \
87 | --learning_rate 0.001 \
88 | --itr 1
89 |
90 | python -u run.py \
91 | --is_training 1 \
92 | --root_path ./dataset/traffic/ \
93 | --data_path traffic.csv \
94 | --model_id traffic_720_96 \
95 | --model $model_name \
96 | --data custom \
97 | --features M \
98 | --seq_len 720 \
99 | --pred_len 96 \
100 | --e_layers 4 \
101 | --enc_in 862 \
102 | --dec_in 862 \
103 | --c_out 862 \
104 | --des 'Exp' \
105 | --d_model 512 \
106 | --d_ff 512 \
107 | --batch_size 16 \
108 | --learning_rate 0.001 \
109 | --itr 1
110 |
--------------------------------------------------------------------------------
/scripts/increasing_lookback/Traffic/iReformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | # model_name=Reformer
4 | model_name=iReformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/traffic/ \
9 | --data_path traffic.csv \
10 | --model_id traffic_48_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 48 \
15 | --pred_len 96 \
16 | --e_layers 4 \
17 | --enc_in 862 \
18 | --dec_in 862 \
19 | --c_out 862 \
20 | --des 'Exp' \
21 | --d_model 512 \
22 | --d_ff 512 \
23 | --batch_size 16 \
24 | --learning_rate 0.001 \
25 | --itr 1
26 |
27 | python -u run.py \
28 | --is_training 1 \
29 | --root_path ./dataset/traffic/ \
30 | --data_path traffic.csv \
31 | --model_id traffic_96_96 \
32 | --model $model_name \
33 | --data custom \
34 | --features M \
35 | --seq_len 96 \
36 | --pred_len 96 \
37 | --e_layers 4 \
38 | --enc_in 862 \
39 | --dec_in 862 \
40 | --c_out 862 \
41 | --des 'Exp' \
42 | --d_model 512 \
43 | --d_ff 512 \
44 | --batch_size 16 \
45 | --learning_rate 0.001 \
46 | --itr 1
47 |
48 | python -u run.py \
49 | --is_training 1 \
50 | --root_path ./dataset/traffic/ \
51 | --data_path traffic.csv \
52 | --model_id traffic_192_96 \
53 | --model $model_name \
54 | --data custom \
55 | --features M \
56 | --seq_len 192 \
57 | --pred_len 96 \
58 | --e_layers 4 \
59 | --enc_in 862 \
60 | --dec_in 862 \
61 | --c_out 862 \
62 | --des 'Exp' \
63 | --d_model 512 \
64 | --d_ff 512 \
65 | --batch_size 16 \
66 | --learning_rate 0.001 \
67 | --itr 1
68 |
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path ./dataset/traffic/ \
72 | --data_path traffic.csv \
73 | --model_id traffic_336_96 \
74 | --model $model_name \
75 | --data custom \
76 | --features M \
77 | --seq_len 336 \
78 | --pred_len 96 \
79 | --e_layers 4 \
80 | --enc_in 862 \
81 | --dec_in 862 \
82 | --c_out 862 \
83 | --des 'Exp' \
84 | --d_model 512 \
85 | --d_ff 512 \
86 | --batch_size 16 \
87 | --learning_rate 0.001 \
88 | --itr 1
89 |
90 | python -u run.py \
91 | --is_training 1 \
92 | --root_path ./dataset/traffic/ \
93 | --data_path traffic.csv \
94 | --model_id traffic_720_96 \
95 | --model $model_name \
96 | --data custom \
97 | --features M \
98 | --seq_len 720 \
99 | --pred_len 96 \
100 | --e_layers 4 \
101 | --enc_in 862 \
102 | --dec_in 862 \
103 | --c_out 862 \
104 | --des 'Exp' \
105 | --d_model 512 \
106 | --d_ff 512 \
107 | --batch_size 16 \
108 | --learning_rate 0.001 \
109 | --itr 1
--------------------------------------------------------------------------------
/scripts/increasing_lookback/Traffic/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | # model_name=Transformer
4 | model_name=iTransformer
5 |
6 | python -u run.py \
7 | --is_training 1 \
8 | --root_path ./dataset/traffic/ \
9 | --data_path traffic.csv \
10 | --model_id traffic_48_96 \
11 | --model $model_name \
12 | --data custom \
13 | --features M \
14 | --seq_len 48 \
15 | --pred_len 96 \
16 | --e_layers 4 \
17 | --enc_in 862 \
18 | --dec_in 862 \
19 | --c_out 862 \
20 | --des 'Exp' \
21 | --d_model 512 \
22 | --d_ff 512 \
23 | --batch_size 16 \
24 | --learning_rate 0.001 \
25 | --itr 1
26 |
27 | python -u run.py \
28 | --is_training 1 \
29 | --root_path ./dataset/traffic/ \
30 | --data_path traffic.csv \
31 | --model_id traffic_96_96 \
32 | --model $model_name \
33 | --data custom \
34 | --features M \
35 | --seq_len 96 \
36 | --pred_len 96 \
37 | --factor 3 \
38 | --enc_in 862 \
39 | --dec_in 862 \
40 | --c_out 862 \
41 | --des 'Exp' \
42 | --d_model 512 \
43 | --d_ff 512 \
44 | --batch_size 16 \
45 | --learning_rate 0.001 \
46 | --itr 1
47 |
48 | python -u run.py \
49 | --is_training 1 \
50 | --root_path ./dataset/traffic/ \
51 | --data_path traffic.csv \
52 | --model_id traffic_192_96 \
53 | --model $model_name \
54 | --data custom \
55 | --features M \
56 | --seq_len 192 \
57 | --pred_len 96 \
58 | --e_layers 4 \
59 | --enc_in 862 \
60 | --dec_in 862 \
61 | --c_out 862 \
62 | --des 'Exp' \
63 | --d_model 512 \
64 | --d_ff 512 \
65 | --batch_size 16 \
66 | --learning_rate 0.001 \
67 | --itr 1
68 |
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path ./dataset/traffic/ \
72 | --data_path traffic.csv \
73 | --model_id traffic_336_96 \
74 | --model $model_name \
75 | --data custom \
76 | --features M \
77 | --seq_len 336 \
78 | --pred_len 96 \
79 | --e_layers 4 \
80 | --enc_in 862 \
81 | --dec_in 862 \
82 | --c_out 862 \
83 | --des 'Exp' \
84 | --d_model 512 \
85 | --d_ff 512 \
86 | --batch_size 16 \
87 | --learning_rate 0.001 \
88 | --itr 1
89 |
90 | python -u run.py \
91 | --is_training 1 \
92 | --root_path ./dataset/traffic/ \
93 | --data_path traffic.csv \
94 | --model_id traffic_720_96 \
95 | --model $model_name \
96 | --data custom \
97 | --features M \
98 | --seq_len 720 \
99 | --pred_len 96 \
100 | --e_layers 4 \
101 | --enc_in 862 \
102 | --dec_in 862 \
103 | --c_out 862 \
104 | --des 'Exp' \
105 | --d_model 512 \
106 | --d_ff 512 \
107 | --batch_size 16 \
108 | --learning_rate 0.001 \
109 | --itr 1
110 |
--------------------------------------------------------------------------------
/scripts/model_efficiency/ECL/iFlashTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_name=Flashformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/electricity/ \
8 | --data_path electricity.csv \
9 | --model_id ECL_96_96 \
10 | --model $model_name \
11 | --data custom \
12 | --features M \
13 | --seq_len 96 \
14 | --label_len 48 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --d_layers 1 \
18 | --factor 3 \
19 | --enc_in 321 \
20 | --dec_in 321 \
21 | --c_out 321 \
22 | --des 'Exp' \
23 | --itr 1
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/electricity/ \
28 | --data_path electricity.csv \
29 | --model_id ECL_96_192 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --label_len 48 \
35 | --pred_len 192 \
36 | --e_layers 2 \
37 | --d_layers 1 \
38 | --factor 3 \
39 | --enc_in 321 \
40 | --dec_in 321 \
41 | --c_out 321 \
42 | --des 'Exp' \
43 | --itr 1
44 |
45 | python -u run.py \
46 | --is_training 1 \
47 | --root_path ./dataset/electricity/ \
48 | --data_path electricity.csv \
49 | --model_id ECL_96_336 \
50 | --model $model_name \
51 | --data custom \
52 | --features M \
53 | --seq_len 96 \
54 | --label_len 48 \
55 | --pred_len 336 \
56 | --e_layers 2 \
57 | --d_layers 1 \
58 | --factor 3 \
59 | --enc_in 321 \
60 | --dec_in 321 \
61 | --c_out 321 \
62 | --des 'Exp' \
63 | --itr 1
64 |
65 | python -u run.py \
66 | --is_training 1 \
67 | --root_path ./dataset/electricity/ \
68 | --data_path electricity.csv \
69 | --model_id ECL_96_720 \
70 | --model $model_name \
71 | --data custom \
72 | --features M \
73 | --seq_len 96 \
74 | --label_len 48 \
75 | --pred_len 720 \
76 | --e_layers 2 \
77 | --d_layers 1 \
78 | --factor 3 \
79 | --enc_in 321 \
80 | --dec_in 321 \
81 | --c_out 321 \
82 | --des 'Exp' \
83 | --itr 1
84 |
85 | model_name=iFlashformer
86 |
87 | python -u run.py \
88 | --is_training 1 \
89 | --root_path ./dataset/electricity/ \
90 | --data_path electricity.csv \
91 | --model_id ECL_96_96 \
92 | --model $model_name \
93 | --data custom \
94 | --features M \
95 | --seq_len 96 \
96 | --label_len 48 \
97 | --pred_len 96 \
98 | --e_layers 2 \
99 | --d_layers 1 \
100 | --factor 3 \
101 | --enc_in 321 \
102 | --dec_in 321 \
103 | --c_out 321 \
104 | --des 'Exp' \
105 | --itr 1
106 |
107 | python -u run.py \
108 | --is_training 1 \
109 | --root_path ./dataset/electricity/ \
110 | --data_path electricity.csv \
111 | --model_id ECL_96_192 \
112 | --model $model_name \
113 | --data custom \
114 | --features M \
115 | --seq_len 96 \
116 | --label_len 48 \
117 | --pred_len 192 \
118 | --e_layers 2 \
119 | --d_layers 1 \
120 | --factor 3 \
121 | --enc_in 321 \
122 | --dec_in 321 \
123 | --c_out 321 \
124 | --des 'Exp' \
125 | --itr 1
126 |
127 | python -u run.py \
128 | --is_training 1 \
129 | --root_path ./dataset/electricity/ \
130 | --data_path electricity.csv \
131 | --model_id ECL_96_336 \
132 | --model $model_name \
133 | --data custom \
134 | --features M \
135 | --seq_len 96 \
136 | --label_len 48 \
137 | --pred_len 336 \
138 | --e_layers 2 \
139 | --d_layers 1 \
140 | --factor 3 \
141 | --enc_in 321 \
142 | --dec_in 321 \
143 | --c_out 321 \
144 | --des 'Exp' \
145 | --itr 1
146 |
147 | python -u run.py \
148 | --is_training 1 \
149 | --root_path ./dataset/electricity/ \
150 | --data_path electricity.csv \
151 | --model_id ECL_96_720 \
152 | --model $model_name \
153 | --data custom \
154 | --features M \
155 | --seq_len 96 \
156 | --label_len 48 \
157 | --pred_len 720 \
158 | --e_layers 2 \
159 | --d_layers 1 \
160 | --factor 3 \
161 | --enc_in 321 \
162 | --dec_in 321 \
163 | --c_out 321 \
164 | --des 'Exp' \
165 | --itr 1
166 |
--------------------------------------------------------------------------------
/scripts/model_efficiency/README.md:
--------------------------------------------------------------------------------
1 | # Efficiency Improvement of iTransformer
2 |
3 | Supposing the input multivariate time series has a shape of $T \times N$. The vanilla attention module has a complexity of $\mathcal{O}(L^2)$, where $L$ is the number of tokens.
4 |
5 | * In Transformer, we have $L=T$ because of the manner of time points as tokens.
6 | * In iTransformer, we have $L=N$ because of the manner of variates as tokens.
7 |
8 | ## Benefit from Efficient Attention
9 |
10 | Since the attention mechanism is applied on the variate dimension in the inverted structure, efficient attention with reduced complexity essentially addresses the problem of numerous variates, which is ubiquitous in real-world applications.
11 |
12 | We currently try out the linear complexity attention from [Flowformer](https://github.com/thuml/Flowformer), and the hardware-accelerated attention mechanism from [FlashAttention](https://github.com/shreyansh26/FlashAttention-PyTorch). It demonstrates efficiency improvement by adopting these novel attention mechanisms.
13 |
14 | ### Scripts
15 | We provide the iTransformers with the FlashAttention module:
16 |
17 | ```
18 | # iTransformer on the Traffic Dataset with hardware-friendly FlashAttention for acceleration
19 |
20 | bash ./scripts/model_efficiency/Traffic/iFlashTransformer.sh
21 | ```
22 |
23 |
24 | ## Efficient Training Strategy
25 | With the input flexibility of attention, the token number can vary from training to inference, **our model is the first one to be capable of training on arbitrary numbers of series**. We propose a novel training strategy for high-dimensional multivariate series by taking advantage of the [variate generation capability](../variate_generalization/README.md).
26 |
27 | Concretely, we randomly choose part of the variates in each batch and only train the model with selected variates. Since the number of variate channels is flexible because of our inverting, the model can predict all the variates for predictions.
28 |
29 |
30 | ## Results
31 |
32 | **Environments**: The batch size of training is fixed as 16 with comparable model hyperparameters. The experiments run on P100 (16G). We comprehensively compare the training speed, memory footprint, and performance of the following.
33 |
34 |
35 |
36 |
37 |
38 | * The efficiency of iTransformer exceeds other Transformers in Weather with 21 variates. In Traffic with 862 variates, the memory footprints are basically the same, but iTransformer can be trained faster.
39 | * iTransformer achieves particularly better performance on the dataset with numerous variates, since the multivariate correlations can be explicitly utilized.
40 | * By adopting an efficient attention module or our proposed efficient training strategy on partial variates, iTransformer can enjoy the same level of speed and memory footprint as linear forecasters.
--------------------------------------------------------------------------------
/scripts/model_efficiency/Traffic/iFlashTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=Flashformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/traffic/ \
8 | --data_path traffic.csv \
9 | --model_id traffic_96_96 \
10 | --model $model_name \
11 | --data custom \
12 | --features M \
13 | --seq_len 96 \
14 | --label_len 48 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --d_layers 1 \
18 | --factor 3 \
19 | --enc_in 862 \
20 | --dec_in 862 \
21 | --c_out 862 \
22 | --des 'Exp' \
23 | --itr 1 \
24 | --train_epochs 3
25 |
26 | python -u run.py \
27 | --is_training 1 \
28 | --root_path ./dataset/traffic/ \
29 | --data_path traffic.csv \
30 | --model_id traffic_96_192 \
31 | --model $model_name \
32 | --data custom \
33 | --features M \
34 | --seq_len 96 \
35 | --label_len 48 \
36 | --pred_len 192 \
37 | --e_layers 2 \
38 | --d_layers 1 \
39 | --factor 3 \
40 | --enc_in 862 \
41 | --dec_in 862 \
42 | --c_out 862 \
43 | --des 'Exp' \
44 | --itr 1 \
45 | --train_epochs 3
46 |
47 | python -u run.py \
48 | --is_training 1 \
49 | --root_path ./dataset/traffic/ \
50 | --data_path traffic.csv \
51 | --model_id traffic_96_336 \
52 | --model $model_name \
53 | --data custom \
54 | --features M \
55 | --seq_len 96 \
56 | --label_len 48 \
57 | --pred_len 336 \
58 | --e_layers 2 \
59 | --d_layers 1 \
60 | --factor 3 \
61 | --enc_in 862 \
62 | --dec_in 862 \
63 | --c_out 862 \
64 | --des 'Exp' \
65 | --itr 1 \
66 | --train_epochs 3
67 |
68 | python -u run.py \
69 | --is_training 1 \
70 | --root_path ./dataset/traffic/ \
71 | --data_path traffic.csv \
72 | --model_id traffic_96_720 \
73 | --model $model_name \
74 | --data custom \
75 | --features M \
76 | --seq_len 96 \
77 | --label_len 48 \
78 | --pred_len 720 \
79 | --e_layers 2 \
80 | --d_layers 1 \
81 | --factor 3 \
82 | --enc_in 862 \
83 | --dec_in 862 \
84 | --c_out 862 \
85 | --des 'Exp' \
86 | --itr 1 \
87 | --train_epochs 3
88 |
89 | model_name=iFlashformer
90 |
91 | python -u run.py \
92 | --is_training 1 \
93 | --root_path ./dataset/traffic/ \
94 | --data_path traffic.csv \
95 | --model_id traffic_96_96 \
96 | --model $model_name \
97 | --data custom \
98 | --features M \
99 | --seq_len 96 \
100 | --label_len 48 \
101 | --pred_len 96 \
102 | --e_layers 2 \
103 | --d_layers 1 \
104 | --factor 3 \
105 | --enc_in 862 \
106 | --dec_in 862 \
107 | --c_out 862 \
108 | --des 'Exp' \
109 | --itr 1 \
110 | --train_epochs 3
111 |
112 | python -u run.py \
113 | --is_training 1 \
114 | --root_path ./dataset/traffic/ \
115 | --data_path traffic.csv \
116 | --model_id traffic_96_192 \
117 | --model $model_name \
118 | --data custom \
119 | --features M \
120 | --seq_len 96 \
121 | --label_len 48 \
122 | --pred_len 192 \
123 | --e_layers 2 \
124 | --d_layers 1 \
125 | --factor 3 \
126 | --enc_in 862 \
127 | --dec_in 862 \
128 | --c_out 862 \
129 | --des 'Exp' \
130 | --itr 1 \
131 | --train_epochs 3
132 |
133 | python -u run.py \
134 | --is_training 1 \
135 | --root_path ./dataset/traffic/ \
136 | --data_path traffic.csv \
137 | --model_id traffic_96_336 \
138 | --model $model_name \
139 | --data custom \
140 | --features M \
141 | --seq_len 96 \
142 | --label_len 48 \
143 | --pred_len 336 \
144 | --e_layers 2 \
145 | --d_layers 1 \
146 | --factor 3 \
147 | --enc_in 862 \
148 | --dec_in 862 \
149 | --c_out 862 \
150 | --des 'Exp' \
151 | --itr 1 \
152 | --train_epochs 3
153 |
154 | python -u run.py \
155 | --is_training 1 \
156 | --root_path ./dataset/traffic/ \
157 | --data_path traffic.csv \
158 | --model_id traffic_96_720 \
159 | --model $model_name \
160 | --data custom \
161 | --features M \
162 | --seq_len 96 \
163 | --label_len 48 \
164 | --pred_len 720 \
165 | --e_layers 2 \
166 | --d_layers 1 \
167 | --factor 3 \
168 | --enc_in 862 \
169 | --dec_in 862 \
170 | --c_out 862 \
171 | --des 'Exp' \
172 | --itr 1 \
173 | --train_epochs 3
174 |
--------------------------------------------------------------------------------
/scripts/model_efficiency/Weather/iFlashTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_name=Flashformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/weather/ \
8 | --data_path weather.csv \
9 | --model_id weather_96_96 \
10 | --model $model_name \
11 | --data custom \
12 | --features M \
13 | --seq_len 96 \
14 | --label_len 48 \
15 | --pred_len 96 \
16 | --e_layers 2 \
17 | --d_layers 1 \
18 | --factor 3 \
19 | --enc_in 21 \
20 | --dec_in 21 \
21 | --c_out 21 \
22 | --des 'Exp' \
23 | --itr 1 \
24 | --batch_size 128 \
25 | --train_epochs 3
26 |
27 | python -u run.py \
28 | --is_training 1 \
29 | --root_path ./dataset/weather/ \
30 | --data_path weather.csv \
31 | --model_id weather_96_192 \
32 | --model $model_name \
33 | --data custom \
34 | --features M \
35 | --seq_len 96 \
36 | --label_len 48 \
37 | --pred_len 192 \
38 | --e_layers 2 \
39 | --d_layers 1 \
40 | --factor 3 \
41 | --enc_in 21 \
42 | --dec_in 21 \
43 | --c_out 21 \
44 | --des 'Exp' \
45 | --batch_size 128 \
46 | --itr 1
47 |
48 | python -u run.py \
49 | --is_training 1 \
50 | --root_path ./dataset/weather/ \
51 | --data_path weather.csv \
52 | --model_id weather_96_336 \
53 | --model $model_name \
54 | --data custom \
55 | --features M \
56 | --seq_len 96 \
57 | --label_len 48 \
58 | --pred_len 336 \
59 | --e_layers 2 \
60 | --d_layers 1 \
61 | --factor 3 \
62 | --enc_in 21 \
63 | --dec_in 21 \
64 | --c_out 21 \
65 | --des 'Exp' \
66 | --batch_size 128 \
67 | --itr 1
68 |
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path ./dataset/weather/ \
72 | --data_path weather.csv \
73 | --model_id weather_96_720 \
74 | --model $model_name \
75 | --data custom \
76 | --features M \
77 | --seq_len 96 \
78 | --label_len 48 \
79 | --pred_len 720 \
80 | --e_layers 2 \
81 | --d_layers 1 \
82 | --factor 3 \
83 | --enc_in 21 \
84 | --dec_in 21 \
85 | --c_out 21 \
86 | --des 'Exp' \
87 | --batch_size 128 \
88 | --itr 1
89 |
90 | model_name=iFlashformer
91 |
92 | python -u run.py \
93 | --is_training 1 \
94 | --root_path ./dataset/weather/ \
95 | --data_path weather.csv \
96 | --model_id weather_96_96 \
97 | --model $model_name \
98 | --data custom \
99 | --features M \
100 | --seq_len 96 \
101 | --label_len 48 \
102 | --pred_len 96 \
103 | --e_layers 2 \
104 | --d_layers 1 \
105 | --factor 3 \
106 | --enc_in 21 \
107 | --dec_in 21 \
108 | --c_out 21 \
109 | --des 'Exp' \
110 | --batch_size 128 \
111 | --itr 1 \
112 | --train_epochs 3
113 |
114 | python -u run.py \
115 | --is_training 1 \
116 | --root_path ./dataset/weather/ \
117 | --data_path weather.csv \
118 | --model_id weather_96_192 \
119 | --model $model_name \
120 | --data custom \
121 | --features M \
122 | --seq_len 96 \
123 | --label_len 48 \
124 | --pred_len 192 \
125 | --e_layers 2 \
126 | --d_layers 1 \
127 | --factor 3 \
128 | --enc_in 21 \
129 | --dec_in 21 \
130 | --c_out 21 \
131 | --des 'Exp' \
132 | --batch_size 128 \
133 | --itr 1
134 |
135 | python -u run.py \
136 | --is_training 1 \
137 | --root_path ./dataset/weather/ \
138 | --data_path weather.csv \
139 | --model_id weather_96_336 \
140 | --model $model_name \
141 | --data custom \
142 | --features M \
143 | --seq_len 96 \
144 | --label_len 48 \
145 | --pred_len 336 \
146 | --e_layers 2 \
147 | --d_layers 1 \
148 | --factor 3 \
149 | --enc_in 21 \
150 | --dec_in 21 \
151 | --c_out 21 \
152 | --des 'Exp' \
153 | --batch_size 128 \
154 | --itr 1
155 |
156 | python -u run.py \
157 | --is_training 1 \
158 | --root_path ./dataset/weather/ \
159 | --data_path weather.csv \
160 | --model_id weather_96_720 \
161 | --model $model_name \
162 | --data custom \
163 | --features M \
164 | --seq_len 96 \
165 | --label_len 48 \
166 | --pred_len 720 \
167 | --e_layers 2 \
168 | --d_layers 1 \
169 | --factor 3 \
170 | --enc_in 21 \
171 | --dec_in 21 \
172 | --c_out 21 \
173 | --des 'Exp' \
174 | --batch_size 128 \
175 | --itr 1
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/ECL/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/electricity/ \
8 | --data_path electricity.csv \
9 | --model_id ECL_96_96 \
10 | --model $model_name \
11 | --data custom \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 96 \
15 | --e_layers 3 \
16 | --enc_in 321 \
17 | --dec_in 321 \
18 | --c_out 321 \
19 | --des 'Exp' \
20 | --d_model 512 \
21 | --d_ff 512 \
22 | --batch_size 16 \
23 | --learning_rate 0.0005 \
24 | --itr 1
25 |
26 | python -u run.py \
27 | --is_training 1 \
28 | --root_path ./dataset/electricity/ \
29 | --data_path electricity.csv \
30 | --model_id ECL_96_192 \
31 | --model $model_name \
32 | --data custom \
33 | --features M \
34 | --seq_len 96 \
35 | --pred_len 192 \
36 | --e_layers 3 \
37 | --enc_in 321 \
38 | --dec_in 321 \
39 | --c_out 321 \
40 | --des 'Exp' \
41 | --d_model 512 \
42 | --d_ff 512 \
43 | --batch_size 16 \
44 | --learning_rate 0.0005 \
45 | --itr 1
46 |
47 |
48 | python -u run.py \
49 | --is_training 1 \
50 | --root_path ./dataset/electricity/ \
51 | --data_path electricity.csv \
52 | --model_id ECL_96_336 \
53 | --model $model_name \
54 | --data custom \
55 | --features M \
56 | --seq_len 96 \
57 | --pred_len 336 \
58 | --e_layers 3 \
59 | --enc_in 321 \
60 | --dec_in 321 \
61 | --c_out 321 \
62 | --des 'Exp' \
63 | --d_model 512 \
64 | --d_ff 512 \
65 | --batch_size 16 \
66 | --learning_rate 0.0005 \
67 | --itr 1
68 |
69 |
70 | python -u run.py \
71 | --is_training 1 \
72 | --root_path ./dataset/electricity/ \
73 | --data_path electricity.csv \
74 | --model_id ECL_96_720 \
75 | --model $model_name \
76 | --data custom \
77 | --features M \
78 | --seq_len 96 \
79 | --pred_len 720 \
80 | --e_layers 3 \
81 | --enc_in 321 \
82 | --dec_in 321 \
83 | --c_out 321 \
84 | --des 'Exp' \
85 | --d_model 512 \
86 | --d_ff 512 \
87 | --batch_size 16 \
88 | --learning_rate 0.0005 \
89 | --itr 1
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/ETT/iTransformer_ETTh1.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/ETT-small/ \
8 | --data_path ETTh1.csv \
9 | --model_id ETTh1_96_96 \
10 | --model $model_name \
11 | --data ETTh1 \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 96 \
15 | --e_layers 2 \
16 | --enc_in 7 \
17 | --dec_in 7 \
18 | --c_out 7 \
19 | --des 'Exp' \
20 | --d_model 256 \
21 | --d_ff 256 \
22 | --itr 1
23 |
24 | python -u run.py \
25 | --is_training 1 \
26 | --root_path ./dataset/ETT-small/ \
27 | --data_path ETTh1.csv \
28 | --model_id ETTh1_96_192 \
29 | --model $model_name \
30 | --data ETTh1 \
31 | --features M \
32 | --seq_len 96 \
33 | --pred_len 192 \
34 | --e_layers 2 \
35 | --enc_in 7 \
36 | --dec_in 7 \
37 | --c_out 7 \
38 | --des 'Exp' \
39 | --d_model 256 \
40 | --d_ff 256 \
41 | --itr 1
42 |
43 | python -u run.py \
44 | --is_training 1 \
45 | --root_path ./dataset/ETT-small/ \
46 | --data_path ETTh1.csv \
47 | --model_id ETTh1_96_336 \
48 | --model $model_name \
49 | --data ETTh1 \
50 | --features M \
51 | --seq_len 96 \
52 | --pred_len 336 \
53 | --e_layers 2 \
54 | --enc_in 7 \
55 | --dec_in 7 \
56 | --c_out 7 \
57 | --des 'Exp' \
58 | --d_model 512 \
59 | --d_ff 512 \
60 | --itr 1
61 |
62 | python -u run.py \
63 | --is_training 1 \
64 | --root_path ./dataset/ETT-small/ \
65 | --data_path ETTh1.csv \
66 | --model_id ETTh1_96_720 \
67 | --model $model_name \
68 | --data ETTh1 \
69 | --features M \
70 | --seq_len 96 \
71 | --pred_len 720 \
72 | --e_layers 2 \
73 | --enc_in 7 \
74 | --dec_in 7 \
75 | --c_out 7 \
76 | --des 'Exp' \
77 | --d_model 512 \
78 | --d_ff 512 \
79 | --itr 1
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/ETT/iTransformer_ETTh2.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/ETT-small/ \
8 | --data_path ETTh2.csv \
9 | --model_id ETTh2_96_96 \
10 | --model $model_name \
11 | --data ETTh2 \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 96 \
15 | --e_layers 2 \
16 | --enc_in 7 \
17 | --dec_in 7 \
18 | --c_out 7 \
19 | --des 'Exp' \
20 | --d_model 128 \
21 | --d_ff 128 \
22 | --itr 1
23 |
24 | python -u run.py \
25 | --is_training 1 \
26 | --root_path ./dataset/ETT-small/ \
27 | --data_path ETTh2.csv \
28 | --model_id ETTh2_96_192 \
29 | --model $model_name \
30 | --data ETTh2 \
31 | --features M \
32 | --seq_len 96 \
33 | --pred_len 192 \
34 | --e_layers 2 \
35 | --enc_in 7 \
36 | --dec_in 7 \
37 | --c_out 7 \
38 | --des 'Exp' \
39 | --d_model 128 \
40 | --d_ff 128 \
41 | --itr 1
42 |
43 | python -u run.py \
44 | --is_training 1 \
45 | --root_path ./dataset/ETT-small/ \
46 | --data_path ETTh2.csv \
47 | --model_id ETTh2_96_336 \
48 | --model $model_name \
49 | --data ETTh2 \
50 | --features M \
51 | --seq_len 96 \
52 | --pred_len 336 \
53 | --e_layers 2 \
54 | --enc_in 7 \
55 | --dec_in 7 \
56 | --c_out 7 \
57 | --des 'Exp' \
58 | --d_model 128 \
59 | --d_ff 128 \
60 | --itr 1
61 |
62 | python -u run.py \
63 | --is_training 1 \
64 | --root_path ./dataset/ETT-small/ \
65 | --data_path ETTh2.csv \
66 | --model_id ETTh2_96_720 \
67 | --model $model_name \
68 | --data ETTh2 \
69 | --features M \
70 | --seq_len 96 \
71 | --pred_len 720 \
72 | --e_layers 2 \
73 | --enc_in 7 \
74 | --dec_in 7 \
75 | --c_out 7 \
76 | --des 'Exp' \
77 | --d_model 128 \
78 | --d_ff 128 \
79 | --itr 1
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/ETT/iTransformer_ETTm1.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/ETT-small/ \
8 | --data_path ETTm1.csv \
9 | --model_id ETTm1_96_96 \
10 | --model $model_name \
11 | --data ETTm1 \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 96 \
15 | --e_layers 2 \
16 | --enc_in 7 \
17 | --dec_in 7 \
18 | --c_out 7 \
19 | --des 'Exp' \
20 | --d_model 128 \
21 | --d_ff 128 \
22 | --itr 1
23 |
24 | python -u run.py \
25 | --is_training 1 \
26 | --root_path ./dataset/ETT-small/ \
27 | --data_path ETTm1.csv \
28 | --model_id ETTm1_96_192 \
29 | --model $model_name \
30 | --data ETTm1 \
31 | --features M \
32 | --seq_len 96 \
33 | --pred_len 192 \
34 | --e_layers 2 \
35 | --enc_in 7 \
36 | --dec_in 7 \
37 | --c_out 7 \
38 | --des 'Exp' \
39 | --d_model 128 \
40 | --d_ff 128 \
41 | --itr 1
42 |
43 | python -u run.py \
44 | --is_training 1 \
45 | --root_path ./dataset/ETT-small/ \
46 | --data_path ETTm1.csv \
47 | --model_id ETTm1_96_336 \
48 | --model $model_name \
49 | --data ETTm1 \
50 | --features M \
51 | --seq_len 96 \
52 | --pred_len 336 \
53 | --e_layers 2 \
54 | --enc_in 7 \
55 | --dec_in 7 \
56 | --c_out 7 \
57 | --des 'Exp' \
58 | --d_model 128 \
59 | --d_ff 128 \
60 | --itr 1
61 |
62 | python -u run.py \
63 | --is_training 1 \
64 | --root_path ./dataset/ETT-small/ \
65 | --data_path ETTm1.csv \
66 | --model_id ETTm1_96_720 \
67 | --model $model_name \
68 | --data ETTm1 \
69 | --features M \
70 | --seq_len 96 \
71 | --pred_len 720 \
72 | --e_layers 2 \
73 | --enc_in 7 \
74 | --dec_in 7 \
75 | --c_out 7 \
76 | --des 'Exp' \
77 | --d_model 128 \
78 | --d_ff 128 \
79 | --itr 1
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/ETT/iTransformer_ETTm2.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/ETT-small/ \
8 | --data_path ETTm2.csv \
9 | --model_id ETTm2_96_96 \
10 | --model $model_name \
11 | --data ETTm2 \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 96 \
15 | --e_layers 2 \
16 | --enc_in 7 \
17 | --dec_in 7 \
18 | --c_out 7 \
19 | --des 'Exp' \
20 | --d_model 128 \
21 | --d_ff 128 \
22 | --itr 1
23 |
24 | python -u run.py \
25 | --is_training 1 \
26 | --root_path ./dataset/ETT-small/ \
27 | --data_path ETTm2.csv \
28 | --model_id ETTm2_96_192 \
29 | --model $model_name \
30 | --data ETTm2 \
31 | --features M \
32 | --seq_len 96 \
33 | --pred_len 192 \
34 | --e_layers 2 \
35 | --enc_in 7 \
36 | --dec_in 7 \
37 | --c_out 7 \
38 | --des 'Exp' \
39 | --d_model 128 \
40 | --d_ff 128 \
41 | --itr 1
42 |
43 | python -u run.py \
44 | --is_training 1 \
45 | --root_path ./dataset/ETT-small/ \
46 | --data_path ETTm2.csv \
47 | --model_id ETTm2_96_336 \
48 | --model $model_name \
49 | --data ETTm2 \
50 | --features M \
51 | --seq_len 96 \
52 | --pred_len 336 \
53 | --e_layers 2 \
54 | --enc_in 7 \
55 | --dec_in 7 \
56 | --c_out 7 \
57 | --des 'Exp' \
58 | --d_model 128 \
59 | --d_ff 128 \
60 | --itr 1
61 |
62 | python -u run.py \
63 | --is_training 1 \
64 | --root_path ./dataset/ETT-small/ \
65 | --data_path ETTm2.csv \
66 | --model_id ETTm2_96_720 \
67 | --model $model_name \
68 | --data ETTm2 \
69 | --features M \
70 | --seq_len 96 \
71 | --pred_len 720 \
72 | --e_layers 2 \
73 | --enc_in 7 \
74 | --dec_in 7 \
75 | --c_out 7 \
76 | --des 'Exp' \
77 | --d_model 128 \
78 | --d_ff 128 \
79 | --itr 1
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/Exchange/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=3
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/exchange_rate/ \
8 | --data_path exchange_rate.csv \
9 | --model_id Exchange_96_96 \
10 | --model $model_name \
11 | --data custom \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 96 \
15 | --e_layers 2 \
16 | --enc_in 8 \
17 | --dec_in 8 \
18 | --c_out 8 \
19 | --des 'Exp' \
20 | --d_model 128 \
21 | --d_ff 128 \
22 | --itr 1
23 |
24 | python -u run.py \
25 | --is_training 1 \
26 | --root_path ./dataset/exchange_rate/ \
27 | --data_path exchange_rate.csv \
28 | --model_id Exchange_96_192 \
29 | --model $model_name \
30 | --data custom \
31 | --features M \
32 | --seq_len 96 \
33 | --pred_len 192 \
34 | --e_layers 2 \
35 | --enc_in 8 \
36 | --dec_in 8 \
37 | --c_out 8 \
38 | --des 'Exp' \
39 | --d_model 128 \
40 | --d_ff 128 \
41 | --itr 1
42 |
43 | python -u run.py \
44 | --is_training 1 \
45 | --root_path ./dataset/exchange_rate/ \
46 | --data_path exchange_rate.csv \
47 | --model_id Exchange_96_336 \
48 | --model $model_name \
49 | --data custom \
50 | --features M \
51 | --seq_len 96 \
52 | --pred_len 336 \
53 | --e_layers 2 \
54 | --enc_in 8 \
55 | --dec_in 8 \
56 | --c_out 8 \
57 | --des 'Exp' \
58 | --itr 1 \
59 | --d_model 128 \
60 | --d_ff 128 \
61 | --train_epochs 1
62 |
63 | python -u run.py \
64 | --is_training 1 \
65 | --root_path ./dataset/exchange_rate/ \
66 | --data_path exchange_rate.csv \
67 | --model_id Exchange_96_720 \
68 | --model $model_name \
69 | --data custom \
70 | --features M \
71 | --seq_len 96 \
72 | --pred_len 720 \
73 | --e_layers 2 \
74 | --enc_in 8 \
75 | --dec_in 8 \
76 | --c_out 8 \
77 | --des 'Exp' \
78 | --d_model 128 \
79 | --d_ff 128 \
80 | --itr 1
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/PEMS/iTransformer_03.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/PEMS/ \
8 | --data_path PEMS03.npz \
9 | --model_id PEMS03_96_12 \
10 | --model $model_name \
11 | --data PEMS \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 12 \
15 | --e_layers 4 \
16 | --enc_in 358 \
17 | --dec_in 358 \
18 | --c_out 358 \
19 | --des 'Exp' \
20 | --d_model 512 \
21 | --d_ff 512 \
22 | --learning_rate 0.001 \
23 | --itr 1
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/PEMS/ \
28 | --data_path PEMS03.npz \
29 | --model_id PEMS03_96_24 \
30 | --model $model_name \
31 | --data PEMS \
32 | --features M \
33 | --seq_len 96 \
34 | --pred_len 24 \
35 | --e_layers 4 \
36 | --enc_in 358 \
37 | --dec_in 358 \
38 | --c_out 358 \
39 | --des 'Exp' \
40 | --d_model 512 \
41 | --d_ff 512 \
42 | --learning_rate 0.001 \
43 | --itr 1
44 |
45 |
46 | python -u run.py \
47 | --is_training 1 \
48 | --root_path ./dataset/PEMS/ \
49 | --data_path PEMS03.npz \
50 | --model_id PEMS03_96_48 \
51 | --model $model_name \
52 | --data PEMS \
53 | --features M \
54 | --seq_len 96 \
55 | --pred_len 48 \
56 | --e_layers 4 \
57 | --enc_in 358 \
58 | --dec_in 358 \
59 | --c_out 358 \
60 | --des 'Exp' \
61 | --d_model 512 \
62 | --d_ff 512 \
63 | --learning_rate 0.001 \
64 | --itr 1
65 |
66 |
67 | python -u run.py \
68 | --is_training 1 \
69 | --root_path ./dataset/PEMS/ \
70 | --data_path PEMS03.npz \
71 | --model_id PEMS03_96_96 \
72 | --model $model_name \
73 | --data PEMS \
74 | --features M \
75 | --seq_len 96 \
76 | --pred_len 96 \
77 | --e_layers 4 \
78 | --enc_in 358 \
79 | --dec_in 358 \
80 | --c_out 358 \
81 | --des 'Exp' \
82 | --d_model 512 \
83 | --d_ff 512 \
84 | --learning_rate 0.001 \
85 | --itr 1
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/PEMS/iTransformer_04.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/PEMS/ \
8 | --data_path PEMS04.npz \
9 | --model_id PEMS04_96_12 \
10 | --model $model_name \
11 | --data PEMS \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 12 \
15 | --e_layers 4 \
16 | --enc_in 307 \
17 | --dec_in 307 \
18 | --c_out 307 \
19 | --des 'Exp' \
20 | --d_model 1024 \
21 | --d_ff 1024 \
22 | --learning_rate 0.0005 \
23 | --itr 1 \
24 | --use_norm 0
25 |
26 | python -u run.py \
27 | --is_training 1 \
28 | --root_path ./dataset/PEMS/ \
29 | --data_path PEMS04.npz \
30 | --model_id PEMS04_96_24 \
31 | --model $model_name \
32 | --data PEMS \
33 | --features M \
34 | --seq_len 96 \
35 | --pred_len 24 \
36 | --e_layers 4 \
37 | --enc_in 307 \
38 | --dec_in 307 \
39 | --c_out 307 \
40 | --des 'Exp' \
41 | --d_model 1024 \
42 | --d_ff 1024 \
43 | --learning_rate 0.0005 \
44 | --itr 1 \
45 | --use_norm 0
46 |
47 | python -u run.py \
48 | --is_training 1 \
49 | --root_path ./dataset/PEMS/ \
50 | --data_path PEMS04.npz \
51 | --model_id PEMS04_96_48 \
52 | --model $model_name \
53 | --data PEMS \
54 | --features M \
55 | --seq_len 96 \
56 | --pred_len 48 \
57 | --e_layers 4 \
58 | --enc_in 307 \
59 | --dec_in 307 \
60 | --c_out 307 \
61 | --des 'Exp' \
62 | --d_model 1024 \
63 | --d_ff 1024 \
64 | --learning_rate 0.0005 \
65 | --itr 1 \
66 | --use_norm 0
67 |
68 | python -u run.py \
69 | --is_training 1 \
70 | --root_path ./dataset/PEMS/ \
71 | --data_path PEMS04.npz \
72 | --model_id PEMS04_96_96 \
73 | --model $model_name \
74 | --data PEMS \
75 | --features M \
76 | --seq_len 96 \
77 | --pred_len 96 \
78 | --e_layers 4 \
79 | --enc_in 307 \
80 | --dec_in 307 \
81 | --c_out 307 \
82 | --des 'Exp' \
83 | --d_model 1024 \
84 | --d_ff 1024 \
85 | --learning_rate 0.0005 \
86 | --itr 1 \
87 | --use_norm 0
88 |
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/PEMS/iTransformer_07.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/PEMS/ \
8 | --data_path PEMS07.npz \
9 | --model_id PEMS07_96_12 \
10 | --model $model_name \
11 | --data PEMS \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 12 \
15 | --e_layers 2 \
16 | --enc_in 883 \
17 | --dec_in 883 \
18 | --c_out 883 \
19 | --des 'Exp' \
20 | --d_model 512 \
21 | --d_ff 512 \
22 | --learning_rate 0.001 \
23 | --itr 1 \
24 | --use_norm 0
25 |
26 | python -u run.py \
27 | --is_training 1 \
28 | --root_path ./dataset/PEMS/ \
29 | --data_path PEMS07.npz \
30 | --model_id PEMS07_96_24 \
31 | --model $model_name \
32 | --data PEMS \
33 | --features M \
34 | --seq_len 96 \
35 | --pred_len 24 \
36 | --e_layers 2 \
37 | --enc_in 883 \
38 | --dec_in 883 \
39 | --c_out 883 \
40 | --des 'Exp' \
41 | --d_model 512 \
42 | --d_ff 512 \
43 | --learning_rate 0.001 \
44 | --itr 1 \
45 | --use_norm 0
46 |
47 | python -u run.py \
48 | --is_training 1 \
49 | --root_path ./dataset/PEMS/ \
50 | --data_path PEMS07.npz \
51 | --model_id PEMS07_96_48 \
52 | --model $model_name \
53 | --data PEMS \
54 | --features M \
55 | --seq_len 96 \
56 | --pred_len 48 \
57 | --e_layers 4 \
58 | --enc_in 883 \
59 | --dec_in 883 \
60 | --c_out 883 \
61 | --des 'Exp' \
62 | --d_model 512 \
63 | --d_ff 512 \
64 | --batch_size 16\
65 | --learning_rate 0.001 \
66 | --itr 1 \
67 | --use_norm 0
68 |
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path ./dataset/PEMS/ \
72 | --data_path PEMS07.npz \
73 | --model_id PEMS07_96_96 \
74 | --model $model_name \
75 | --data PEMS \
76 | --features M \
77 | --seq_len 96 \
78 | --pred_len 96 \
79 | --e_layers 4 \
80 | --enc_in 883 \
81 | --dec_in 883 \
82 | --c_out 883 \
83 | --des 'Exp' \
84 | --d_model 512 \
85 | --d_ff 512 \
86 | --batch_size 16\
87 | --learning_rate 0.001 \
88 | --itr 1 \
89 | --use_norm 0
90 |
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/PEMS/iTransformer_08.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/PEMS/ \
8 | --data_path PEMS08.npz \
9 | --model_id PEMS08_96_12 \
10 | --model $model_name \
11 | --data PEMS \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 12 \
15 | --e_layers 2 \
16 | --enc_in 170 \
17 | --dec_in 170 \
18 | --c_out 170 \
19 | --des 'Exp' \
20 | --d_model 512 \
21 | --d_ff 512 \
22 | --itr 1 \
23 | --use_norm 1
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/PEMS/ \
28 | --data_path PEMS08.npz \
29 | --model_id PEMS08_96_24 \
30 | --model $model_name \
31 | --data PEMS \
32 | --features M \
33 | --seq_len 96 \
34 | --pred_len 24 \
35 | --e_layers 2 \
36 | --enc_in 170 \
37 | --dec_in 170 \
38 | --c_out 170 \
39 | --des 'Exp' \
40 | --d_model 512 \
41 | --d_ff 512 \
42 | --itr 1 \
43 | --use_norm 1
44 |
45 | python -u run.py \
46 | --is_training 1 \
47 | --root_path ./dataset/PEMS/ \
48 | --data_path PEMS08.npz \
49 | --model_id PEMS08_96_48 \
50 | --model $model_name \
51 | --data PEMS \
52 | --features M \
53 | --seq_len 96 \
54 | --pred_len 48 \
55 | --e_layers 4 \
56 | --enc_in 170 \
57 | --dec_in 170 \
58 | --c_out 170 \
59 | --des 'Exp' \
60 | --d_model 512 \
61 | --d_ff 512 \
62 | --batch_size 16\
63 | --learning_rate 0.001 \
64 | --itr 1 \
65 | --use_norm 0
66 |
67 | python -u run.py \
68 | --is_training 1 \
69 | --root_path ./dataset/PEMS/ \
70 | --data_path PEMS08.npz \
71 | --model_id PEMS08_96_96 \
72 | --model $model_name \
73 | --data PEMS \
74 | --features M \
75 | --seq_len 96 \
76 | --pred_len 96 \
77 | --e_layers 4 \
78 | --enc_in 170 \
79 | --dec_in 170 \
80 | --c_out 170 \
81 | --des 'Exp' \
82 | --d_model 512 \
83 | --d_ff 512 \
84 | --batch_size 16\
85 | --learning_rate 0.001 \
86 | --itr 1 \
87 | --use_norm 0
88 |
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/README.md:
--------------------------------------------------------------------------------
1 | # iTransformer for Multivariate Time Series Forecasting
2 |
3 | This folder contains the reproductions of the iTransformers for Multivariate Time Series Forecasting (MTSF).
4 |
5 | ## Dataset
6 |
7 | Extensive challenging multivariate forecasting tasks are evaluated as the benchmark. We provide the download links: [Google Drive](https://drive.google.com/file/d/1l51QsKvQPcqILT3DwfjCgx8Dsg2rpjot/view?usp=drive_link) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/2ea5ca3d621e4e5ba36a/).
8 |
9 |
10 |
11 |
12 |
13 | ## Scripts
14 |
15 | In each folder named after the dataset, we provide the iTransformer experiments under four different prediction lengths as shown in the table above.
16 |
17 | ```
18 | # iTransformer on the Traffic Dataset
19 |
20 | bash ./scripts/multivariate_forecasting/Traffic/iTransformer.sh
21 | ```
22 |
23 | To evaluate the model under other input/prediction lengths, feel free to change the ```seq_len``` and ```pred_len``` arguments:
24 |
25 | ```
26 | # iTransformer on the Electricity Dataset, where 180 time steps are inputted as the observations, and the task is to predict the future 60 steps
27 |
28 | python -u run.py \
29 | --is_training 1 \
30 | --root_path ./dataset/electricity/ \
31 | --data_path electricity.csv \
32 | --model_id ECL_180_60 \
33 | --model $model_name \
34 | --data custom \
35 | --features M \
36 | --seq_len 180 \
37 | --pred_len 60 \
38 | --e_layers 3 \
39 | --enc_in 321 \
40 | --dec_in 321 \
41 | --c_out 321 \
42 | --des 'Exp' \
43 | --d_model 512 \
44 | --d_ff 512 \
45 | --batch_size 16 \
46 | --learning_rate 0.0005 \
47 | --itr 1
48 | ```
49 |
50 |
51 | ## Training on Custom Dataset
52 |
53 | To train with your own time series dataset, you can try out the following steps:
54 |
55 | 1. Read through the ```Dataset_Custom``` class under the ```data_provider/data_loader``` folder, which provides the functionality to load and process time series files.
56 | 2. The file should be ```csv``` format with the first column containing the timestamp and the following columns containing the variates of time series.
57 | 3. Set ```data=custom``` and modify the ```enc_in```, ```dec_in```, ```c_out``` arguments according to your number of variates in the training script.
58 |
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/SolarEnergy/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/Solar/ \
8 | --data_path solar_AL.txt \
9 | --model_id solar_96_96 \
10 | --model $model_name \
11 | --data Solar \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 96 \
15 | --e_layers 2 \
16 | --enc_in 137 \
17 | --dec_in 137 \
18 | --c_out 137 \
19 | --des 'Exp' \
20 | --d_model 512 \
21 | --d_ff 512 \
22 | --learning_rate 0.0005 \
23 | --itr 1
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/Solar/ \
28 | --data_path solar_AL.txt \
29 | --model_id solar_96_192 \
30 | --model $model_name \
31 | --data Solar \
32 | --features M \
33 | --seq_len 96 \
34 | --pred_len 192 \
35 | --e_layers 2 \
36 | --enc_in 137 \
37 | --dec_in 137 \
38 | --c_out 137 \
39 | --des 'Exp' \
40 | --d_model 512 \
41 | --d_ff 512 \
42 | --learning_rate 0.0005 \
43 | --itr 1
44 |
45 | python -u run.py \
46 | --is_training 1 \
47 | --root_path ./dataset/Solar/ \
48 | --data_path solar_AL.txt \
49 | --model_id solar_96_336 \
50 | --model $model_name \
51 | --data Solar \
52 | --features M \
53 | --seq_len 96 \
54 | --pred_len 336 \
55 | --e_layers 2 \
56 | --enc_in 137 \
57 | --dec_in 137 \
58 | --c_out 137 \
59 | --des 'Exp' \
60 | --d_model 512 \
61 | --d_ff 512 \
62 | --learning_rate 0.0005 \
63 | --itr 1
64 |
65 | python -u run.py \
66 | --is_training 1 \
67 | --root_path ./dataset/Solar/ \
68 | --data_path solar_AL.txt \
69 | --model_id solar_96_720 \
70 | --model $model_name \
71 | --data Solar \
72 | --features M \
73 | --seq_len 96 \
74 | --pred_len 720 \
75 | --e_layers 2 \
76 | --enc_in 137 \
77 | --dec_in 137 \
78 | --c_out 137 \
79 | --des 'Exp' \
80 | --d_model 512 \
81 | --d_ff 512 \
82 | --learning_rate 0.0005 \
83 | --itr 1
84 |
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/Traffic/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/traffic/ \
8 | --data_path traffic.csv \
9 | --model_id traffic_96_96 \
10 | --model $model_name \
11 | --data custom \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 96 \
15 | --e_layers 4 \
16 | --enc_in 862 \
17 | --dec_in 862 \
18 | --c_out 862 \
19 | --des 'Exp' \
20 | --d_model 512\
21 | --d_ff 512 \
22 | --batch_size 16 \
23 | --learning_rate 0.001 \
24 | --itr 1
25 |
26 | python -u run.py \
27 | --is_training 1 \
28 | --root_path ./dataset/traffic/ \
29 | --data_path traffic.csv \
30 | --model_id traffic_96_192 \
31 | --model $model_name \
32 | --data custom \
33 | --features M \
34 | --seq_len 96 \
35 | --pred_len 192 \
36 | --e_layers 4 \
37 | --enc_in 862 \
38 | --dec_in 862 \
39 | --c_out 862 \
40 | --des 'Exp' \
41 | --d_model 512 \
42 | --d_ff 512 \
43 | --batch_size 16 \
44 | --learning_rate 0.001 \
45 | --itr 1
46 |
47 | python -u run.py \
48 | --is_training 1 \
49 | --root_path ./dataset/traffic/ \
50 | --data_path traffic.csv \
51 | --model_id traffic_96_336 \
52 | --model $model_name \
53 | --data custom \
54 | --features M \
55 | --seq_len 96 \
56 | --pred_len 336 \
57 | --e_layers 4 \
58 | --enc_in 862 \
59 | --dec_in 862 \
60 | --c_out 862 \
61 | --des 'Exp' \
62 | --d_model 512\
63 | --d_ff 512 \
64 | --batch_size 16 \
65 | --learning_rate 0.001 \
66 | --itr 1
67 |
68 | python -u run.py \
69 | --is_training 1 \
70 | --root_path ./dataset/traffic/ \
71 | --data_path traffic.csv \
72 | --model_id traffic_96_720 \
73 | --model $model_name \
74 | --data custom \
75 | --features M \
76 | --seq_len 96 \
77 | --pred_len 720 \
78 | --e_layers 4 \
79 | --enc_in 862 \
80 | --dec_in 862 \
81 | --c_out 862 \
82 | --des 'Exp' \
83 | --d_model 512 \
84 | --d_ff 512 \
85 | --batch_size 16 \
86 | --learning_rate 0.001\
87 | --itr 1
--------------------------------------------------------------------------------
/scripts/multivariate_forecasting/Weather/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=iTransformer
4 |
5 | python -u run.py \
6 | --is_training 1 \
7 | --root_path ./dataset/weather/ \
8 | --data_path weather.csv \
9 | --model_id weather_96_96 \
10 | --model $model_name \
11 | --data custom \
12 | --features M \
13 | --seq_len 96 \
14 | --pred_len 96 \
15 | --e_layers 3 \
16 | --enc_in 21 \
17 | --dec_in 21 \
18 | --c_out 21 \
19 | --des 'Exp' \
20 | --d_model 512\
21 | --d_ff 512\
22 | --itr 1
23 |
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/weather/ \
28 | --data_path weather.csv \
29 | --model_id weather_96_192 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --pred_len 192 \
35 | --e_layers 3 \
36 | --enc_in 21 \
37 | --dec_in 21 \
38 | --c_out 21 \
39 | --des 'Exp' \
40 | --d_model 512\
41 | --d_ff 512\
42 | --itr 1
43 |
44 |
45 | python -u run.py \
46 | --is_training 1 \
47 | --root_path ./dataset/weather/ \
48 | --data_path weather.csv \
49 | --model_id weather_96_336 \
50 | --model $model_name \
51 | --data custom \
52 | --features M \
53 | --seq_len 96 \
54 | --pred_len 336 \
55 | --e_layers 3 \
56 | --enc_in 21 \
57 | --dec_in 21 \
58 | --c_out 21 \
59 | --des 'Exp' \
60 | --d_model 512\
61 | --d_ff 512\
62 | --itr 1
63 |
64 |
65 | python -u run.py \
66 | --is_training 1 \
67 | --root_path ./dataset/weather/ \
68 | --data_path weather.csv \
69 | --model_id weather_96_720 \
70 | --model $model_name \
71 | --data custom \
72 | --features M \
73 | --seq_len 96 \
74 | --pred_len 720 \
75 | --e_layers 3 \
76 | --enc_in 21 \
77 | --dec_in 21 \
78 | --c_out 21 \
79 | --des 'Exp' \
80 | --d_model 512\
81 | --d_ff 512\
82 | --itr 1
--------------------------------------------------------------------------------
/scripts/variate_generalization/ECL/iFlowformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_name=Flowformer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/electricity/ \
8 | # --data_path electricity.csv \
9 | # --model_id ECL_96_96 \
10 | # --model $model_name \
11 | # --data custom \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 321 \
20 | # --dec_in 321 \
21 | # --c_out 321 \
22 | # --des 'Exp' \
23 | # --itr 1
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/electricity/ \
28 | --data_path electricity.csv \
29 | --model_id ECL_96_96 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --label_len 48 \
35 | --pred_len 96 \
36 | --e_layers 2 \
37 | --d_layers 1 \
38 | --factor 3 \
39 | --enc_in 64 \
40 | --dec_in 64 \
41 | --c_out 64 \
42 | --des 'Exp' \
43 | --channel_independence true \
44 | --exp_name partial_train \
45 | --batch_size 8 \
46 | --d_model 32 \
47 | --d_ff 64 \
48 | --itr 1
49 |
50 | model_name=iFlowformer
51 |
52 | #python -u run.py \
53 | ## --is_training 1 \
54 | # --root_path ./dataset/electricity/ \
55 | # --data_path electricity.csv \
56 | # --model_id ECL_96_96 \
57 | # --model $model_name \
58 | # --data custom \
59 | # --features M \
60 | # --seq_len 96 \
61 | # --label_len 48 \
62 | # --pred_len 96 \
63 | # --e_layers 2 \
64 | # --d_layers 1 \
65 | # --factor 3 \
66 | # --enc_in 321 \
67 | # --dec_in 321 \
68 | # --c_out 321 \
69 | # --des 'Exp' \
70 | # --itr 1
71 |
72 | python -u run.py \
73 | --is_training 1 \
74 | --root_path ./dataset/electricity/ \
75 | --data_path electricity.csv \
76 | --model_id ECL_96_96 \
77 | --model $model_name \
78 | --data custom \
79 | --features M \
80 | --seq_len 96 \
81 | --label_len 48 \
82 | --pred_len 96 \
83 | --e_layers 2 \
84 | --d_layers 1 \
85 | --factor 3 \
86 | --enc_in 64 \
87 | --dec_in 64 \
88 | --c_out 64 \
89 | --des 'Exp' \
90 | --exp_name partial_train \
91 | --itr 1
92 |
--------------------------------------------------------------------------------
/scripts/variate_generalization/ECL/iInformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=Informer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/electricity/ \
8 | # --data_path electricity.csv \
9 | # --model_id ECL_96_96 \
10 | # --model $model_name \
11 | # --data custom \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 321 \
20 | # --dec_in 321 \
21 | # --c_out 321 \
22 | # --des 'Exp' \
23 | # --itr 1
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/electricity/ \
28 | --data_path electricity.csv \
29 | --model_id ECL_96_96 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --label_len 48 \
35 | --pred_len 96 \
36 | --e_layers 2 \
37 | --d_layers 1 \
38 | --factor 3 \
39 | --enc_in 64 \
40 | --dec_in 64 \
41 | --c_out 64 \
42 | --des 'Exp' \
43 | --channel_independence true \
44 | --exp_name partial_train \
45 | --batch_size 8 \
46 | --d_model 32 \
47 | --d_ff 64 \
48 | --itr 1
49 |
50 | model_name=iInformer
51 |
52 | #python -u run.py \
53 | ## --is_training 1 \
54 | # --root_path ./dataset/electricity/ \
55 | # --data_path electricity.csv \
56 | # --model_id ECL_96_96 \
57 | # --model $model_name \
58 | # --data custom \
59 | # --features M \
60 | # --seq_len 96 \
61 | # --label_len 48 \
62 | # --pred_len 96 \
63 | # --e_layers 2 \
64 | # --d_layers 1 \
65 | # --factor 3 \
66 | # --enc_in 321 \
67 | # --dec_in 321 \
68 | # --c_out 321 \
69 | # --des 'Exp' \
70 | # --itr 1
71 |
72 | python -u run.py \
73 | --is_training 1 \
74 | --root_path ./dataset/electricity/ \
75 | --data_path electricity.csv \
76 | --model_id ECL_96_96 \
77 | --model $model_name \
78 | --data custom \
79 | --features M \
80 | --seq_len 96 \
81 | --label_len 48 \
82 | --pred_len 96 \
83 | --e_layers 2 \
84 | --d_layers 1 \
85 | --factor 3 \
86 | --enc_in 64 \
87 | --dec_in 64 \
88 | --c_out 64 \
89 | --des 'Exp' \
90 | --exp_name partial_train \
91 | --itr 1
--------------------------------------------------------------------------------
/scripts/variate_generalization/ECL/iReformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=Reformer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/electricity/ \
8 | # --data_path electricity.csv \
9 | # --model_id ECL_96_96 \
10 | # --model $model_name \
11 | # --data custom \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 321 \
20 | # --dec_in 321 \
21 | # --c_out 321 \
22 | # --des 'Exp' \
23 | # --itr 1
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/electricity/ \
28 | --data_path electricity.csv \
29 | --model_id ECL_96_96 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --label_len 48 \
35 | --pred_len 96 \
36 | --e_layers 2 \
37 | --d_layers 1 \
38 | --factor 3 \
39 | --enc_in 64 \
40 | --dec_in 64 \
41 | --c_out 64 \
42 | --des 'Exp' \
43 | --channel_independence true \
44 | --exp_name partial_train \
45 | --batch_size 8 \
46 | --d_model 32 \
47 | --d_ff 64 \
48 | --itr 1
49 |
50 | model_name=iReformer
51 |
52 | #python -u run.py \
53 | ## --is_training 1 \
54 | # --root_path ./dataset/electricity/ \
55 | # --data_path electricity.csv \
56 | # --model_id ECL_96_96 \
57 | # --model $model_name \
58 | # --data custom \
59 | # --features M \
60 | # --seq_len 96 \
61 | # --label_len 48 \
62 | # --pred_len 96 \
63 | # --e_layers 2 \
64 | # --d_layers 1 \
65 | # --factor 3 \
66 | # --enc_in 321 \
67 | # --dec_in 321 \
68 | # --c_out 321 \
69 | # --des 'Exp' \
70 | # --itr 1
71 |
72 | python -u run.py \
73 | --is_training 1 \
74 | --root_path ./dataset/electricity/ \
75 | --data_path electricity.csv \
76 | --model_id ECL_96_96 \
77 | --model $model_name \
78 | --data custom \
79 | --features M \
80 | --seq_len 96 \
81 | --label_len 48 \
82 | --pred_len 96 \
83 | --e_layers 2 \
84 | --d_layers 1 \
85 | --factor 3 \
86 | --enc_in 64 \
87 | --dec_in 64 \
88 | --c_out 64 \
89 | --des 'Exp' \
90 | --exp_name partial_train \
91 | --itr 1
--------------------------------------------------------------------------------
/scripts/variate_generalization/ECL/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=3
2 |
3 | model_name=Transformer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/electricity/ \
8 | # --data_path electricity.csv \
9 | # --model_id ECL_96_96 \
10 | # --model $model_name \
11 | # --data custom \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 321 \
20 | # --dec_in 321 \
21 | # --c_out 321 \
22 | # --des 'Exp' \
23 | # --itr 1
24 |
25 | # 20% partial variates, enc_in: 64 = 321 // 5
26 | python -u run.py \
27 | --is_training 1 \
28 | --root_path ./dataset/electricity/ \
29 | --data_path electricity.csv \
30 | --model_id ECL_96_96 \
31 | --model $model_name \
32 | --data custom \
33 | --features M \
34 | --seq_len 96 \
35 | --label_len 48 \
36 | --pred_len 96 \
37 | --e_layers 2 \
38 | --d_layers 1 \
39 | --factor 3 \
40 | --enc_in 64 \
41 | --dec_in 64 \
42 | --c_out 64 \
43 | --des 'Exp' \
44 | --channel_independence true \
45 | --exp_name partial_train \
46 | --batch_size 8 \
47 | --d_model 32 \
48 | --d_ff 64 \
49 | --itr 1
50 |
51 | model_name=iTransformer
52 |
53 | #python -u run.py \
54 | ## --is_training 1 \
55 | # --root_path ./dataset/electricity/ \
56 | # --data_path electricity.csv \
57 | # --model_id ECL_96_96 \
58 | # --model $model_name \
59 | # --data custom \
60 | # --features M \
61 | # --seq_len 96 \
62 | # --label_len 48 \
63 | # --pred_len 96 \
64 | # --e_layers 2 \
65 | # --d_layers 1 \
66 | # --factor 3 \
67 | # --enc_in 321 \
68 | # --dec_in 321 \
69 | # --c_out 321 \
70 | # --des 'Exp' \
71 | # --itr 1
72 |
73 | python -u run.py \
74 | --is_training 1 \
75 | --root_path ./dataset/electricity/ \
76 | --data_path electricity.csv \
77 | --model_id ECL_96_96 \
78 | --model $model_name \
79 | --data custom \
80 | --features M \
81 | --seq_len 96 \
82 | --label_len 48 \
83 | --pred_len 96 \
84 | --e_layers 2 \
85 | --d_layers 1 \
86 | --factor 3 \
87 | --enc_in 64 \
88 | --dec_in 64 \
89 | --c_out 64 \
90 | --des 'Exp' \
91 | --exp_name partial_train \
92 | --itr 1
--------------------------------------------------------------------------------
/scripts/variate_generalization/README.md:
--------------------------------------------------------------------------------
1 | # iTransformer for Variate Generalization
2 |
3 | This folder contains the implementation of the iTransformer to generalize on unseen variates. If you are new to this repo, we recommend you to read this [README](../multivariate_forecasting/README.md) first.
4 |
5 | By inverting vanilla Transformers, the model is empowered with the generalization capability on unseen variates. Firstly, benefiting from the flexibility of the number of input tokens, the amount of variate channels is no longer restricted and thus feasible to vary from training and inference. Second, feed-forward networks are identically applied on independent variate tokens to learn transferable representations of time series.
6 |
7 | ## Scripts
8 |
9 | ```
10 | # Train models with only 20% of variates from Traffic and test the model on all variates without finetuning
11 |
12 | bash ./scripts/variate_generalization/Traffic/iTransformer.sh
13 | ```
14 |
15 | > During Training
16 |
17 |
18 |
19 |
20 | > During Inference
21 |
22 |
23 |
24 |
25 | In each folder named after the dataset, we provide two strategies to enable Transformers to generalize on unseen variate.
26 |
27 | * **CI-Transformers**: Channel Independence regards each variate of time series as independent channels, and uses a shared backbone to forecast all variates. Therefore, the model can predict variates one by one, but the training and inference procedure can be time-consuming.
28 |
29 | * **iTransformers**: benefiting from the flexibility of attention that the number of input tokens can be dynamically changeable, the amount of variates as tokens is no longer restricted, and can even allow the model to be trained on arbitrary variables.
30 |
31 | ## Results
32 |
33 |
34 |
35 |
36 |
37 | iTransformers can be naturally trained with 20% variates and accomplish forecast on all variates with the ability to learn transferable representations.
--------------------------------------------------------------------------------
/scripts/variate_generalization/SolarEnergy/iFlowformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_name=Flowformer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/Solar/ \
8 | # --data_path solar_AL.txt \
9 | # --model_id solar_96_96 \
10 | # --model $model_name \
11 | # --data Solar \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 137 \
20 | # --dec_in 137 \
21 | # --c_out 137 \
22 | # --des 'Exp' \
23 | # --learning_rate 0.0005 \
24 | # --itr 1
25 |
26 | python -u run.py \
27 | --is_training 1 \
28 | --root_path ./dataset/Solar/ \
29 | --data_path solar_AL.txt \
30 | --model_id solar_96_96 \
31 | --model $model_name \
32 | --data Solar \
33 | --features M \
34 | --seq_len 96 \
35 | --label_len 48 \
36 | --pred_len 96 \
37 | --e_layers 2 \
38 | --d_layers 1 \
39 | --factor 3 \
40 | --enc_in 27 \
41 | --dec_in 27 \
42 | --c_out 27 \
43 | --des 'Exp' \
44 | --d_model 32 \
45 | --d_ff 64 \
46 | --learning_rate 0.0005 \
47 | --channel_independence true \
48 | --exp_name partial_train \
49 | --batch_size 8 \
50 | --itr 1
51 |
52 | model_name=iFlowformer
53 |
54 | #python -u run.py \
55 | ## --is_training 1 \
56 | # --root_path ./dataset/Solar/ \
57 | # --data_path solar_AL.txt \
58 | # --model_id solar_96_96 \
59 | # --model $model_name \
60 | # --data Solar \
61 | # --features M \
62 | # --seq_len 96 \
63 | # --label_len 48 \
64 | # --pred_len 96 \
65 | # --e_layers 2 \
66 | # --d_layers 1 \
67 | # --factor 3 \
68 | # --enc_in 137 \
69 | # --dec_in 137 \
70 | # --c_out 137 \
71 | # --des 'Exp' \
72 | # --learning_rate 0.0005 \
73 | # --itr 1
74 |
75 | python -u run.py \
76 | --is_training 1 \
77 | --root_path ./dataset/Solar/ \
78 | --data_path solar_AL.txt \
79 | --model_id solar_96_96 \
80 | --model $model_name \
81 | --data Solar \
82 | --features M \
83 | --seq_len 96 \
84 | --label_len 48 \
85 | --pred_len 96 \
86 | --e_layers 2 \
87 | --d_layers 1 \
88 | --factor 3 \
89 | --enc_in 27 \
90 | --dec_in 27 \
91 | --c_out 27 \
92 | --des 'Exp' \
93 | --learning_rate 0.0005 \
94 | --exp_name partial_train \
95 | --itr 1
--------------------------------------------------------------------------------
/scripts/variate_generalization/SolarEnergy/iInformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=Informer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/Solar/ \
8 | # --data_path solar_AL.txt \
9 | # --model_id solar_96_96 \
10 | # --model $model_name \
11 | # --data Solar \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 137 \
20 | # --dec_in 137 \
21 | # --c_out 137 \
22 | # --des 'Exp' \
23 | # --learning_rate 0.0005 \
24 | # --itr 1
25 |
26 | python -u run.py \
27 | --is_training 1 \
28 | --root_path ./dataset/Solar/ \
29 | --data_path solar_AL.txt \
30 | --model_id solar_96_96 \
31 | --model $model_name \
32 | --data Solar \
33 | --features M \
34 | --seq_len 96 \
35 | --label_len 48 \
36 | --pred_len 96 \
37 | --e_layers 2 \
38 | --d_layers 1 \
39 | --factor 3 \
40 | --enc_in 27 \
41 | --dec_in 27 \
42 | --c_out 27 \
43 | --des 'Exp' \
44 | --d_model 32 \
45 | --d_ff 64 \
46 | --learning_rate 0.0005 \
47 | --channel_independence true \
48 | --exp_name partial_train \
49 | --batch_size 8 \
50 | --itr 1
51 |
52 | model_name=iInformer
53 |
54 | #python -u run.py \
55 | ## --is_training 1 \
56 | # --root_path ./dataset/Solar/ \
57 | # --data_path solar_AL.txt \
58 | # --model_id solar_96_96 \
59 | # --model $model_name \
60 | # --data Solar \
61 | # --features M \
62 | # --seq_len 96 \
63 | # --label_len 48 \
64 | # --pred_len 96 \
65 | # --e_layers 2 \
66 | # --d_layers 1 \
67 | # --factor 3 \
68 | # --enc_in 137 \
69 | # --dec_in 137 \
70 | # --c_out 137 \
71 | # --des 'Exp' \
72 | # --learning_rate 0.0005 \
73 | # --itr 1
74 |
75 | python -u run.py \
76 | --is_training 1 \
77 | --root_path ./dataset/Solar/ \
78 | --data_path solar_AL.txt \
79 | --model_id solar_96_96 \
80 | --model $model_name \
81 | --data Solar \
82 | --features M \
83 | --seq_len 96 \
84 | --label_len 48 \
85 | --pred_len 96 \
86 | --e_layers 2 \
87 | --d_layers 1 \
88 | --factor 3 \
89 | --enc_in 27 \
90 | --dec_in 27 \
91 | --c_out 27 \
92 | --des 'Exp' \
93 | --learning_rate 0.0005 \
94 | --exp_name partial_train \
95 | --itr 1
--------------------------------------------------------------------------------
/scripts/variate_generalization/SolarEnergy/iReformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=Reformer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/Solar/ \
8 | # --data_path solar_AL.txt \
9 | # --model_id solar_96_96 \
10 | # --model $model_name \
11 | # --data Solar \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 137 \
20 | # --dec_in 137 \
21 | # --c_out 137 \
22 | # --des 'Exp' \
23 | # --learning_rate 0.0005 \
24 | # --itr 1
25 |
26 | python -u run.py \
27 | --is_training 1 \
28 | --root_path ./dataset/Solar/ \
29 | --data_path solar_AL.txt \
30 | --model_id solar_96_96 \
31 | --model $model_name \
32 | --data Solar \
33 | --features M \
34 | --seq_len 96 \
35 | --label_len 48 \
36 | --pred_len 96 \
37 | --e_layers 2 \
38 | --d_layers 1 \
39 | --factor 3 \
40 | --enc_in 27 \
41 | --dec_in 27 \
42 | --c_out 27 \
43 | --des 'Exp' \
44 | --d_model 32 \
45 | --d_ff 64 \
46 | --learning_rate 0.0005 \
47 | --channel_independence true \
48 | --exp_name partial_train \
49 | --batch_size 8 \
50 | --itr 1
51 |
52 | model_name=iReformer
53 |
54 | #python -u run.py \
55 | ## --is_training 1 \
56 | # --root_path ./dataset/Solar/ \
57 | # --data_path solar_AL.txt \
58 | # --model_id solar_96_96 \
59 | # --model $model_name \
60 | # --data Solar \
61 | # --features M \
62 | # --seq_len 96 \
63 | # --label_len 48 \
64 | # --pred_len 96 \
65 | # --e_layers 2 \
66 | # --d_layers 1 \
67 | # --factor 3 \
68 | # --enc_in 137 \
69 | # --dec_in 137 \
70 | # --c_out 137 \
71 | # --des 'Exp' \
72 | # --learning_rate 0.0005 \
73 | # --itr
74 |
75 | python -u run.py \
76 | --is_training 1 \
77 | --root_path ./dataset/Solar/ \
78 | --data_path solar_AL.txt \
79 | --model_id solar_96_96 \
80 | --model $model_name \
81 | --data Solar \
82 | --features M \
83 | --seq_len 96 \
84 | --label_len 48 \
85 | --pred_len 96 \
86 | --e_layers 2 \
87 | --d_layers 1 \
88 | --factor 3 \
89 | --enc_in 27 \
90 | --dec_in 27 \
91 | --c_out 27 \
92 | --des 'Exp' \
93 | --learning_rate 0.0005 \
94 | --exp_name partial_train \
95 | --itr 1
--------------------------------------------------------------------------------
/scripts/variate_generalization/SolarEnergy/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=3
2 |
3 | model_name=Transformer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/Solar/ \
8 | # --data_path solar_AL.txt \
9 | # --model_id solar_96_96 \
10 | # --model $model_name \
11 | # --data Solar \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 137 \
20 | # --dec_in 137 \
21 | # --c_out 137 \
22 | # --des 'Exp' \
23 | # --learning_rate 0.0005 \
24 | # --itr 1
25 |
26 | # 20% partial variates: 27 = 137 // 5
27 | python -u run.py \
28 | --is_training 1 \
29 | --root_path ./dataset/Solar/ \
30 | --data_path solar_AL.txt \
31 | --model_id solar_96_96 \
32 | --model $model_name \
33 | --data Solar \
34 | --features M \
35 | --seq_len 96 \
36 | --label_len 48 \
37 | --pred_len 96 \
38 | --e_layers 2 \
39 | --d_layers 1 \
40 | --factor 3 \
41 | --enc_in 27 \
42 | --dec_in 27 \
43 | --c_out 27 \
44 | --des 'Exp' \
45 | --d_model 32 \
46 | --d_ff 64 \
47 | --learning_rate 0.0005 \
48 | --channel_independence true \
49 | --exp_name partial_train \
50 | --batch_size 8 \
51 | --itr 1
52 |
53 | model_name=iTransformer
54 |
55 | #python -u run.py \
56 | ## --is_training 1 \
57 | # --root_path ./dataset/Solar/ \
58 | # --data_path solar_AL.txt \
59 | # --model_id solar_96_96 \
60 | # --model $model_name \
61 | # --data Solar \
62 | # --features M \
63 | # --seq_len 96 \
64 | # --label_len 48 \
65 | # --pred_len 96 \
66 | # --e_layers 2 \
67 | # --d_layers 1 \
68 | # --factor 3 \
69 | # --enc_in 137 \
70 | # --dec_in 137 \
71 | # --c_out 137 \
72 | # --des 'Exp' \
73 | # --learning_rate 0.0005 \
74 | # --itr 1
75 |
76 | python -u run.py \
77 | --is_training 1 \
78 | --root_path ./dataset/Solar/ \
79 | --data_path solar_AL.txt \
80 | --model_id solar_96_96 \
81 | --model $model_name \
82 | --data Solar \
83 | --features M \
84 | --seq_len 96 \
85 | --label_len 48 \
86 | --pred_len 96 \
87 | --e_layers 2 \
88 | --d_layers 1 \
89 | --factor 3 \
90 | --enc_in 27 \
91 | --dec_in 27 \
92 | --c_out 27 \
93 | --des 'Exp' \
94 | --learning_rate 0.0005 \
95 | --exp_name partial_train \
96 | --itr 1
--------------------------------------------------------------------------------
/scripts/variate_generalization/Traffic/iFlowformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=0
2 |
3 | model_name=Flowformer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/traffic/ \
8 | # --data_path traffic.csv \
9 | # --model_id traffic_96_96 \
10 | # --model $model_name \
11 | # --data custom \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 862 \
20 | # --dec_in 862 \
21 | # --c_out 862 \
22 | # --des 'Exp' \
23 | # --itr 1
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/traffic/ \
28 | --data_path traffic.csv \
29 | --model_id traffic_96_96 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --label_len 48 \
35 | --pred_len 96 \
36 | --e_layers 2 \
37 | --d_layers 1 \
38 | --factor 3 \
39 | --enc_in 172 \
40 | --dec_in 172 \
41 | --c_out 172 \
42 | --des 'Exp' \
43 | --channel_independence true \
44 | --exp_name partial_train \
45 | --batch_size 4 \
46 | --d_model 32 \
47 | --d_ff 64 \
48 | --itr 1
49 |
50 | model_name=iFlowformer
51 |
52 | #python -u run.py \
53 | ## --is_training 1 \
54 | # --root_path ./dataset/traffic/ \
55 | # --data_path traffic.csv \
56 | # --model_id traffic_96_96 \
57 | # --model $model_name \
58 | # --data custom \
59 | # --features M \
60 | # --seq_len 96 \
61 | # --label_len 48 \
62 | # --pred_len 96 \
63 | # --e_layers 2 \
64 | # --d_layers 1 \
65 | # --factor 3 \
66 | # --enc_in 862 \
67 | # --dec_in 862 \
68 | # --c_out 862 \
69 | # --des 'Exp' \
70 | # --itr 1
71 |
72 | python -u run.py \
73 | --is_training 1 \
74 | --root_path ./dataset/traffic/ \
75 | --data_path traffic.csv \
76 | --model_id traffic_96_96 \
77 | --model $model_name \
78 | --data custom \
79 | --features M \
80 | --seq_len 96 \
81 | --label_len 48 \
82 | --pred_len 96 \
83 | --e_layers 2 \
84 | --d_layers 1 \
85 | --factor 3 \
86 | --enc_in 172 \
87 | --dec_in 172 \
88 | --c_out 172 \
89 | --des 'Exp' \
90 | --exp_name partial_train \
91 | --itr 1
92 |
--------------------------------------------------------------------------------
/scripts/variate_generalization/Traffic/iInformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=1
2 |
3 | model_name=Informer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/traffic/ \
8 | # --data_path traffic.csv \
9 | # --model_id traffic_96_96 \
10 | # --model $model_name \
11 | # --data custom \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 862 \
20 | # --dec_in 862 \
21 | # --c_out 862 \
22 | # --des 'Exp' \
23 | # --itr 1
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/traffic/ \
28 | --data_path traffic.csv \
29 | --model_id traffic_96_96 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --label_len 48 \
35 | --pred_len 96 \
36 | --e_layers 2 \
37 | --d_layers 1 \
38 | --factor 3 \
39 | --enc_in 172 \
40 | --dec_in 172 \
41 | --c_out 172 \
42 | --des 'Exp' \
43 | --channel_independence true \
44 | --exp_name partial_train \
45 | --batch_size 4 \
46 | --d_model 32 \
47 | --d_ff 64 \
48 | --itr 1
49 |
50 | model_name=iInformer
51 |
52 | #python -u run.py \
53 | ## --is_training 1 \
54 | # --root_path ./dataset/traffic/ \
55 | # --data_path traffic.csv \
56 | # --model_id traffic_96_96 \
57 | # --model $model_name \
58 | # --data custom \
59 | # --features M \
60 | # --seq_len 96 \
61 | # --label_len 48 \
62 | # --pred_len 96 \
63 | # --e_layers 2 \
64 | # --d_layers 1 \
65 | # --factor 3 \
66 | # --enc_in 862 \
67 | # --dec_in 862 \
68 | # --c_out 862 \
69 | # --des 'Exp' \
70 | # --itr 1
71 |
72 | python -u run.py \
73 | --is_training 1 \
74 | --root_path ./dataset/traffic/ \
75 | --data_path traffic.csv \
76 | --model_id traffic_96_96 \
77 | --model $model_name \
78 | --data custom \
79 | --features M \
80 | --seq_len 96 \
81 | --label_len 48 \
82 | --pred_len 96 \
83 | --e_layers 2 \
84 | --d_layers 1 \
85 | --factor 3 \
86 | --enc_in 172 \
87 | --dec_in 172 \
88 | --c_out 172 \
89 | --des 'Exp' \
90 | --exp_name partial_train \
91 | --itr 1
92 |
--------------------------------------------------------------------------------
/scripts/variate_generalization/Traffic/iReformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=2
2 |
3 | model_name=Reformer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/traffic/ \
8 | # --data_path traffic.csv \
9 | # --model_id traffic_96_96 \
10 | # --model $model_name \
11 | # --data custom \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 862 \
20 | # --dec_in 862 \
21 | # --c_out 862 \
22 | # --des 'Exp' \
23 | # --itr 1
24 |
25 | python -u run.py \
26 | --is_training 1 \
27 | --root_path ./dataset/traffic/ \
28 | --data_path traffic.csv \
29 | --model_id traffic_96_96 \
30 | --model $model_name \
31 | --data custom \
32 | --features M \
33 | --seq_len 96 \
34 | --label_len 48 \
35 | --pred_len 96 \
36 | --e_layers 2 \
37 | --d_layers 1 \
38 | --factor 3 \
39 | --enc_in 172 \
40 | --dec_in 172 \
41 | --c_out 172 \
42 | --des 'Exp' \
43 | --channel_independence true \
44 | --exp_name partial_train \
45 | --batch_size 4 \
46 | --d_model 32 \
47 | --d_ff 64 \
48 | --itr 1
49 |
50 | model_name=iReformer
51 |
52 | #python -u run.py \
53 | ## --is_training 1 \
54 | # --root_path ./dataset/traffic/ \
55 | # --data_path traffic.csv \
56 | # --model_id traffic_96_96 \
57 | # --model $model_name \
58 | # --data custom \
59 | # --features M \
60 | # --seq_len 96 \
61 | # --label_len 48 \
62 | # --pred_len 96 \
63 | # --e_layers 2 \
64 | # --d_layers 1 \
65 | # --factor 3 \
66 | # --enc_in 862 \
67 | # --dec_in 862 \
68 | # --c_out 862 \
69 | # --des 'Exp' \
70 | # --itr 1
71 |
72 | python -u run.py \
73 | --is_training 1 \
74 | --root_path ./dataset/traffic/ \
75 | --data_path traffic.csv \
76 | --model_id traffic_96_96 \
77 | --model $model_name \
78 | --data custom \
79 | --features M \
80 | --seq_len 96 \
81 | --label_len 48 \
82 | --pred_len 96 \
83 | --e_layers 2 \
84 | --d_layers 1 \
85 | --factor 3 \
86 | --enc_in 172 \
87 | --dec_in 172 \
88 | --c_out 172 \
89 | --des 'Exp' \
90 | --exp_name partial_train \
91 | --itr 1
--------------------------------------------------------------------------------
/scripts/variate_generalization/Traffic/iTransformer.sh:
--------------------------------------------------------------------------------
1 | export CUDA_VISIBLE_DEVICES=3
2 |
3 | model_name=Transformer
4 |
5 | #python -u run.py \
6 | ## --is_training 1 \
7 | # --root_path ./dataset/traffic/ \
8 | # --data_path traffic.csv \
9 | # --model_id traffic_96_96 \
10 | # --model $model_name \
11 | # --data custom \
12 | # --features M \
13 | # --seq_len 96 \
14 | # --label_len 48 \
15 | # --pred_len 96 \
16 | # --e_layers 2 \
17 | # --d_layers 1 \
18 | # --factor 3 \
19 | # --enc_in 862 \
20 | # --dec_in 862 \
21 | # --c_out 862 \
22 | # --des 'Exp' \
23 | # --itr 1
24 |
25 | # 20% partial variates, enc_in: 172 = 862 // 5
26 | python -u run.py \
27 | --is_training 1 \
28 | --root_path ./dataset/traffic/ \
29 | --data_path traffic.csv \
30 | --model_id traffic_96_96 \
31 | --model $model_name \
32 | --data custom \
33 | --features M \
34 | --seq_len 96 \
35 | --label_len 48 \
36 | --pred_len 96 \
37 | --e_layers 2 \
38 | --d_layers 1 \
39 | --factor 3 \
40 | --enc_in 172 \
41 | --dec_in 172 \
42 | --c_out 172 \
43 | --des 'Exp' \
44 | --channel_independence true \
45 | --exp_name partial_train \
46 | --batch_size 8 \
47 | --d_model 32 \
48 | --d_ff 64 \
49 | --itr 1
50 |
51 | model_name=iTransformer
52 |
53 | #python -u run.py \
54 | ## --is_training 1 \
55 | # --root_path ./dataset/traffic/ \
56 | # --data_path traffic.csv \
57 | # --model_id traffic_96_96 \
58 | # --model $model_name \
59 | # --data custom \
60 | # --features M \
61 | # --seq_len 96 \
62 | # --label_len 48 \
63 | # --pred_len 96 \
64 | # --e_layers 2 \
65 | # --d_layers 1 \
66 | # --factor 3 \
67 | # --enc_in 862 \
68 | # --dec_in 862 \
69 | # --c_out 862 \
70 | # --des 'Exp' \
71 | # --itr 1
72 |
73 | python -u run.py \
74 | --is_training 1 \
75 | --root_path ./dataset/traffic/ \
76 | --data_path traffic.csv \
77 | --model_id traffic_96_96 \
78 | --model $model_name \
79 | --data custom \
80 | --features M \
81 | --seq_len 96 \
82 | --label_len 48 \
83 | --pred_len 96 \
84 | --e_layers 2 \
85 | --d_layers 1 \
86 | --factor 3 \
87 | --enc_in 172 \
88 | --dec_in 172 \
89 | --c_out 172 \
90 | --des 'Exp' \
91 | --exp_name partial_train \
92 | --itr 1
93 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thuml/iTransformer/f8c144c075e4b9aa8a2c82dc5d51031dee0fef1f/utils/__init__.py
--------------------------------------------------------------------------------
/utils/masking.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TriangularCausalMask():
5 | def __init__(self, B, L, device="cpu"):
6 | mask_shape = [B, 1, L, L]
7 | with torch.no_grad():
8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
9 |
10 | @property
11 | def mask(self):
12 | return self._mask
13 |
14 |
15 | class ProbMask():
16 | def __init__(self, B, H, L, index, scores, device="cpu"):
17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
19 | indicator = _mask_ex[torch.arange(B)[:, None, None],
20 | torch.arange(H)[None, :, None],
21 | index, :].to(device)
22 | self._mask = indicator.view(scores.shape).to(device)
23 |
24 | @property
25 | def mask(self):
26 | return self._mask
27 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def RSE(pred, true):
5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
6 |
7 |
8 | def CORR(pred, true):
9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
11 | return (u / d).mean(-1)
12 |
13 |
14 | def MAE(pred, true):
15 | return np.mean(np.abs(pred - true))
16 |
17 |
18 | def MSE(pred, true):
19 | return np.mean((pred - true) ** 2)
20 |
21 |
22 | def RMSE(pred, true):
23 | return np.sqrt(MSE(pred, true))
24 |
25 |
26 | def MAPE(pred, true):
27 | return np.mean(np.abs((pred - true) / true))
28 |
29 |
30 | def MSPE(pred, true):
31 | return np.mean(np.square((pred - true) / true))
32 |
33 |
34 | def metric(pred, true):
35 | mae = MAE(pred, true)
36 | mse = MSE(pred, true)
37 | rmse = RMSE(pred, true)
38 | mape = MAPE(pred, true)
39 | mspe = MSPE(pred, true)
40 |
41 | return mae, mse, rmse, mape, mspe
42 |
--------------------------------------------------------------------------------
/utils/timefeatures.py:
--------------------------------------------------------------------------------
1 | # From: gluonts/src/gluonts/time_feature/_base.py
2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License").
5 | # You may not use this file except in compliance with the License.
6 | # A copy of the License is located at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # or in the "license" file accompanying this file. This file is distributed
11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12 | # express or implied. See the License for the specific language governing
13 | # permissions and limitations under the License.
14 |
15 | from typing import List
16 |
17 | import numpy as np
18 | import pandas as pd
19 | from pandas.tseries import offsets
20 | from pandas.tseries.frequencies import to_offset
21 |
22 |
23 | class TimeFeature:
24 | def __init__(self):
25 | pass
26 |
27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
28 | pass
29 |
30 | def __repr__(self):
31 | return self.__class__.__name__ + "()"
32 |
33 |
34 | class SecondOfMinute(TimeFeature):
35 | """Minute of hour encoded as value between [-0.5, 0.5]"""
36 |
37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
38 | return index.second / 59.0 - 0.5
39 |
40 |
41 | class MinuteOfHour(TimeFeature):
42 | """Minute of hour encoded as value between [-0.5, 0.5]"""
43 |
44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
45 | return index.minute / 59.0 - 0.5
46 |
47 |
48 | class HourOfDay(TimeFeature):
49 | """Hour of day encoded as value between [-0.5, 0.5]"""
50 |
51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
52 | return index.hour / 23.0 - 0.5
53 |
54 |
55 | class DayOfWeek(TimeFeature):
56 | """Hour of day encoded as value between [-0.5, 0.5]"""
57 |
58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
59 | return index.dayofweek / 6.0 - 0.5
60 |
61 |
62 | class DayOfMonth(TimeFeature):
63 | """Day of month encoded as value between [-0.5, 0.5]"""
64 |
65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
66 | return (index.day - 1) / 30.0 - 0.5
67 |
68 |
69 | class DayOfYear(TimeFeature):
70 | """Day of year encoded as value between [-0.5, 0.5]"""
71 |
72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
73 | return (index.dayofyear - 1) / 365.0 - 0.5
74 |
75 |
76 | class MonthOfYear(TimeFeature):
77 | """Month of year encoded as value between [-0.5, 0.5]"""
78 |
79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
80 | return (index.month - 1) / 11.0 - 0.5
81 |
82 |
83 | class WeekOfYear(TimeFeature):
84 | """Week of year encoded as value between [-0.5, 0.5]"""
85 |
86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
87 | return (index.isocalendar().week - 1) / 52.0 - 0.5
88 |
89 |
90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
91 | """
92 | Returns a list of time features that will be appropriate for the given frequency string.
93 | Parameters
94 | ----------
95 | freq_str
96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
97 | """
98 |
99 | features_by_offsets = {
100 | offsets.YearEnd: [],
101 | offsets.QuarterEnd: [MonthOfYear],
102 | offsets.MonthEnd: [MonthOfYear],
103 | offsets.Week: [DayOfMonth, WeekOfYear],
104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
107 | offsets.Minute: [
108 | MinuteOfHour,
109 | HourOfDay,
110 | DayOfWeek,
111 | DayOfMonth,
112 | DayOfYear,
113 | ],
114 | offsets.Second: [
115 | SecondOfMinute,
116 | MinuteOfHour,
117 | HourOfDay,
118 | DayOfWeek,
119 | DayOfMonth,
120 | DayOfYear,
121 | ],
122 | }
123 |
124 | offset = to_offset(freq_str)
125 |
126 | for offset_type, feature_classes in features_by_offsets.items():
127 | if isinstance(offset, offset_type):
128 | return [cls() for cls in feature_classes]
129 |
130 | supported_freq_msg = f"""
131 | Unsupported frequency {freq_str}
132 | The following frequencies are supported:
133 | Y - yearly
134 | alias: A
135 | M - monthly
136 | W - weekly
137 | D - daily
138 | B - business days
139 | H - hourly
140 | T - minutely
141 | alias: min
142 | S - secondly
143 | """
144 | raise RuntimeError(supported_freq_msg)
145 |
146 |
147 | def time_features(dates, freq='h'):
148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])
149 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | import matplotlib.pyplot as plt
6 | import pandas as pd
7 |
8 | plt.switch_backend('agg')
9 |
10 |
11 | def adjust_learning_rate(optimizer, epoch, args):
12 | # lr = args.learning_rate * (0.2 ** (epoch // 2))
13 | if args.lradj == 'type1':
14 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
15 | elif args.lradj == 'type2':
16 | lr_adjust = {
17 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
18 | 10: 5e-7, 15: 1e-7, 20: 5e-8
19 | }
20 | if epoch in lr_adjust.keys():
21 | lr = lr_adjust[epoch]
22 | for param_group in optimizer.param_groups:
23 | param_group['lr'] = lr
24 | print('Updating learning rate to {}'.format(lr))
25 |
26 |
27 | class EarlyStopping:
28 | def __init__(self, patience=7, verbose=False, delta=0):
29 | self.patience = patience
30 | self.verbose = verbose
31 | self.counter = 0
32 | self.best_score = None
33 | self.early_stop = False
34 | self.val_loss_min = np.Inf
35 | self.delta = delta
36 |
37 | def __call__(self, val_loss, model, path):
38 | score = -val_loss
39 | if self.best_score is None:
40 | self.best_score = score
41 | self.save_checkpoint(val_loss, model, path)
42 | elif score < self.best_score + self.delta:
43 | self.counter += 1
44 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
45 | if self.counter >= self.patience:
46 | self.early_stop = True
47 | else:
48 | self.best_score = score
49 | self.save_checkpoint(val_loss, model, path)
50 | self.counter = 0
51 |
52 | def save_checkpoint(self, val_loss, model, path):
53 | if self.verbose:
54 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
55 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
56 | self.val_loss_min = val_loss
57 |
58 |
59 | class dotdict(dict):
60 | """dot.notation access to dictionary attributes"""
61 | __getattr__ = dict.get
62 | __setattr__ = dict.__setitem__
63 | __delattr__ = dict.__delitem__
64 |
65 |
66 | class StandardScaler():
67 | def __init__(self, mean, std):
68 | self.mean = mean
69 | self.std = std
70 |
71 | def transform(self, data):
72 | return (data - self.mean) / self.std
73 |
74 | def inverse_transform(self, data):
75 | return (data * self.std) + self.mean
76 |
77 |
78 | def visual(true, preds=None, name='./pic/test.pdf'):
79 | """
80 | Results visualization
81 | """
82 | plt.figure()
83 | plt.plot(true, label='GroundTruth', linewidth=2)
84 | if preds is not None:
85 | plt.plot(preds, label='Prediction', linewidth=2)
86 | plt.legend()
87 | plt.savefig(name, bbox_inches='tight')
88 |
89 |
90 | def adjustment(gt, pred):
91 | anomaly_state = False
92 | for i in range(len(gt)):
93 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state:
94 | anomaly_state = True
95 | for j in range(i, 0, -1):
96 | if gt[j] == 0:
97 | break
98 | else:
99 | if pred[j] == 0:
100 | pred[j] = 1
101 | for j in range(i, len(gt)):
102 | if gt[j] == 0:
103 | break
104 | else:
105 | if pred[j] == 0:
106 | pred[j] = 1
107 | elif gt[i] == 0:
108 | anomaly_state = False
109 | if anomaly_state:
110 | pred[i] = 1
111 | return gt, pred
112 |
113 |
114 | def cal_accuracy(y_pred, y_true):
115 | return np.mean(y_pred == y_true)
116 |
--------------------------------------------------------------------------------