├── .gitattributes
├── .gitignore
├── Figures
├── Figure2.png
├── Figure3.png
├── Figure4.png
├── Figure6.png
├── Table2.png
├── Table4.png
├── Table5.png
└── Table6.png
├── LICENSE
├── README.md
├── acf_plot.ipynb
├── checkpoints
├── ETTm1_96_96_CycleNet_ETTm1_ftM_sl96_pl96_cycle96_linear_seed2024
│ └── checkpoint.pth
├── Electricity_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024
│ └── checkpoint.pth
├── Solar_96_96_CycleNet_Solar_ftM_sl96_pl96_cycle144_linear_seed2024
│ └── checkpoint.pth
├── traffic_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024
│ └── checkpoint.pth
└── weather_96_96_CycleNet_custom_ftM_sl96_pl96_cycle144_linear_seed2024
│ └── checkpoint.pth
├── data_provider
├── data_factory.py
└── data_loader.py
├── exp
├── exp_basic.py
└── exp_main.py
├── layers
├── AutoCorrelation.py
├── Autoformer_EncDec.py
├── Embed.py
├── PatchTST_backbone.py
├── PatchTST_layers.py
├── RevIN.py
├── SelfAttention_Family.py
└── Transformer_EncDec.py
├── models
├── Autoformer.py
├── CycleNet.py
├── CycleiTransformer.py
├── DLinear.py
├── Informer.py
├── LDLinear.py
├── Linear.py
├── NLinear.py
├── PatchTST.py
├── RLinear.py
├── RMLP.py
├── SegRNN.py
├── SparseTSF.py
└── Transformer.py
├── requirements.txt
├── result.txt
├── run.py
├── run_main.sh
├── run_pems.sh
├── run_std.sh
├── scripts
├── CycleNet
│ ├── Linear-Input-336
│ │ ├── electricity.sh
│ │ ├── etth1.sh
│ │ ├── etth2.sh
│ │ ├── ettm1.sh
│ │ ├── ettm2.sh
│ │ ├── solar.sh
│ │ ├── traffic.sh
│ │ └── weather.sh
│ ├── Linear-Input-720
│ │ ├── electricity.sh
│ │ ├── etth1.sh
│ │ ├── etth2.sh
│ │ ├── ettm1.sh
│ │ ├── ettm2.sh
│ │ ├── solar.sh
│ │ ├── traffic.sh
│ │ └── weather.sh
│ ├── Linear-Input-96
│ │ ├── electricity.sh
│ │ ├── etth1.sh
│ │ ├── etth2.sh
│ │ ├── ettm1.sh
│ │ ├── ettm2.sh
│ │ ├── solar.sh
│ │ ├── traffic.sh
│ │ └── weather.sh
│ ├── MLP-Input-336
│ │ ├── electricity.sh
│ │ ├── etth1.sh
│ │ ├── etth2.sh
│ │ ├── ettm1.sh
│ │ ├── ettm2.sh
│ │ ├── solar.sh
│ │ ├── traffic.sh
│ │ └── weather.sh
│ ├── MLP-Input-720
│ │ ├── electricity.sh
│ │ ├── etth1.sh
│ │ ├── etth2.sh
│ │ ├── ettm1.sh
│ │ ├── ettm2.sh
│ │ ├── solar.sh
│ │ ├── traffic.sh
│ │ └── weather.sh
│ ├── MLP-Input-96
│ │ ├── electricity.sh
│ │ ├── etth1.sh
│ │ ├── etth2.sh
│ │ ├── ettm1.sh
│ │ ├── ettm2.sh
│ │ ├── solar.sh
│ │ ├── traffic.sh
│ │ └── weather.sh
│ ├── PEMS
│ │ ├── pems03.sh
│ │ ├── pems04.sh
│ │ ├── pems07.sh
│ │ └── pems08.sh
│ └── STD
│ │ ├── CycleNet.sh
│ │ ├── DLinear.sh
│ │ ├── LDLinear.sh
│ │ ├── Linear.sh
│ │ └── SparseTSF.sh
└── iTransformer
│ ├── electricity.sh
│ └── traffic.sh
├── utils
├── masking.py
├── metrics.py
├── timefeatures.py
└── tools.py
└── visualization.ipynb
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | *.pyc
3 | #*.pth
4 | dataset/
5 | test_results/
6 | results/
7 | logs/
8 | .DS_Store
--------------------------------------------------------------------------------
/Figures/Figure2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Figure2.png
--------------------------------------------------------------------------------
/Figures/Figure3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Figure3.png
--------------------------------------------------------------------------------
/Figures/Figure4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Figure4.png
--------------------------------------------------------------------------------
/Figures/Figure6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Figure6.png
--------------------------------------------------------------------------------
/Figures/Table2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Table2.png
--------------------------------------------------------------------------------
/Figures/Table4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Table4.png
--------------------------------------------------------------------------------
/Figures/Table5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Table5.png
--------------------------------------------------------------------------------
/Figures/Table6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Table6.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CycleNet
2 |
3 | Welcome to the official repository of the CycleNet paper: "[CycleNet: Enhancing Time Series Forecasting through Modeling Periodic Patterns](https://arxiv.org/pdf/2409.18479)".
4 |
5 | [[Poster|海报]](https://drive.google.com/file/d/1dBnmrjtTab4M5L9qfrdAp2Y_xr53hBZ1/view?usp=drive_link) -
6 | [[Slides|幻灯片]](https://drive.google.com/file/d/1QcCCxRFtnFYPtaXmZiF4Zby5r-nl5tR5/view?usp=drive_link) -
7 | [[中文解读]](https://zhuanlan.zhihu.com/p/984766136)
8 |
9 | ## Updates
10 | 🚩 **News** (2025.05): Our latest work, [**TQNet**](https://github.com/ACAT-SCUT/TQNet), has been accepted to **ICML 2025**. TQNet is a powerful successor to CycleNet, addressing its limitation in *modeling inter-variable correlations* effectively.
11 |
12 | 🚩 **News** (2024.10): Thanks to the contribution of [wayhoww](https://github.com/wayhoww), CycleNet has been updated with a [new implementation](https://github.com/ACAT-SCUT/CycleNet/blob/f1deb5e1329970bf0c97b8fa593bb02c6d894587/models/CycleNet.py#L17) for generating cyclic components, achieving a 2x to 3x acceleration.
13 |
14 | 🚩 **News** (2024.09): CycleNet has been accepted as **NeurIPS 2024 _Spotlight_** (_average rating 7.25_, Top 1%).
15 |
16 | **Please note that** [**SparseTSF**](https://github.com/lss-1138/SparseTSF), [**CycleNet**](https://github.com/ACAT-SCUT/CycleNet), and [**TQNet**](https://github.com/ACAT-SCUT/TQNet) represent our continued exploration of **leveraging periodicity** for long-term time series forecasting (LTSF).
17 | The differences and connections among them are as follows:
18 |
19 | | Model | Use of Periodicity | Technique | Effect | Efficiency | Strengths | Limitation |
20 | | :----------------------------------------------------------: | :-------------------------------: | :------------------------------: | :----------------------------------------: | :------------------: | :-------------------------------------------: | :---------------------------------------------------: |
21 | | [**SparseTSF**](https://github.com/lss-1138/SparseTSF)
**(ICML 2024 Oral**) | Indirectly via downsampling | Cross-Period Sparse Forecasting | Ultra-light design | < 1k parameters | Extremely lightweight, near SOTA | Fails to cover multi-periods **(solved by CycleNet)** |
22 | | [**CycleNet**](https://github.com/ACAT-SCUT/CycleNet)
**(NeurIPS 2024 Spotlight)** | Explicit via learnable parameters | Residual Cycle Forecasting (RCF) | Better use of periodicity | 100k ~ 1M parameters | Strong performance on periodic data | Fails in multivariate modeling **(solved by TQNet)** |
23 | | [**TQNet**](https://github.com/ACAT-SCUT/TQNet)
**(ICML 2025)** | Serve as global correlations | Temporal Query in attention mechanism | Robust inter-variable correlation modeling | ~1M parameters | Enhanced multivariate forecasting performance | Hard to scale to ultra-long look-back inputs |
24 |
25 | ## Introduction
26 |
27 | CycleNet pioneers the **explicit modeling of periodicity** to enhance model performance in long-term time series forecasting (LTSF) tasks. Specifically, we introduce the Residual Cycle Forecasting (RCF) technique, which uses **learnable recurrent cycles** to capture inherent periodic patterns in sequences and then makes predictions on the *residual components* of the modeled cycles.
28 |
29 | 
30 |
31 | The RCF technique comprises two steps: the first step involves modeling the periodic patterns of sequences through globally **learnable recurrent cycles** within independent channels, and the second step entails predicting the *residual components* of the modeled cycles.
32 |
33 | 
34 |
35 | The learnable recurrent cycles _**Q**_ are initialized to ***zeros*** and then undergo gradient backpropagation training along with the backbone module for prediction, yielding *learned representations* (different from the initial zeros) that uncover the cyclic patterns embedded within the sequence. Here, we have provided the code implementation [[visualization.ipynb](https://github.com/ACAT-SCUT/CycleNet/blob/main/visualization.ipynb)] to visualize the learned periodic patterns.
36 |
37 | 
38 |
39 | RCF can be regarded as a novel approach to achieve Seasonal-Trend Decomposition (STD). Compared to other existing methods, such as moving averages (used in DLinear, FEDformer, and Autoformer), it offers significant advantages.
40 |
41 | 
42 |
43 | As a result, CycleNet can achieve current state-of-the-art performance using only a simple Linear or dual-layer MLP as its backbone, and it also provides substantial computational efficiency.
44 |
45 | 
46 |
47 | In addition to simple models like Linear and MLP, RCF can also improve the performance of more advanced algorithms.
48 |
49 | 
50 |
51 | Finally, as an explicit periodic modeling technique, RCF requires that the period length of the data is identified beforehand.
52 | For RCF to be effective, the length ***W*** of the learnable recurrent cycles _**Q**_ must accurately match the intrinsic period length of the data.
53 |
54 | 
55 |
56 | To identify the data's inherent period length, a straightforward method is manual inference. For instance, in the Electricity dataset, we know there is a weekly periodic pattern and that the sampling granularity is hourly, so the period length can be deduced to be 168.
57 |
58 | Moreover, a more scientific approach is to use the Autocorrelation Function (ACF), for which we provide an example ([ACF_ETTh1.ipynb](https://github.com/lss-1138/SparseTSF/blob/main/ACF_ETTh1.ipynb)) in the [SparseTSF](https://github.com/lss-1138/SparseTSF) repository. The hyperparameter ***W*** should be set to the lag corresponding to the observed maximum peak.
59 |
60 | 
61 |
62 | ## Model Implementation
63 |
64 | The key technique of CycleNet (or RCF) is to use learnable recurrent cycle to explicitly model the periodic patterns within the data, and then model the residual components of the modeled cycles using either a single-layer Linear or a dual-layer MLP. The core implementation code of CycleNet (or RCF) is available at:
65 |
66 | ```
67 | models/CycleNet.py
68 | ```
69 |
70 | To identify the relative position of each sample within the recurrent cycles, we need to generate cycle index (i.e., **_t_ mod _W_** mentioned in the paper) additionally for each data sample. The code for this part is available at:
71 |
72 | ```
73 | data_provider/data_loader.py
74 | ```
75 | The specific implementation code of cycle index includes:
76 | ```python
77 | def __read_data__(self):
78 | ...
79 | self.cycle_index = (np.arange(len(data)) % self.cycle)[border1:border2]
80 |
81 | def __getitem__(self, index):
82 | ...
83 | cycle_index = torch.tensor(self.cycle_index[s_end])
84 | return ..., cycle_index
85 | ```
86 |
87 | This simple implementation requires ensuring that the sequences have **no missing values**. In practical usage, a more elegant approach can be employed to generate the cycle index, such as mapping real-time timestamp to indices, for example:
88 | ```python
89 | def getCycleIndex(timestamp, frequency, cycle_len):
90 | return (timestamp / frequency) % cycle_len
91 | ```
92 |
93 |
94 |
95 | ## Getting Started
96 |
97 | ### Environment Requirements
98 |
99 | To get started, ensure you have Conda installed on your system and follow these steps to set up the environment:
100 |
101 | ```
102 | conda create -n CycleNet python=3.8
103 | conda activate CycleNet
104 | pip install -r requirements.txt
105 | ```
106 |
107 | ### Data Preparation
108 |
109 | All the datasets needed for CycleNet can be obtained from the [[Google Drive]](https://drive.google.com/file/d/1bNbw1y8VYp-8pkRTqbjoW-TA-G8T0EQf/view) that introduced in previous works such as Autoformer and SCINet.
110 | Create a separate folder named ```./dataset``` and place all the CSV files in this directory.
111 | **Note**: Place the CSV files directly into this directory, such as "./dataset/ETTh1.csv"
112 |
113 | ### Training Example
114 |
115 | You can easily reproduce the results from the paper by running the provided script command. For instance, to reproduce the main results, execute the following command:
116 |
117 | ```
118 | sh run_main.sh
119 | ```
120 |
121 | **For your convenience**, we have provided the execution results of "sh run_main.sh" in [**[result.txt](https://github.com/ACAT-SCUT/CycleNet/blob/main/result.txt)**], which contain the results of running CycleNet/Linear and CycleNet/MLP five times each with various input lengths {96, 336, 720} and random seeds {2024, 2025, 2026, 2027, 2028}.
122 |
123 | You can also run the following command to reproduce the results of various STD techniques as well as the performance on the PEMS datasets (the PEMS datasets can be obtained from [SCINet](https://github.com/cure-lab/SCINet)):
124 |
125 | ```
126 | sh run_std.sh
127 | sh run_pems.sh
128 | ```
129 |
130 | Furthermore, you can specify separate scripts to run independent tasks, such as obtaining results on etth1:
131 |
132 | ```
133 | sh scripts/CycleNet/Linear-Input-96/etth1.sh
134 | ```
135 |
136 |
137 |
138 |
139 |
140 | ## Citation
141 | If you find this repo useful, please cite our paper.
142 | ```
143 | @inproceedings{cyclenet,
144 | title={CycleNet: Enhancing Time Series Forecasting through Modeling Periodic Patterns},
145 | author={Lin, Shengsheng and Lin, Weiwei and Hu, Xinyi and Wu, Wentai and Mo, Ruichao and Zhong, Haocheng},
146 | booktitle={Thirty-eighth Conference on Neural Information Processing Systems},
147 | year={2024}
148 | }
149 | ```
150 |
151 |
152 |
153 |
154 |
155 | ## Contact
156 | If you have any questions or suggestions, feel free to contact:
157 | - Shengsheng Lin ([cslinshengsheng@mail.scut.edu.cn]())
158 | - Weiwei Lin ([linww@scut.edu.cn]())
159 | - Xinyi Hu (xyhu@cse.cuhk.edu.hk)
160 |
161 | ## Acknowledgement
162 |
163 | We extend our heartfelt appreciation to the following GitHub repositories for providing valuable code bases and datasets:
164 |
165 | https://github.com/lss-1138/SparseTSF
166 |
167 | https://github.com/thuml/iTransformer
168 |
169 | https://github.com/lss-1138/SegRNN
170 |
171 | https://github.com/yuqinie98/patchtst
172 |
173 | https://github.com/cure-lab/LTSF-Linear
174 |
175 | https://github.com/zhouhaoyi/Informer2020
176 |
177 | https://github.com/thuml/Autoformer
178 |
179 | https://github.com/MAZiqing/FEDformer
180 |
181 | https://github.com/alipay/Pyraformer
182 |
183 | https://github.com/ts-kim/RevIN
184 |
185 | https://github.com/timeseriesAI/tsai
186 |
187 |
--------------------------------------------------------------------------------
/checkpoints/ETTm1_96_96_CycleNet_ETTm1_ftM_sl96_pl96_cycle96_linear_seed2024/checkpoint.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/checkpoints/ETTm1_96_96_CycleNet_ETTm1_ftM_sl96_pl96_cycle96_linear_seed2024/checkpoint.pth
--------------------------------------------------------------------------------
/checkpoints/Electricity_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024/checkpoint.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/checkpoints/Electricity_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024/checkpoint.pth
--------------------------------------------------------------------------------
/checkpoints/Solar_96_96_CycleNet_Solar_ftM_sl96_pl96_cycle144_linear_seed2024/checkpoint.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/checkpoints/Solar_96_96_CycleNet_Solar_ftM_sl96_pl96_cycle144_linear_seed2024/checkpoint.pth
--------------------------------------------------------------------------------
/checkpoints/traffic_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024/checkpoint.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/checkpoints/traffic_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024/checkpoint.pth
--------------------------------------------------------------------------------
/checkpoints/weather_96_96_CycleNet_custom_ftM_sl96_pl96_cycle144_linear_seed2024/checkpoint.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/checkpoints/weather_96_96_CycleNet_custom_ftM_sl96_pl96_cycle144_linear_seed2024/checkpoint.pth
--------------------------------------------------------------------------------
/data_provider/data_factory.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred, Dataset_Solar, Dataset_PEMS
2 | from torch.utils.data import DataLoader
3 |
4 | data_dict = {
5 | 'ETTh1': Dataset_ETT_hour,
6 | 'ETTh2': Dataset_ETT_hour,
7 | 'ETTm1': Dataset_ETT_minute,
8 | 'ETTm2': Dataset_ETT_minute,
9 | 'custom': Dataset_Custom,
10 | 'Solar': Dataset_Solar,
11 | 'PEMS': Dataset_PEMS
12 | }
13 |
14 |
15 | def data_provider(args, flag):
16 | Data = data_dict[args.data]
17 | timeenc = 0 if args.embed != 'timeF' else 1
18 |
19 | if flag == 'test':
20 | shuffle_flag = False
21 | drop_last = False
22 | batch_size = args.batch_size
23 | freq = args.freq
24 | elif flag == 'pred':
25 | shuffle_flag = False
26 | drop_last = False
27 | batch_size = 1
28 | freq = args.freq
29 | Data = Dataset_Pred
30 | else:
31 | shuffle_flag = True
32 | drop_last = True
33 | batch_size = args.batch_size
34 | freq = args.freq
35 |
36 | data_set = Data(
37 | root_path=args.root_path,
38 | data_path=args.data_path,
39 | flag=flag,
40 | size=[args.seq_len, args.label_len, args.pred_len],
41 | features=args.features,
42 | target=args.target,
43 | timeenc=timeenc,
44 | freq=freq,
45 | cycle=args.cycle
46 | )
47 | print(flag, len(data_set))
48 | data_loader = DataLoader(
49 | data_set,
50 | batch_size=batch_size,
51 | shuffle=shuffle_flag,
52 | num_workers=args.num_workers,
53 | drop_last=drop_last)
54 | return data_set, data_loader
55 |
--------------------------------------------------------------------------------
/exp/exp_basic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 |
6 | class Exp_Basic(object):
7 | def __init__(self, args):
8 | self.args = args
9 | self.device = self._acquire_device()
10 | self.model = self._build_model().to(self.device)
11 |
12 | def _build_model(self):
13 | raise NotImplementedError
14 | return None
15 |
16 | def _acquire_device(self):
17 | if self.args.use_gpu:
18 | os.environ["CUDA_VISIBLE_DEVICES"] = str(
19 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
20 | device = torch.device('cuda:{}'.format(self.args.gpu))
21 | print('Use GPU: cuda:{}'.format(self.args.gpu))
22 | else:
23 | device = torch.device('cpu')
24 | print('Use CPU')
25 | return device
26 |
27 | def _get_data(self):
28 | pass
29 |
30 | def vali(self):
31 | pass
32 |
33 | def train(self):
34 | pass
35 |
36 | def test(self):
37 | pass
38 |
--------------------------------------------------------------------------------
/layers/AutoCorrelation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import math
7 | from math import sqrt
8 | import os
9 |
10 |
11 | class AutoCorrelation(nn.Module):
12 | """
13 | AutoCorrelation Mechanism with the following two phases:
14 | (1) period-based dependencies discovery
15 | (2) time delay aggregation
16 | This block can replace the self-attention family mechanism seamlessly.
17 | """
18 | def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
19 | super(AutoCorrelation, self).__init__()
20 | self.factor = factor
21 | self.scale = scale
22 | self.mask_flag = mask_flag
23 | self.output_attention = output_attention
24 | self.dropout = nn.Dropout(attention_dropout)
25 |
26 | def time_delay_agg_training(self, values, corr):
27 | """
28 | SpeedUp version of Autocorrelation (a batch-normalization style design)
29 | This is for the training phase.
30 | """
31 | head = values.shape[1]
32 | channel = values.shape[2]
33 | length = values.shape[3]
34 | # find top k
35 | top_k = int(self.factor * math.log(length))
36 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
37 | index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
38 | weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
39 | # update corr
40 | tmp_corr = torch.softmax(weights, dim=-1)
41 | # aggregation
42 | tmp_values = values
43 | delays_agg = torch.zeros_like(values).float()
44 | for i in range(top_k):
45 | pattern = torch.roll(tmp_values, -int(index[i]), -1)
46 | delays_agg = delays_agg + pattern * \
47 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
48 | return delays_agg
49 |
50 | def time_delay_agg_inference(self, values, corr):
51 | """
52 | SpeedUp version of Autocorrelation (a batch-normalization style design)
53 | This is for the inference phase.
54 | """
55 | batch = values.shape[0]
56 | head = values.shape[1]
57 | channel = values.shape[2]
58 | length = values.shape[3]
59 | # index init
60 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda()
61 | # find top k
62 | top_k = int(self.factor * math.log(length))
63 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
64 | weights = torch.topk(mean_value, top_k, dim=-1)[0]
65 | delay = torch.topk(mean_value, top_k, dim=-1)[1]
66 | # update corr
67 | tmp_corr = torch.softmax(weights, dim=-1)
68 | # aggregation
69 | tmp_values = values.repeat(1, 1, 1, 2)
70 | delays_agg = torch.zeros_like(values).float()
71 | for i in range(top_k):
72 | tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
73 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
74 | delays_agg = delays_agg + pattern * \
75 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
76 | return delays_agg
77 |
78 | def time_delay_agg_full(self, values, corr):
79 | """
80 | Standard version of Autocorrelation
81 | """
82 | batch = values.shape[0]
83 | head = values.shape[1]
84 | channel = values.shape[2]
85 | length = values.shape[3]
86 | # index init
87 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda()
88 | # find top k
89 | top_k = int(self.factor * math.log(length))
90 | weights = torch.topk(corr, top_k, dim=-1)[0]
91 | delay = torch.topk(corr, top_k, dim=-1)[1]
92 | # update corr
93 | tmp_corr = torch.softmax(weights, dim=-1)
94 | # aggregation
95 | tmp_values = values.repeat(1, 1, 1, 2)
96 | delays_agg = torch.zeros_like(values).float()
97 | for i in range(top_k):
98 | tmp_delay = init_index + delay[..., i].unsqueeze(-1)
99 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
100 | delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
101 | return delays_agg
102 |
103 | def forward(self, queries, keys, values, attn_mask):
104 | B, L, H, E = queries.shape
105 | _, S, _, D = values.shape
106 | if L > S:
107 | zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
108 | values = torch.cat([values, zeros], dim=1)
109 | keys = torch.cat([keys, zeros], dim=1)
110 | else:
111 | values = values[:, :L, :, :]
112 | keys = keys[:, :L, :, :]
113 |
114 | # period-based dependencies
115 | q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
116 | k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
117 | res = q_fft * torch.conj(k_fft)
118 | corr = torch.fft.irfft(res, dim=-1)
119 |
120 | # time delay agg
121 | if self.training:
122 | V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
123 | else:
124 | V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
125 |
126 | if self.output_attention:
127 | return (V.contiguous(), corr.permute(0, 3, 1, 2))
128 | else:
129 | return (V.contiguous(), None)
130 |
131 |
132 | class AutoCorrelationLayer(nn.Module):
133 | def __init__(self, correlation, d_model, n_heads, d_keys=None,
134 | d_values=None):
135 | super(AutoCorrelationLayer, self).__init__()
136 |
137 | d_keys = d_keys or (d_model // n_heads)
138 | d_values = d_values or (d_model // n_heads)
139 |
140 | self.inner_correlation = correlation
141 | self.query_projection = nn.Linear(d_model, d_keys * n_heads)
142 | self.key_projection = nn.Linear(d_model, d_keys * n_heads)
143 | self.value_projection = nn.Linear(d_model, d_values * n_heads)
144 | self.out_projection = nn.Linear(d_values * n_heads, d_model)
145 | self.n_heads = n_heads
146 |
147 | def forward(self, queries, keys, values, attn_mask):
148 | B, L, _ = queries.shape
149 | _, S, _ = keys.shape
150 | H = self.n_heads
151 |
152 | queries = self.query_projection(queries).view(B, L, H, -1)
153 | keys = self.key_projection(keys).view(B, S, H, -1)
154 | values = self.value_projection(values).view(B, S, H, -1)
155 |
156 | out, attn = self.inner_correlation(
157 | queries,
158 | keys,
159 | values,
160 | attn_mask
161 | )
162 | out = out.view(B, L, -1)
163 |
164 | return self.out_projection(out), attn
165 |
--------------------------------------------------------------------------------
/layers/Autoformer_EncDec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class my_Layernorm(nn.Module):
7 | """
8 | Special designed layernorm for the seasonal part
9 | """
10 | def __init__(self, channels):
11 | super(my_Layernorm, self).__init__()
12 | self.layernorm = nn.LayerNorm(channels)
13 |
14 | def forward(self, x):
15 | x_hat = self.layernorm(x)
16 | bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
17 | return x_hat - bias
18 |
19 |
20 | class moving_avg(nn.Module):
21 | """
22 | Moving average block to highlight the trend of time series
23 | """
24 | def __init__(self, kernel_size, stride):
25 | super(moving_avg, self).__init__()
26 | self.kernel_size = kernel_size
27 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
28 |
29 | def forward(self, x):
30 | # padding on the both ends of time series
31 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
32 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
33 | x = torch.cat([front, x, end], dim=1)
34 | x = self.avg(x.permute(0, 2, 1))
35 | x = x.permute(0, 2, 1)
36 | return x
37 |
38 |
39 | class series_decomp(nn.Module):
40 | """
41 | Series decomposition block
42 | """
43 | def __init__(self, kernel_size):
44 | super(series_decomp, self).__init__()
45 | self.moving_avg = moving_avg(kernel_size, stride=1)
46 |
47 | def forward(self, x):
48 | moving_mean = self.moving_avg(x)
49 | res = x - moving_mean
50 | return res, moving_mean
51 |
52 |
53 | class EncoderLayer(nn.Module):
54 | """
55 | Autoformer encoder layer with the progressive decomposition architecture
56 | """
57 | def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
58 | super(EncoderLayer, self).__init__()
59 | d_ff = d_ff or 4 * d_model
60 | self.attention = attention
61 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
62 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
63 | self.decomp1 = series_decomp(moving_avg)
64 | self.decomp2 = series_decomp(moving_avg)
65 | self.dropout = nn.Dropout(dropout)
66 | self.activation = F.relu if activation == "relu" else F.gelu
67 |
68 | def forward(self, x, attn_mask=None):
69 | new_x, attn = self.attention(
70 | x, x, x,
71 | attn_mask=attn_mask
72 | )
73 | x = x + self.dropout(new_x)
74 | x, _ = self.decomp1(x)
75 | y = x
76 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
77 | y = self.dropout(self.conv2(y).transpose(-1, 1))
78 | res, _ = self.decomp2(x + y)
79 | return res, attn
80 |
81 |
82 | class Encoder(nn.Module):
83 | """
84 | Autoformer encoder
85 | """
86 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
87 | super(Encoder, self).__init__()
88 | self.attn_layers = nn.ModuleList(attn_layers)
89 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
90 | self.norm = norm_layer
91 |
92 | def forward(self, x, attn_mask=None):
93 | attns = []
94 | if self.conv_layers is not None:
95 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
96 | x, attn = attn_layer(x, attn_mask=attn_mask)
97 | x = conv_layer(x)
98 | attns.append(attn)
99 | x, attn = self.attn_layers[-1](x)
100 | attns.append(attn)
101 | else:
102 | for attn_layer in self.attn_layers:
103 | x, attn = attn_layer(x, attn_mask=attn_mask)
104 | attns.append(attn)
105 |
106 | if self.norm is not None:
107 | x = self.norm(x)
108 |
109 | return x, attns
110 |
111 |
112 | class DecoderLayer(nn.Module):
113 | """
114 | Autoformer decoder layer with the progressive decomposition architecture
115 | """
116 | def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
117 | moving_avg=25, dropout=0.1, activation="relu"):
118 | super(DecoderLayer, self).__init__()
119 | d_ff = d_ff or 4 * d_model
120 | self.self_attention = self_attention
121 | self.cross_attention = cross_attention
122 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
123 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
124 | self.decomp1 = series_decomp(moving_avg)
125 | self.decomp2 = series_decomp(moving_avg)
126 | self.decomp3 = series_decomp(moving_avg)
127 | self.dropout = nn.Dropout(dropout)
128 | self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
129 | padding_mode='circular', bias=False)
130 | self.activation = F.relu if activation == "relu" else F.gelu
131 |
132 | def forward(self, x, cross, x_mask=None, cross_mask=None):
133 | x = x + self.dropout(self.self_attention(
134 | x, x, x,
135 | attn_mask=x_mask
136 | )[0])
137 | x, trend1 = self.decomp1(x)
138 | x = x + self.dropout(self.cross_attention(
139 | x, cross, cross,
140 | attn_mask=cross_mask
141 | )[0])
142 | x, trend2 = self.decomp2(x)
143 | y = x
144 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
145 | y = self.dropout(self.conv2(y).transpose(-1, 1))
146 | x, trend3 = self.decomp3(x + y)
147 |
148 | residual_trend = trend1 + trend2 + trend3
149 | residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
150 | return x, residual_trend
151 |
152 |
153 | class Decoder(nn.Module):
154 | """
155 | Autoformer encoder
156 | """
157 | def __init__(self, layers, norm_layer=None, projection=None):
158 | super(Decoder, self).__init__()
159 | self.layers = nn.ModuleList(layers)
160 | self.norm = norm_layer
161 | self.projection = projection
162 |
163 | def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
164 | for layer in self.layers:
165 | x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
166 | trend = trend + residual_trend
167 |
168 | if self.norm is not None:
169 | x = self.norm(x)
170 |
171 | if self.projection is not None:
172 | x = self.projection(x)
173 | return x, trend
174 |
--------------------------------------------------------------------------------
/layers/Embed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import weight_norm
5 | import math
6 |
7 |
8 | class PositionalEmbedding(nn.Module):
9 | def __init__(self, d_model, max_len=5000):
10 | super(PositionalEmbedding, self).__init__()
11 | # Compute the positional encodings once in log space.
12 | pe = torch.zeros(max_len, d_model).float()
13 | pe.require_grad = False
14 |
15 | position = torch.arange(0, max_len).float().unsqueeze(1)
16 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
17 |
18 | pe[:, 0::2] = torch.sin(position * div_term)
19 | pe[:, 1::2] = torch.cos(position * div_term)
20 |
21 | pe = pe.unsqueeze(0)
22 | self.register_buffer('pe', pe)
23 |
24 | def forward(self, x):
25 | return self.pe[:, :x.size(1)]
26 |
27 |
28 | class TokenEmbedding(nn.Module):
29 | def __init__(self, c_in, d_model):
30 | super(TokenEmbedding, self).__init__()
31 | padding = 1 if torch.__version__ >= '1.5.0' else 2
32 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
33 | kernel_size=3, padding=padding, padding_mode='circular', bias=False)
34 | for m in self.modules():
35 | if isinstance(m, nn.Conv1d):
36 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
37 |
38 | def forward(self, x):
39 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
40 | return x
41 |
42 |
43 | class FixedEmbedding(nn.Module):
44 | def __init__(self, c_in, d_model):
45 | super(FixedEmbedding, self).__init__()
46 |
47 | w = torch.zeros(c_in, d_model).float()
48 | w.require_grad = False
49 |
50 | position = torch.arange(0, c_in).float().unsqueeze(1)
51 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
52 |
53 | w[:, 0::2] = torch.sin(position * div_term)
54 | w[:, 1::2] = torch.cos(position * div_term)
55 |
56 | self.emb = nn.Embedding(c_in, d_model)
57 | self.emb.weight = nn.Parameter(w, requires_grad=False)
58 |
59 | def forward(self, x):
60 | return self.emb(x).detach()
61 |
62 |
63 | class TemporalEmbedding(nn.Module):
64 | def __init__(self, d_model, embed_type='fixed', freq='h'):
65 | super(TemporalEmbedding, self).__init__()
66 |
67 | minute_size = 4
68 | hour_size = 24
69 | weekday_size = 7
70 | day_size = 32
71 | month_size = 13
72 |
73 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
74 | if freq == 't':
75 | self.minute_embed = Embed(minute_size, d_model)
76 | self.hour_embed = Embed(hour_size, d_model)
77 | self.weekday_embed = Embed(weekday_size, d_model)
78 | self.day_embed = Embed(day_size, d_model)
79 | self.month_embed = Embed(month_size, d_model)
80 |
81 | def forward(self, x):
82 | x = x.long()
83 |
84 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0.
85 | hour_x = self.hour_embed(x[:, :, 3])
86 | weekday_x = self.weekday_embed(x[:, :, 2])
87 | day_x = self.day_embed(x[:, :, 1])
88 | month_x = self.month_embed(x[:, :, 0])
89 |
90 | return hour_x + weekday_x + day_x + month_x + minute_x
91 |
92 |
93 | class TimeFeatureEmbedding(nn.Module):
94 | def __init__(self, d_model, embed_type='timeF', freq='h'):
95 | super(TimeFeatureEmbedding, self).__init__()
96 |
97 | freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
98 | d_inp = freq_map[freq]
99 | self.embed = nn.Linear(d_inp, d_model, bias=False)
100 |
101 | def forward(self, x):
102 | return self.embed(x)
103 |
104 |
105 | class DataEmbedding(nn.Module):
106 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
107 | super(DataEmbedding, self).__init__()
108 |
109 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
110 | self.position_embedding = PositionalEmbedding(d_model=d_model)
111 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
112 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
113 | d_model=d_model, embed_type=embed_type, freq=freq)
114 | self.dropout = nn.Dropout(p=dropout)
115 |
116 | def forward(self, x, x_mark):
117 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
118 | return self.dropout(x)
119 |
120 |
121 | class DataEmbedding_wo_pos(nn.Module):
122 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
123 | super(DataEmbedding_wo_pos, self).__init__()
124 |
125 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
126 | self.position_embedding = PositionalEmbedding(d_model=d_model)
127 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
128 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
129 | d_model=d_model, embed_type=embed_type, freq=freq)
130 | self.dropout = nn.Dropout(p=dropout)
131 |
132 | def forward(self, x, x_mark):
133 | x = self.value_embedding(x) + self.temporal_embedding(x_mark)
134 | return self.dropout(x)
135 |
136 | class DataEmbedding_wo_pos_temp(nn.Module):
137 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
138 | super(DataEmbedding_wo_pos_temp, self).__init__()
139 |
140 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
141 | self.position_embedding = PositionalEmbedding(d_model=d_model)
142 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
143 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
144 | d_model=d_model, embed_type=embed_type, freq=freq)
145 | self.dropout = nn.Dropout(p=dropout)
146 |
147 | def forward(self, x, x_mark):
148 | x = self.value_embedding(x)
149 | return self.dropout(x)
150 |
151 | class DataEmbedding_wo_temp(nn.Module):
152 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
153 | super(DataEmbedding_wo_temp, self).__init__()
154 |
155 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
156 | self.position_embedding = PositionalEmbedding(d_model=d_model)
157 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
158 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
159 | d_model=d_model, embed_type=embed_type, freq=freq)
160 | self.dropout = nn.Dropout(p=dropout)
161 |
162 | def forward(self, x, x_mark):
163 | x = self.value_embedding(x) + self.position_embedding(x)
164 | return self.dropout(x)
165 |
166 |
167 | class DataEmbedding_inverted(nn.Module):
168 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
169 | super(DataEmbedding_inverted, self).__init__()
170 | self.value_embedding = nn.Linear(c_in, d_model)
171 | self.dropout = nn.Dropout(p=dropout)
172 |
173 | def forward(self, x, x_mark):
174 | x = x.permute(0, 2, 1)
175 | # x: [Batch Variate Time]
176 | if x_mark is None:
177 | x = self.value_embedding(x)
178 | else:
179 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
180 | # x: [Batch Variate d_model]
181 | return self.dropout(x)
182 |
--------------------------------------------------------------------------------
/layers/PatchTST_layers.py:
--------------------------------------------------------------------------------
1 | __all__ = ['Transpose', 'get_activation_fn', 'moving_avg', 'series_decomp', 'PositionalEncoding', 'SinCosPosEncoding', 'Coord2dPosEncoding', 'Coord1dPosEncoding', 'positional_encoding']
2 |
3 | import torch
4 | from torch import nn
5 | import math
6 |
7 | class Transpose(nn.Module):
8 | def __init__(self, *dims, contiguous=False):
9 | super().__init__()
10 | self.dims, self.contiguous = dims, contiguous
11 | def forward(self, x):
12 | if self.contiguous: return x.transpose(*self.dims).contiguous()
13 | else: return x.transpose(*self.dims)
14 |
15 |
16 | def get_activation_fn(activation):
17 | if callable(activation): return activation()
18 | elif activation.lower() == "relu": return nn.ReLU()
19 | elif activation.lower() == "gelu": return nn.GELU()
20 | raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable')
21 |
22 |
23 | # decomposition
24 |
25 | class moving_avg(nn.Module):
26 | """
27 | Moving average block to highlight the trend of time series
28 | """
29 | def __init__(self, kernel_size, stride):
30 | super(moving_avg, self).__init__()
31 | self.kernel_size = kernel_size
32 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
33 |
34 | def forward(self, x):
35 | # padding on the both ends of time series
36 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
37 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
38 | x = torch.cat([front, x, end], dim=1)
39 | x = self.avg(x.permute(0, 2, 1))
40 | x = x.permute(0, 2, 1)
41 | return x
42 |
43 |
44 | class series_decomp(nn.Module):
45 | """
46 | Series decomposition block
47 | """
48 | def __init__(self, kernel_size):
49 | super(series_decomp, self).__init__()
50 | self.moving_avg = moving_avg(kernel_size, stride=1)
51 |
52 | def forward(self, x):
53 | moving_mean = self.moving_avg(x)
54 | res = x - moving_mean
55 | return res, moving_mean
56 |
57 |
58 |
59 | # pos_encoding
60 |
61 | def PositionalEncoding(q_len, d_model, normalize=True):
62 | pe = torch.zeros(q_len, d_model)
63 | position = torch.arange(0, q_len).unsqueeze(1)
64 | div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
65 | pe[:, 0::2] = torch.sin(position * div_term)
66 | pe[:, 1::2] = torch.cos(position * div_term)
67 | if normalize:
68 | pe = pe - pe.mean()
69 | pe = pe / (pe.std() * 10)
70 | return pe
71 |
72 | SinCosPosEncoding = PositionalEncoding
73 |
74 | def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False):
75 | x = .5 if exponential else 1
76 | i = 0
77 | for i in range(100):
78 | cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1
79 | pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose)
80 | if abs(cpe.mean()) <= eps: break
81 | elif cpe.mean() > eps: x += .001
82 | else: x -= .001
83 | i += 1
84 | if normalize:
85 | cpe = cpe - cpe.mean()
86 | cpe = cpe / (cpe.std() * 10)
87 | return cpe
88 |
89 | def Coord1dPosEncoding(q_len, exponential=False, normalize=True):
90 | cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1)
91 | if normalize:
92 | cpe = cpe - cpe.mean()
93 | cpe = cpe / (cpe.std() * 10)
94 | return cpe
95 |
96 | def positional_encoding(pe, learn_pe, q_len, d_model):
97 | # Positional encoding
98 | if pe == None:
99 | W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe
100 | nn.init.uniform_(W_pos, -0.02, 0.02)
101 | learn_pe = False
102 | elif pe == 'zero':
103 | W_pos = torch.empty((q_len, 1))
104 | nn.init.uniform_(W_pos, -0.02, 0.02)
105 | elif pe == 'zeros':
106 | W_pos = torch.empty((q_len, d_model))
107 | nn.init.uniform_(W_pos, -0.02, 0.02)
108 | elif pe == 'normal' or pe == 'gauss':
109 | W_pos = torch.zeros((q_len, 1))
110 | torch.nn.init.normal_(W_pos, mean=0.0, std=0.1)
111 | elif pe == 'uniform':
112 | W_pos = torch.zeros((q_len, 1))
113 | nn.init.uniform_(W_pos, a=0.0, b=0.1)
114 | elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True)
115 | elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True)
116 | elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True)
117 | elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True)
118 | elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True)
119 | else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \
120 | 'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)")
121 | return nn.Parameter(W_pos, requires_grad=learn_pe)
--------------------------------------------------------------------------------
/layers/RevIN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class RevIN(nn.Module):
5 | def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False):
6 | """
7 | :param num_features: the number of features or channels
8 | :param eps: a value added for numerical stability
9 | :param affine: if True, RevIN has learnable affine parameters
10 | """
11 | super(RevIN, self).__init__()
12 | self.num_features = num_features
13 | self.eps = eps
14 | self.affine = affine
15 | self.subtract_last = subtract_last
16 | if self.affine:
17 | self._init_params()
18 |
19 | def forward(self, x, mode:str):
20 | if mode == 'norm':
21 | self._get_statistics(x)
22 | x = self._normalize(x)
23 | elif mode == 'denorm':
24 | x = self._denormalize(x)
25 | else: raise NotImplementedError
26 | return x
27 |
28 | def _init_params(self):
29 | # initialize RevIN params: (C,)
30 | self.affine_weight = nn.Parameter(torch.ones(self.num_features))
31 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
32 |
33 | def _get_statistics(self, x):
34 | dim2reduce = tuple(range(1, x.ndim-1))
35 | if self.subtract_last:
36 | self.last = x[:,-1,:].unsqueeze(1)
37 | else:
38 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
39 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
40 |
41 | def _normalize(self, x):
42 | if self.subtract_last:
43 | x = x - self.last
44 | else:
45 | x = x - self.mean
46 | x = x / self.stdev
47 | if self.affine:
48 | x = x * self.affine_weight
49 | x = x + self.affine_bias
50 | return x
51 |
52 | def _denormalize(self, x):
53 | if self.affine:
54 | x = x - self.affine_bias
55 | x = x / (self.affine_weight + self.eps*self.eps)
56 | x = x * self.stdev
57 | if self.subtract_last:
58 | x = x + self.last
59 | else:
60 | x = x + self.mean
61 | return x
--------------------------------------------------------------------------------
/layers/SelfAttention_Family.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import matplotlib.pyplot as plt
6 |
7 | import numpy as np
8 | import math
9 | from math import sqrt
10 | from utils.masking import TriangularCausalMask, ProbMask
11 | import os
12 |
13 |
14 | class FullAttention(nn.Module):
15 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
16 | super(FullAttention, self).__init__()
17 | self.scale = scale
18 | self.mask_flag = mask_flag
19 | self.output_attention = output_attention
20 | self.dropout = nn.Dropout(attention_dropout)
21 |
22 | def forward(self, queries, keys, values, attn_mask):
23 | B, L, H, E = queries.shape
24 | _, S, _, D = values.shape
25 | scale = self.scale or 1. / sqrt(E)
26 |
27 | scores = torch.einsum("blhe,bshe->bhls", queries, keys)
28 |
29 | if self.mask_flag:
30 | if attn_mask is None:
31 | attn_mask = TriangularCausalMask(B, L, device=queries.device)
32 |
33 | scores.masked_fill_(attn_mask.mask, -np.inf)
34 |
35 | A = self.dropout(torch.softmax(scale * scores, dim=-1))
36 | V = torch.einsum("bhls,bshd->blhd", A, values)
37 |
38 | if self.output_attention:
39 | return (V.contiguous(), A)
40 | else:
41 | return (V.contiguous(), None)
42 |
43 |
44 | class ProbAttention(nn.Module):
45 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
46 | super(ProbAttention, self).__init__()
47 | self.factor = factor
48 | self.scale = scale
49 | self.mask_flag = mask_flag
50 | self.output_attention = output_attention
51 | self.dropout = nn.Dropout(attention_dropout)
52 |
53 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
54 | # Q [B, H, L, D]
55 | B, H, L_K, E = K.shape
56 | _, _, L_Q, _ = Q.shape
57 |
58 | # calculate the sampled Q_K
59 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
60 | index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
61 | K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
62 | Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
63 |
64 | # find the Top_k query with sparisty measurement
65 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
66 | M_top = M.topk(n_top, sorted=False)[1]
67 |
68 | # use the reduced Q to calculate Q_K
69 | Q_reduce = Q[torch.arange(B)[:, None, None],
70 | torch.arange(H)[None, :, None],
71 | M_top, :] # factor*ln(L_q)
72 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
73 |
74 | return Q_K, M_top
75 |
76 | def _get_initial_context(self, V, L_Q):
77 | B, H, L_V, D = V.shape
78 | if not self.mask_flag:
79 | # V_sum = V.sum(dim=-2)
80 | V_sum = V.mean(dim=-2)
81 | contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
82 | else: # use mask
83 | assert (L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only
84 | contex = V.cumsum(dim=-2)
85 | return contex
86 |
87 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
88 | B, H, L_V, D = V.shape
89 |
90 | if self.mask_flag:
91 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
92 | scores.masked_fill_(attn_mask.mask, -np.inf)
93 |
94 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
95 |
96 | context_in[torch.arange(B)[:, None, None],
97 | torch.arange(H)[None, :, None],
98 | index, :] = torch.matmul(attn, V).type_as(context_in)
99 | if self.output_attention:
100 | attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
101 | attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
102 | return (context_in, attns)
103 | else:
104 | return (context_in, None)
105 |
106 | def forward(self, queries, keys, values, attn_mask):
107 | B, L_Q, H, D = queries.shape
108 | _, L_K, _, _ = keys.shape
109 |
110 | queries = queries.transpose(2, 1)
111 | keys = keys.transpose(2, 1)
112 | values = values.transpose(2, 1)
113 |
114 | U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
115 | u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
116 |
117 | U_part = U_part if U_part < L_K else L_K
118 | u = u if u < L_Q else L_Q
119 |
120 | scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
121 |
122 | # add scale factor
123 | scale = self.scale or 1. / sqrt(D)
124 | if scale is not None:
125 | scores_top = scores_top * scale
126 | # get the context
127 | context = self._get_initial_context(values, L_Q)
128 | # update the context with selected top_k queries
129 | context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
130 |
131 | return context.contiguous(), attn
132 |
133 |
134 | class AttentionLayer(nn.Module):
135 | def __init__(self, attention, d_model, n_heads, d_keys=None,
136 | d_values=None):
137 | super(AttentionLayer, self).__init__()
138 |
139 | d_keys = d_keys or (d_model // n_heads)
140 | d_values = d_values or (d_model // n_heads)
141 |
142 | self.inner_attention = attention
143 | self.query_projection = nn.Linear(d_model, d_keys * n_heads)
144 | self.key_projection = nn.Linear(d_model, d_keys * n_heads)
145 | self.value_projection = nn.Linear(d_model, d_values * n_heads)
146 | self.out_projection = nn.Linear(d_values * n_heads, d_model)
147 | self.n_heads = n_heads
148 |
149 | def forward(self, queries, keys, values, attn_mask):
150 | B, L, _ = queries.shape
151 | _, S, _ = keys.shape
152 | H = self.n_heads
153 |
154 | queries = self.query_projection(queries).view(B, L, H, -1)
155 | keys = self.key_projection(keys).view(B, S, H, -1)
156 | values = self.value_projection(values).view(B, S, H, -1)
157 |
158 | out, attn = self.inner_attention(
159 | queries,
160 | keys,
161 | values,
162 | attn_mask
163 | )
164 | out = out.view(B, L, -1)
165 |
166 | return self.out_projection(out), attn
167 |
--------------------------------------------------------------------------------
/layers/Transformer_EncDec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ConvLayer(nn.Module):
7 | def __init__(self, c_in):
8 | super(ConvLayer, self).__init__()
9 | self.downConv = nn.Conv1d(in_channels=c_in,
10 | out_channels=c_in,
11 | kernel_size=3,
12 | padding=2,
13 | padding_mode='circular')
14 | self.norm = nn.BatchNorm1d(c_in)
15 | self.activation = nn.ELU()
16 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
17 |
18 | def forward(self, x):
19 | x = self.downConv(x.permute(0, 2, 1))
20 | x = self.norm(x)
21 | x = self.activation(x)
22 | x = self.maxPool(x)
23 | x = x.transpose(1, 2)
24 | return x
25 |
26 |
27 | class EncoderLayer(nn.Module):
28 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
29 | super(EncoderLayer, self).__init__()
30 | d_ff = d_ff or 4 * d_model
31 | self.attention = attention
32 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
33 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
34 | self.norm1 = nn.LayerNorm(d_model)
35 | self.norm2 = nn.LayerNorm(d_model)
36 | self.dropout = nn.Dropout(dropout)
37 | self.activation = F.relu if activation == "relu" else F.gelu
38 |
39 | def forward(self, x, attn_mask=None):
40 | new_x, attn = self.attention(
41 | x, x, x,
42 | attn_mask=attn_mask
43 | )
44 | x = x + self.dropout(new_x)
45 |
46 | y = x = self.norm1(x)
47 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
48 | y = self.dropout(self.conv2(y).transpose(-1, 1))
49 |
50 | return self.norm2(x + y), attn
51 |
52 |
53 | class Encoder(nn.Module):
54 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
55 | super(Encoder, self).__init__()
56 | self.attn_layers = nn.ModuleList(attn_layers)
57 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
58 | self.norm = norm_layer
59 |
60 | def forward(self, x, attn_mask=None):
61 | # x [B, L, D]
62 | attns = []
63 | if self.conv_layers is not None:
64 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
65 | x, attn = attn_layer(x, attn_mask=attn_mask)
66 | x = conv_layer(x)
67 | attns.append(attn)
68 | x, attn = self.attn_layers[-1](x)
69 | attns.append(attn)
70 | else:
71 | for attn_layer in self.attn_layers:
72 | x, attn = attn_layer(x, attn_mask=attn_mask)
73 | attns.append(attn)
74 |
75 | if self.norm is not None:
76 | x = self.norm(x)
77 |
78 | return x, attns
79 |
80 |
81 | class DecoderLayer(nn.Module):
82 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
83 | dropout=0.1, activation="relu"):
84 | super(DecoderLayer, self).__init__()
85 | d_ff = d_ff or 4 * d_model
86 | self.self_attention = self_attention
87 | self.cross_attention = cross_attention
88 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
89 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
90 | self.norm1 = nn.LayerNorm(d_model)
91 | self.norm2 = nn.LayerNorm(d_model)
92 | self.norm3 = nn.LayerNorm(d_model)
93 | self.dropout = nn.Dropout(dropout)
94 | self.activation = F.relu if activation == "relu" else F.gelu
95 |
96 | def forward(self, x, cross, x_mask=None, cross_mask=None):
97 | x = x + self.dropout(self.self_attention(
98 | x, x, x,
99 | attn_mask=x_mask
100 | )[0])
101 | x = self.norm1(x)
102 |
103 | x = x + self.dropout(self.cross_attention(
104 | x, cross, cross,
105 | attn_mask=cross_mask
106 | )[0])
107 |
108 | y = x = self.norm2(x)
109 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
110 | y = self.dropout(self.conv2(y).transpose(-1, 1))
111 |
112 | return self.norm3(x + y)
113 |
114 |
115 | class Decoder(nn.Module):
116 | def __init__(self, layers, norm_layer=None, projection=None):
117 | super(Decoder, self).__init__()
118 | self.layers = nn.ModuleList(layers)
119 | self.norm = norm_layer
120 | self.projection = projection
121 |
122 | def forward(self, x, cross, x_mask=None, cross_mask=None):
123 | for layer in self.layers:
124 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
125 |
126 | if self.norm is not None:
127 | x = self.norm(x)
128 |
129 | if self.projection is not None:
130 | x = self.projection(x)
131 | return x
132 |
--------------------------------------------------------------------------------
/models/Autoformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Embed import DataEmbedding, DataEmbedding_wo_pos,DataEmbedding_wo_pos_temp,DataEmbedding_wo_temp
5 | from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer
6 | from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp
7 | import math
8 | import numpy as np
9 |
10 |
11 | class Model(nn.Module):
12 | """
13 | Autoformer is the first method to achieve the series-wise connection,
14 | with inherent O(LlogL) complexity
15 | """
16 | def __init__(self, configs):
17 | super(Model, self).__init__()
18 | self.seq_len = configs.seq_len
19 | self.label_len = configs.label_len
20 | self.pred_len = configs.pred_len
21 | self.output_attention = configs.output_attention
22 |
23 | # Decomp
24 | kernel_size = configs.moving_avg
25 | self.decomp = series_decomp(kernel_size)
26 |
27 | # Embedding
28 | # The series-wise connection inherently contains the sequential information.
29 | # Thus, we can discard the position embedding of transformers.
30 | if configs.embed_type == 0:
31 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
32 | configs.dropout)
33 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq,
34 | configs.dropout)
35 | elif configs.embed_type == 1:
36 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
37 | configs.dropout)
38 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
39 | configs.dropout)
40 | elif configs.embed_type == 2:
41 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
42 | configs.dropout)
43 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq,
44 | configs.dropout)
45 |
46 | elif configs.embed_type == 3:
47 | self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq,
48 | configs.dropout)
49 | self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq,
50 | configs.dropout)
51 | elif configs.embed_type == 4:
52 | self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq,
53 | configs.dropout)
54 | self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq,
55 | configs.dropout)
56 |
57 | # Encoder
58 | self.encoder = Encoder(
59 | [
60 | EncoderLayer(
61 | AutoCorrelationLayer(
62 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout,
63 | output_attention=configs.output_attention),
64 | configs.d_model, configs.n_heads),
65 | configs.d_model,
66 | configs.d_ff,
67 | moving_avg=configs.moving_avg,
68 | dropout=configs.dropout,
69 | activation=configs.activation
70 | ) for l in range(configs.e_layers)
71 | ],
72 | norm_layer=my_Layernorm(configs.d_model)
73 | )
74 | # Decoder
75 | self.decoder = Decoder(
76 | [
77 | DecoderLayer(
78 | AutoCorrelationLayer(
79 | AutoCorrelation(True, configs.factor, attention_dropout=configs.dropout,
80 | output_attention=False),
81 | configs.d_model, configs.n_heads),
82 | AutoCorrelationLayer(
83 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout,
84 | output_attention=False),
85 | configs.d_model, configs.n_heads),
86 | configs.d_model,
87 | configs.c_out,
88 | configs.d_ff,
89 | moving_avg=configs.moving_avg,
90 | dropout=configs.dropout,
91 | activation=configs.activation,
92 | )
93 | for l in range(configs.d_layers)
94 | ],
95 | norm_layer=my_Layernorm(configs.d_model),
96 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
97 | )
98 |
99 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
100 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
101 | # decomp init
102 | mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)
103 | zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device)
104 | seasonal_init, trend_init = self.decomp(x_enc)
105 | # decoder input
106 | trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
107 | seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)
108 | # enc
109 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
110 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
111 | # dec
112 | dec_out = self.dec_embedding(seasonal_init, x_mark_dec)
113 | seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask,
114 | trend=trend_init)
115 | # final
116 | dec_out = trend_part + seasonal_part
117 |
118 | if self.output_attention:
119 | return dec_out[:, -self.pred_len:, :], attns
120 | else:
121 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
122 |
--------------------------------------------------------------------------------
/models/CycleNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class RecurrentCycle(torch.nn.Module):
5 | # Thanks for the contribution of wayhoww.
6 | # The new implementation uses index arithmetic with modulo to directly gather cyclic data in a single operation,
7 | # while the original implementation manually rolls and repeats the data through looping.
8 | # It achieves a significant speed improvement (2x ~ 3x acceleration).
9 | # See https://github.com/ACAT-SCUT/CycleNet/pull/4 for more details.
10 | def __init__(self, cycle_len, channel_size):
11 | super(RecurrentCycle, self).__init__()
12 | self.cycle_len = cycle_len
13 | self.channel_size = channel_size
14 | self.data = torch.nn.Parameter(torch.zeros(cycle_len, channel_size), requires_grad=True)
15 |
16 | def forward(self, index, length):
17 | gather_index = (index.view(-1, 1) + torch.arange(length, device=index.device).view(1, -1)) % self.cycle_len
18 | return self.data[gather_index]
19 |
20 |
21 | class Model(nn.Module):
22 | def __init__(self, configs):
23 | super(Model, self).__init__()
24 |
25 | self.seq_len = configs.seq_len
26 | self.pred_len = configs.pred_len
27 | self.enc_in = configs.enc_in
28 | self.cycle_len = configs.cycle
29 | self.model_type = configs.model_type
30 | self.d_model = configs.d_model
31 | self.use_revin = configs.use_revin
32 |
33 | self.cycleQueue = RecurrentCycle(cycle_len=self.cycle_len, channel_size=self.enc_in)
34 |
35 | assert self.model_type in ['linear', 'mlp']
36 | if self.model_type == 'linear':
37 | self.model = nn.Linear(self.seq_len, self.pred_len)
38 | elif self.model_type == 'mlp':
39 | self.model = nn.Sequential(
40 | nn.Linear(self.seq_len, self.d_model),
41 | nn.ReLU(),
42 | nn.Linear(self.d_model, self.pred_len)
43 | )
44 |
45 | def forward(self, x, cycle_index):
46 | # x: (batch_size, seq_len, enc_in), cycle_index: (batch_size,)
47 |
48 | # instance norm
49 | if self.use_revin:
50 | seq_mean = torch.mean(x, dim=1, keepdim=True)
51 | seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5
52 | x = (x - seq_mean) / torch.sqrt(seq_var)
53 |
54 | # remove the cycle of the input data
55 | x = x - self.cycleQueue(cycle_index, self.seq_len)
56 |
57 | # forecasting with channel independence (parameters-sharing)
58 | y = self.model(x.permute(0, 2, 1)).permute(0, 2, 1)
59 |
60 | # add back the cycle of the output data
61 | y = y + self.cycleQueue((cycle_index + self.seq_len) % self.cycle_len, self.pred_len)
62 |
63 | # instance denorm
64 | if self.use_revin:
65 | y = y * torch.sqrt(seq_var) + seq_mean
66 |
67 | return y
68 |
--------------------------------------------------------------------------------
/models/CycleiTransformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Encoder, EncoderLayer
5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer
6 | from layers.Embed import DataEmbedding_inverted
7 | import numpy as np
8 |
9 | ### Intergration of Cycle
10 | class RecurrentCycle(torch.nn.Module):
11 | def __init__(self, cycle_len, channel_size):
12 | super(RecurrentCycle, self).__init__()
13 | self.cycle_len = cycle_len
14 | self.channel_size = channel_size
15 | self.data = torch.nn.Parameter(torch.zeros(cycle_len, channel_size), requires_grad=True)
16 |
17 | def forward(self, index, length):
18 | gather_index = (index.view(-1, 1) + torch.arange(length, device=index.device).view(1, -1)) % self.cycle_len
19 | return self.data[gather_index]
20 |
21 |
22 | class Model(nn.Module):
23 | def __init__(self, configs):
24 | super(Model, self).__init__()
25 | self.seq_len = configs.seq_len
26 | self.pred_len = configs.pred_len
27 | self.output_attention = configs.output_attention
28 | self.use_norm = configs.use_revin
29 | # Embedding
30 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq,
31 | configs.dropout)
32 | # Encoder-only architecture
33 | self.encoder = Encoder(
34 | [
35 | EncoderLayer(
36 | AttentionLayer(
37 | FullAttention(False, configs.factor, attention_dropout=configs.dropout,
38 | output_attention=configs.output_attention), configs.d_model, configs.n_heads),
39 | configs.d_model,
40 | configs.d_ff,
41 | dropout=configs.dropout,
42 | activation=configs.activation
43 | ) for l in range(configs.e_layers)
44 | ],
45 | norm_layer=torch.nn.LayerNorm(configs.d_model)
46 | )
47 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True)
48 |
49 | ### ### Intergration of Cycle
50 | self.cycle_len = configs.cycle
51 | self.enc_in = configs.enc_in
52 | self.cycleQueue = RecurrentCycle(cycle_len=self.cycle_len, channel_size=self.enc_in)
53 |
54 | def forecast(self, x_enc, cycle_index, x_mark_enc=None, x_dec=None, x_mark_dec=None):
55 |
56 | if self.use_norm:
57 | # Normalization from Non-stationary Transformer
58 | means = x_enc.mean(1, keepdim=True).detach()
59 | x_enc = x_enc - means
60 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
61 | x_enc /= stdev
62 |
63 | ### Intergration of Cycle
64 | # remove the cycle of the input data
65 | x_enc = x_enc - self.cycleQueue(cycle_index, self.seq_len)
66 |
67 | _, _, N = x_enc.shape # B L N
68 | # B: batch_size; E: d_model;
69 | # L: seq_len; S: pred_len;
70 | # N: number of variate (tokens), can also includes covariates
71 |
72 | # Embedding
73 | # B L N -> B N E (B L N -> B L E in the vanilla Transformer)
74 | enc_out = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens
75 |
76 | # B N E -> B N E (B L E -> B L E in the vanilla Transformer)
77 | # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules
78 | enc_out, attns = self.encoder(enc_out, attn_mask=None)
79 |
80 | # B N E -> B N S -> B S N
81 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates
82 |
83 | ### Intergration of Cycle
84 | # add back the cycle of the output data
85 | dec_out = dec_out + self.cycleQueue((cycle_index + self.seq_len) % self.cycle_len, self.pred_len)
86 |
87 | if self.use_norm:
88 | # De-Normalization from Non-stationary Transformer
89 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
90 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
91 |
92 |
93 | return dec_out
94 |
95 | def forward(self, x_enc, cycle_index, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
96 | dec_out = self.forecast(x_enc, cycle_index, x_mark_enc, x_dec, x_mark_dec)
97 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
--------------------------------------------------------------------------------
/models/DLinear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class moving_avg(nn.Module):
7 | """
8 | Moving average block to highlight the trend of time series
9 | """
10 | def __init__(self, kernel_size, stride):
11 | super(moving_avg, self).__init__()
12 | self.kernel_size = kernel_size
13 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
14 |
15 | def forward(self, x):
16 | # padding on the both ends of time series
17 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
18 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
19 | x = torch.cat([front, x, end], dim=1)
20 | x = self.avg(x.permute(0, 2, 1))
21 | x = x.permute(0, 2, 1)
22 | return x
23 |
24 |
25 | class series_decomp(nn.Module):
26 | """
27 | Series decomposition block
28 | """
29 | def __init__(self, kernel_size):
30 | super(series_decomp, self).__init__()
31 | self.moving_avg = moving_avg(kernel_size, stride=1)
32 |
33 | def forward(self, x):
34 | moving_mean = self.moving_avg(x)
35 | res = x - moving_mean
36 | return res, moving_mean
37 |
38 | class Model(nn.Module):
39 | """
40 | Decomposition-Linear
41 | """
42 | def __init__(self, configs):
43 | super(Model, self).__init__()
44 | self.seq_len = configs.seq_len
45 | self.pred_len = configs.pred_len
46 |
47 | # Decompsition Kernel Size
48 | kernel_size = 25
49 | self.decompsition = series_decomp(kernel_size)
50 | self.individual = configs.individual
51 | self.channels = configs.enc_in
52 |
53 | if self.individual:
54 | self.Linear_Seasonal = nn.ModuleList()
55 | self.Linear_Trend = nn.ModuleList()
56 |
57 | for i in range(self.channels):
58 | self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len))
59 | self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len))
60 |
61 | # Use this two lines if you want to visualize the weights
62 | # self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
63 | # self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
64 | else:
65 | self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len)
66 | self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len)
67 |
68 | # Use this two lines if you want to visualize the weights
69 | # self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
70 | # self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
71 |
72 | def forward(self, x):
73 | # x: [Batch, Input length, Channel]
74 | seasonal_init, trend_init = self.decompsition(x)
75 | seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)
76 | if self.individual:
77 | seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device)
78 | trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device)
79 | for i in range(self.channels):
80 | seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:])
81 | trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:])
82 | else:
83 | seasonal_output = self.Linear_Seasonal(seasonal_init)
84 | trend_output = self.Linear_Trend(trend_init)
85 |
86 | x = seasonal_output + trend_output
87 |
88 |
89 | return x.permute(0,2,1) # to [Batch, Output length, Channel]
90 |
--------------------------------------------------------------------------------
/models/Informer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from utils.masking import TriangularCausalMask, ProbMask
5 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
6 | from layers.SelfAttention_Family import FullAttention, ProbAttention, AttentionLayer
7 | from layers.Embed import DataEmbedding,DataEmbedding_wo_pos,DataEmbedding_wo_temp,DataEmbedding_wo_pos_temp
8 | import numpy as np
9 |
10 |
11 | class Model(nn.Module):
12 | """
13 | Informer with Propspare attention in O(LlogL) complexity
14 | """
15 | def __init__(self, configs):
16 | super(Model, self).__init__()
17 | self.pred_len = configs.pred_len
18 | self.output_attention = configs.output_attention
19 |
20 | # Embedding
21 | if configs.embed_type == 0:
22 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
23 | configs.dropout)
24 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
25 | configs.dropout)
26 | elif configs.embed_type == 1:
27 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
28 | configs.dropout)
29 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
30 | configs.dropout)
31 | elif configs.embed_type == 2:
32 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
33 | configs.dropout)
34 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq,
35 | configs.dropout)
36 |
37 | elif configs.embed_type == 3:
38 | self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq,
39 | configs.dropout)
40 | self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq,
41 | configs.dropout)
42 | elif configs.embed_type == 4:
43 | self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq,
44 | configs.dropout)
45 | self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq,
46 | configs.dropout)
47 | # Encoder
48 | self.encoder = Encoder(
49 | [
50 | EncoderLayer(
51 | AttentionLayer(
52 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout,
53 | output_attention=configs.output_attention),
54 | configs.d_model, configs.n_heads),
55 | configs.d_model,
56 | configs.d_ff,
57 | dropout=configs.dropout,
58 | activation=configs.activation
59 | ) for l in range(configs.e_layers)
60 | ],
61 | [
62 | ConvLayer(
63 | configs.d_model
64 | ) for l in range(configs.e_layers - 1)
65 | ] if configs.distil else None,
66 | norm_layer=torch.nn.LayerNorm(configs.d_model)
67 | )
68 | # Decoder
69 | self.decoder = Decoder(
70 | [
71 | DecoderLayer(
72 | AttentionLayer(
73 | ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False),
74 | configs.d_model, configs.n_heads),
75 | AttentionLayer(
76 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
77 | configs.d_model, configs.n_heads),
78 | configs.d_model,
79 | configs.d_ff,
80 | dropout=configs.dropout,
81 | activation=configs.activation,
82 | )
83 | for l in range(configs.d_layers)
84 | ],
85 | norm_layer=torch.nn.LayerNorm(configs.d_model),
86 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
87 | )
88 |
89 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
90 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
91 |
92 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
93 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
94 |
95 | dec_out = self.dec_embedding(x_dec, x_mark_dec)
96 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
97 |
98 | if self.output_attention:
99 | return dec_out[:, -self.pred_len:, :], attns
100 | else:
101 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
102 |
--------------------------------------------------------------------------------
/models/LDLinear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from torch import Tensor
5 | import torch.nn.functional as F
6 | from typing import Optional
7 |
8 | '''
9 | LDLinear: Replace the Moving Average Kernel (MOV) of DLinear with a Learnable Decomposition Module (LD), which is proposed in this paper:
10 | https://openreview.net/forum?id=87CYNyCGOo
11 | '''
12 | class LD(nn.Module):
13 | def __init__(self, kernel_size=25):
14 | super(LD, self).__init__()
15 | # Define a shared convolution layers for all channels
16 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, stride=1, padding=int(kernel_size // 2),
17 | padding_mode='replicate', bias=True)
18 | # Define the parameters for Gaussian initialization
19 | kernel_size_half = kernel_size // 2
20 | sigma = 1.0 # 1 for variance
21 | weights = torch.zeros(1, 1, kernel_size)
22 | for i in range(kernel_size):
23 | weights[0, 0, i] = math.exp(-((i - kernel_size_half) / (2 * sigma)) ** 2)
24 |
25 | # Set the weights of the convolution layer
26 | self.conv.weight.data = F.softmax(weights, dim=-1)
27 | self.conv.bias.data.fill_(0.0)
28 |
29 | def forward(self, inp):
30 | # Permute the input tensor to match the expected shape for 1D convolution (B, N, T)
31 | inp = inp.permute(0, 2, 1)
32 | # Split the input tensor into separate channels
33 | input_channels = torch.split(inp, 1, dim=1)
34 |
35 | # Apply convolution to each channel
36 | conv_outputs = [self.conv(input_channel) for input_channel in input_channels]
37 |
38 | # Concatenate the channel outputs
39 | out = torch.cat(conv_outputs, dim=1)
40 | out = out.permute(0, 2, 1)
41 | return out
42 |
43 |
44 | class Model(nn.Module):
45 |
46 | def __init__(self, configs):
47 | super(Model, self).__init__()
48 | self.seq_len = configs.seq_len
49 | self.pred_len = configs.pred_len
50 |
51 | # Decompsition Kernel Size
52 | kernel_size = 25
53 | self.LD = LD(kernel_size=kernel_size)
54 | # self.decompsition = series_decomp(kernel_size)
55 | self.channels = configs.enc_in
56 |
57 | self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
58 | self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
59 |
60 |
61 | def forward(self, x):
62 | # x: [Batch, Input length, Channel]
63 | trend_init = self.LD(x)
64 | seasonal_init = x - trend_init
65 | seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(0, 2, 1)
66 |
67 | seasonal_output = self.Linear_Seasonal(seasonal_init)
68 | trend_output = self.Linear_Trend(trend_init)
69 |
70 | x = seasonal_output + trend_output
71 |
72 | return x.permute(0, 2, 1) # to [Batch, Output length, Channel]
--------------------------------------------------------------------------------
/models/Linear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class Model(nn.Module):
7 | """
8 | Just one Linear layer
9 | """
10 | def __init__(self, configs):
11 | super(Model, self).__init__()
12 | self.seq_len = configs.seq_len
13 | self.pred_len = configs.pred_len
14 | self.Linear = nn.Linear(self.seq_len, self.pred_len)
15 | # Use this line if you want to visualize the weights
16 | # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
17 |
18 | def forward(self, x):
19 | # x: [Batch, Input length, Channel]
20 | x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
21 | return x # [Batch, Output length, Channel]
--------------------------------------------------------------------------------
/models/NLinear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class Model(nn.Module):
7 | """
8 | Normalization-Linear
9 | """
10 | def __init__(self, configs):
11 | super(Model, self).__init__()
12 | self.seq_len = configs.seq_len
13 | self.pred_len = configs.pred_len
14 | self.Linear = nn.Linear(self.seq_len, self.pred_len)
15 | # Use this line if you want to visualize the weights
16 | # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
17 |
18 | def forward(self, x):
19 | # x: [Batch, Input length, Channel]
20 | seq_last = x[:,-1:,:].detach()
21 | x = x - seq_last
22 | x = self.Linear(x.permute(0,2,1)).permute(0,2,1)
23 | x = x + seq_last
24 | return x # [Batch, Output length, Channel]
--------------------------------------------------------------------------------
/models/PatchTST.py:
--------------------------------------------------------------------------------
1 | __all__ = ['PatchTST']
2 |
3 | # Cell
4 | from typing import Callable, Optional
5 | import torch
6 | from torch import nn
7 | from torch import Tensor
8 | import torch.nn.functional as F
9 | import numpy as np
10 |
11 | from layers.PatchTST_backbone import PatchTST_backbone
12 | from layers.PatchTST_layers import series_decomp
13 |
14 |
15 | class Model(nn.Module):
16 | def __init__(self, configs, max_seq_len:Optional[int]=1024, d_k:Optional[int]=None, d_v:Optional[int]=None, norm:str='BatchNorm', attn_dropout:float=0.,
17 | act:str="gelu", key_padding_mask:bool='auto',padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True,
18 | pre_norm:bool=False, store_attn:bool=False, pe:str='zeros', learn_pe:bool=True, pretrain_head:bool=False, head_type = 'flatten', verbose:bool=False, **kwargs):
19 |
20 | super().__init__()
21 |
22 | # load parameters
23 | c_in = configs.enc_in
24 | context_window = configs.seq_len
25 | target_window = configs.pred_len
26 |
27 | n_layers = configs.e_layers
28 | n_heads = configs.n_heads
29 | d_model = configs.d_model
30 | d_ff = configs.d_ff
31 | dropout = configs.dropout
32 | fc_dropout = configs.fc_dropout
33 | head_dropout = configs.head_dropout
34 |
35 | individual = configs.individual
36 |
37 | patch_len = configs.patch_len
38 | stride = configs.stride
39 | padding_patch = configs.padding_patch
40 |
41 | revin = configs.revin
42 | affine = configs.affine
43 | subtract_last = configs.subtract_last
44 |
45 | decomposition = configs.decomposition
46 | kernel_size = configs.kernel_size
47 |
48 |
49 | # model
50 | self.decomposition = decomposition
51 | if self.decomposition:
52 | self.decomp_module = series_decomp(kernel_size)
53 | self.model_trend = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride,
54 | max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model,
55 | n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout,
56 | dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
57 | attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
58 | pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,
59 | pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine,
60 | subtract_last=subtract_last, verbose=verbose, **kwargs)
61 | self.model_res = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride,
62 | max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model,
63 | n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout,
64 | dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
65 | attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
66 | pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,
67 | pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine,
68 | subtract_last=subtract_last, verbose=verbose, **kwargs)
69 | else:
70 | self.model = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride,
71 | max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model,
72 | n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout,
73 | dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var,
74 | attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn,
75 | pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch,
76 | pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine,
77 | subtract_last=subtract_last, verbose=verbose, **kwargs)
78 |
79 |
80 | def forward(self, x): # x: [Batch, Input length, Channel]
81 | if self.decomposition:
82 | res_init, trend_init = self.decomp_module(x)
83 | res_init, trend_init = res_init.permute(0,2,1), trend_init.permute(0,2,1) # x: [Batch, Channel, Input length]
84 | res = self.model_res(res_init)
85 | trend = self.model_trend(trend_init)
86 | x = res + trend
87 | x = x.permute(0,2,1) # x: [Batch, Input length, Channel]
88 | else:
89 | x = x.permute(0,2,1) # x: [Batch, Channel, Input length]
90 | x = self.model(x)
91 | x = x.permute(0,2,1) # x: [Batch, Input length, Channel]
92 | return x
--------------------------------------------------------------------------------
/models/RLinear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class Model(nn.Module):
7 | def __init__(self, configs):
8 | super(Model, self).__init__()
9 | self.seq_len = configs.seq_len
10 | self.pred_len = configs.pred_len
11 | self.Linear = nn.Linear(self.seq_len, self.pred_len)
12 |
13 | def forward(self, x):
14 | # x: [Batch, Input length, Channel]
15 | seq_mean = torch.mean(x, dim=1, keepdim=True)
16 | seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5
17 | x = (x - seq_mean) / torch.sqrt(seq_var)
18 |
19 | y = self.Linear(x.permute(0,2,1)).permute(0,2,1)
20 |
21 | y = y * torch.sqrt(seq_var) + seq_mean
22 | return y # [Batch, Output length, Channel]
--------------------------------------------------------------------------------
/models/RMLP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class Model(nn.Module):
7 | def __init__(self, configs):
8 | super(Model, self).__init__()
9 | self.seq_len = configs.seq_len
10 | self.pred_len = configs.pred_len
11 | self.d_model = configs.d_model
12 | self.MLP = nn.Sequential(
13 | nn.Linear(self.seq_len, self.d_model),
14 | nn.ReLU(),
15 | nn.Linear(self.d_model, self.pred_len)
16 | )
17 |
18 | def forward(self, x):
19 | # x: [Batch, Input length, Channel]
20 | seq_mean = torch.mean(x, dim=1, keepdim=True)
21 | seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5
22 | x = (x - seq_mean) / torch.sqrt(seq_var)
23 |
24 | y = self.MLP(x.permute(0,2,1)).permute(0,2,1)
25 |
26 | y = y * torch.sqrt(seq_var) + seq_mean
27 | return y # [Batch, Output length, Channel]
--------------------------------------------------------------------------------
/models/SegRNN.py:
--------------------------------------------------------------------------------
1 | '''
2 | A complete implementation version containing all code (including ablation components)
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | from layers.RevIN import RevIN
8 |
9 | class Model(nn.Module):
10 | def __init__(self, configs):
11 | super(Model, self).__init__()
12 |
13 | # get parameters
14 | self.seq_len = configs.seq_len
15 | self.pred_len = configs.pred_len
16 | self.enc_in = configs.enc_in
17 | self.d_model = configs.d_model
18 | self.dropout = configs.dropout
19 |
20 | self.rnn_type = configs.rnn_type
21 | self.dec_way = configs.dec_way
22 | self.seg_len = configs.seg_len
23 | self.channel_id = configs.channel_id
24 | self.revin = configs.revin
25 |
26 | assert self.rnn_type in ['rnn', 'gru', 'lstm']
27 | assert self.dec_way in ['rmf', 'pmf']
28 |
29 | self.seg_num_x = self.seq_len//self.seg_len
30 |
31 | # build model
32 | self.valueEmbedding = nn.Sequential(
33 | nn.Linear(self.seg_len, self.d_model),
34 | nn.ReLU()
35 | )
36 |
37 | if self.rnn_type == "rnn":
38 | self.rnn = nn.RNN(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True,
39 | batch_first=True, bidirectional=False)
40 | elif self.rnn_type == "gru":
41 | self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True,
42 | batch_first=True, bidirectional=False)
43 | elif self.rnn_type == "lstm":
44 | self.rnn = nn.LSTM(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True,
45 | batch_first=True, bidirectional=False)
46 |
47 | if self.dec_way == "rmf":
48 | self.seg_num_y = self.pred_len // self.seg_len
49 | self.predict = nn.Sequential(
50 | nn.Dropout(self.dropout),
51 | nn.Linear(self.d_model, self.seg_len)
52 | )
53 | elif self.dec_way == "pmf":
54 | self.seg_num_y = self.pred_len // self.seg_len
55 |
56 | if self.channel_id:
57 | self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model // 2))
58 | self.channel_emb = nn.Parameter(torch.randn(self.enc_in, self.d_model // 2))
59 | else:
60 | self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model))
61 |
62 | self.predict = nn.Sequential(
63 | nn.Dropout(self.dropout),
64 | nn.Linear(self.d_model, self.seg_len)
65 | )
66 | if self.revin:
67 | self.revinLayer = RevIN(self.enc_in, affine=False, subtract_last=False)
68 |
69 |
70 | def forward(self, x):
71 |
72 | # b:batch_size c:channel_size s:seq_len s:seq_len
73 | # d:d_model w:seg_len n:seg_num_x m:seg_num_y
74 | batch_size = x.size(0)
75 |
76 | # normalization and permute b,s,c -> b,c,s
77 | if self.revin:
78 | x = self.revinLayer(x, 'norm').permute(0, 2, 1)
79 | else:
80 | seq_last = x[:, -1:, :].detach()
81 | x = (x - seq_last).permute(0, 2, 1) # b,c,s
82 |
83 | # segment and embedding b,c,s -> bc,n,w -> bc,n,d
84 | x = self.valueEmbedding(x.reshape(-1, self.seg_num_x, self.seg_len))
85 |
86 | # encoding
87 | if self.rnn_type == "lstm":
88 | _, (hn, cn) = self.rnn(x)
89 | else:
90 | _, hn = self.rnn(x) # bc,n,d 1,bc,d
91 |
92 | # decoding
93 | if self.dec_way == "rmf":
94 | y = []
95 | for i in range(self.seg_num_y):
96 | yy = self.predict(hn) # 1,bc,l
97 | yy = yy.permute(1,0,2) # bc,1,l
98 | y.append(yy)
99 | yy = self.valueEmbedding(yy)
100 | if self.rnn_type == "lstm":
101 | _, (hn, cn) = self.rnn(yy, (hn, cn))
102 | else:
103 | _, hn = self.rnn(yy, hn)
104 | y = torch.stack(y, dim=1).squeeze(2).reshape(-1, self.enc_in, self.pred_len) # b,c,s
105 | elif self.dec_way == "pmf":
106 | if self.channel_id:
107 | # m,d//2 -> 1,m,d//2 -> c,m,d//2
108 | # c,d//2 -> c,1,d//2 -> c,m,d//2
109 | # c,m,d -> cm,1,d -> bcm, 1, d
110 | pos_emb = torch.cat([
111 | self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1),
112 | self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1)
113 | ], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1)
114 | else:
115 | # m,d -> bcm,d -> bcm, 1, d
116 | pos_emb = self.pos_emb.repeat(batch_size * self.enc_in, 1).unsqueeze(1)
117 |
118 | # pos_emb: m,d -> bcm,d -> bcm,1,d
119 | # hn, cn: 1,bc,d -> 1,bc,md -> 1,bcm,d
120 | if self.rnn_type == "lstm":
121 | _, (hy, cy) = self.rnn(pos_emb,
122 | (hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model),
123 | cn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model)))
124 | else:
125 | _, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model))
126 | # 1,bcm,d -> 1,bcm,w -> b,c,s
127 | y = self.predict(hy).view(-1, self.enc_in, self.pred_len)
128 |
129 | # permute and denorm
130 | if self.revin:
131 | y = self.revinLayer(y.permute(0, 2, 1), 'denorm')
132 | else:
133 | y = y.permute(0, 2, 1) + seq_last
134 |
135 | return y
136 |
137 | '''
138 | Concise version implementation that only includes necessary code
139 | '''
140 | # import torch
141 | # import torch.nn as nn
142 | #
143 | # class Model(nn.Module):
144 | # def __init__(self, configs):
145 | # super(Model, self).__init__()
146 | #
147 | # # get parameters
148 | # self.seq_len = configs.seq_len
149 | # self.pred_len = configs.pred_len
150 | # self.enc_in = configs.enc_in
151 | # self.d_model = configs.d_model
152 | # self.dropout = configs.dropout
153 | #
154 | # self.seg_len = configs.seg_len
155 | # self.seg_num_x = self.seq_len//self.seg_len
156 | # self.seg_num_y = self.pred_len // self.seg_len
157 | #
158 | #
159 | # self.valueEmbedding = nn.Sequential(
160 | # nn.Linear(self.seg_len, self.d_model),
161 | # nn.ReLU()
162 | # )
163 | # self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True,
164 | # batch_first=True, bidirectional=False)
165 | # self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model // 2))
166 | # self.channel_emb = nn.Parameter(torch.randn(self.enc_in, self.d_model // 2))
167 | # self.predict = nn.Sequential(
168 | # nn.Dropout(self.dropout),
169 | # nn.Linear(self.d_model, self.seg_len)
170 | # )
171 | #
172 | # def forward(self, x):
173 | # # b:batch_size c:channel_size s:seq_len s:seq_len
174 | # # d:d_model w:seg_len n:seg_num_x m:seg_num_y
175 | # batch_size = x.size(0)
176 | #
177 | # # normalization and permute b,s,c -> b,c,s
178 | # seq_last = x[:, -1:, :].detach()
179 | # x = (x - seq_last).permute(0, 2, 1) # b,c,s
180 | #
181 | # # segment and embedding b,c,s -> bc,n,w -> bc,n,d
182 | # x = self.valueEmbedding(x.reshape(-1, self.seg_num_x, self.seg_len))
183 | #
184 | # # encoding
185 | # _, hn = self.rnn(x) # bc,n,d 1,bc,d
186 | #
187 | # # m,d//2 -> 1,m,d//2 -> c,m,d//2
188 | # # c,d//2 -> c,1,d//2 -> c,m,d//2
189 | # # c,m,d -> cm,1,d -> bcm, 1, d
190 | # pos_emb = torch.cat([
191 | # self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1),
192 | # self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1)
193 | # ], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1)
194 | #
195 | # _, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model)) # bcm,1,d 1,bcm,d
196 | #
197 | # # 1,bcm,d -> 1,bcm,w -> b,c,s
198 | # y = self.predict(hy).view(-1, self.enc_in, self.pred_len)
199 | #
200 | # # permute and denorm
201 | # y = y.permute(0, 2, 1) + seq_last
202 | #
203 | # return y
--------------------------------------------------------------------------------
/models/SparseTSF.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Model(nn.Module):
6 | def __init__(self, configs):
7 | super(Model, self).__init__()
8 |
9 | # get parameters
10 | self.seq_len = configs.seq_len
11 | self.pred_len = configs.pred_len
12 | self.enc_in = configs.enc_in
13 |
14 | self.use_revin = configs.use_revin
15 |
16 | self.period_len = configs.period_len
17 |
18 | self.seg_num_x = self.seq_len // self.period_len
19 | self.seg_num_y = self.pred_len // self.period_len
20 |
21 | self.conv1d = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=1 + 2 * self.period_len // 2,
22 | stride=1, padding=self.period_len // 2, padding_mode="zeros", bias=False)
23 |
24 | self.linear = nn.Linear(self.seg_num_x, self.seg_num_y, bias=False)
25 |
26 |
27 | def forward(self, x):
28 | batch_size = x.shape[0]
29 |
30 | # normalization and permute b,s,c -> b,c,s
31 | if self.use_revin:
32 | seq_mean = torch.mean(x, dim=1).unsqueeze(1)
33 | x = (x - seq_mean).permute(0, 2, 1)
34 | else:
35 | x = x.permute(0, 2, 1)
36 |
37 | x = self.conv1d(x.reshape(-1, 1, self.seq_len)).reshape(-1, self.enc_in, self.seq_len) + x
38 |
39 | # b,c,s -> bc,n,w -> bc,w,n
40 | x = x.reshape(-1, self.seg_num_x, self.period_len).permute(0, 2, 1)
41 |
42 | y = self.linear(x) # bc,w,m
43 |
44 | # bc,w,m -> bc,m,w -> b,c,s
45 | y = y.permute(0, 2, 1).reshape(batch_size, self.enc_in, self.pred_len)
46 |
47 | if self.use_revin:
48 | y = y.permute(0, 2, 1) + seq_mean
49 | else:
50 | y = y.permute(0, 2, 1)
51 |
52 | return y
--------------------------------------------------------------------------------
/models/Transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer
6 | from layers.Embed import DataEmbedding,DataEmbedding_wo_pos,DataEmbedding_wo_temp,DataEmbedding_wo_pos_temp
7 | import numpy as np
8 |
9 |
10 | class Model(nn.Module):
11 | """
12 | Vanilla Transformer with O(L^2) complexity
13 | """
14 | def __init__(self, configs):
15 | super(Model, self).__init__()
16 | self.pred_len = configs.pred_len
17 | self.output_attention = configs.output_attention
18 |
19 | # Embedding
20 | if configs.embed_type == 0:
21 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
22 | configs.dropout)
23 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
24 | configs.dropout)
25 | elif configs.embed_type == 1:
26 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
27 | configs.dropout)
28 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
29 | configs.dropout)
30 | elif configs.embed_type == 2:
31 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
32 | configs.dropout)
33 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq,
34 | configs.dropout)
35 |
36 | elif configs.embed_type == 3:
37 | self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq,
38 | configs.dropout)
39 | self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq,
40 | configs.dropout)
41 | elif configs.embed_type == 4:
42 | self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq,
43 | configs.dropout)
44 | self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq,
45 | configs.dropout)
46 | # Encoder
47 | self.encoder = Encoder(
48 | [
49 | EncoderLayer(
50 | AttentionLayer(
51 | FullAttention(False, configs.factor, attention_dropout=configs.dropout,
52 | output_attention=configs.output_attention), configs.d_model, configs.n_heads),
53 | configs.d_model,
54 | configs.d_ff,
55 | dropout=configs.dropout,
56 | activation=configs.activation
57 | ) for l in range(configs.e_layers)
58 | ],
59 | norm_layer=torch.nn.LayerNorm(configs.d_model)
60 | )
61 | # Decoder
62 | self.decoder = Decoder(
63 | [
64 | DecoderLayer(
65 | AttentionLayer(
66 | FullAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False),
67 | configs.d_model, configs.n_heads),
68 | AttentionLayer(
69 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
70 | configs.d_model, configs.n_heads),
71 | configs.d_model,
72 | configs.d_ff,
73 | dropout=configs.dropout,
74 | activation=configs.activation,
75 | )
76 | for l in range(configs.d_layers)
77 | ],
78 | norm_layer=torch.nn.LayerNorm(configs.d_model),
79 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
80 | )
81 |
82 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
83 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
84 |
85 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
86 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
87 |
88 | dec_out = self.dec_embedding(x_dec, x_mark_dec)
89 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
90 |
91 | if self.output_attention:
92 | return dec_out[:, -self.pred_len:, :], attns
93 | else:
94 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
95 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | matplotlib
3 | pandas
4 | scikit-learn
5 | torch
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from exp.exp_main import Exp_Main
5 | import random
6 | import numpy as np
7 |
8 | parser = argparse.ArgumentParser(description='Model family for Time Series Forecasting')
9 |
10 | # random seed
11 | parser.add_argument('--random_seed', type=int, default=2024, help='random seed')
12 |
13 | # basic config
14 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status')
15 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id')
16 | parser.add_argument('--model', type=str, required=True, default='Autoformer',
17 | help='model name, options: [Autoformer, Informer, Transformer]')
18 |
19 | # data loader
20 | parser.add_argument('--data', type=str, required=True, default='ETTh1', help='dataset type')
21 | parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file')
22 | parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
23 | parser.add_argument('--features', type=str, default='M',
24 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
25 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
26 | parser.add_argument('--freq', type=str, default='h',
27 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
28 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
29 |
30 | # forecasting task
31 | parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
32 | parser.add_argument('--label_len', type=int, default=0, help='start token length') #fixed
33 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
34 |
35 | # CycleNet.
36 | parser.add_argument('--cycle', type=int, default=24, help='cycle length')
37 | parser.add_argument('--model_type', type=str, default='mlp', help='model type, options: [linear, mlp]')
38 | parser.add_argument('--use_revin', type=int, default=1, help='1: use revin or 0: no revin')
39 |
40 | # DLinear
41 | #parser.add_argument('--individual', action='store_true', default=False, help='DLinear: a linear layer for each variate(channel) individually')
42 |
43 | # PatchTST
44 | parser.add_argument('--fc_dropout', type=float, default=0.05, help='fully connected dropout')
45 | parser.add_argument('--head_dropout', type=float, default=0.0, help='head dropout')
46 | parser.add_argument('--patch_len', type=int, default=16, help='patch length')
47 | parser.add_argument('--stride', type=int, default=8, help='stride')
48 | parser.add_argument('--padding_patch', default='end', help='None: None; end: padding on the end')
49 | parser.add_argument('--revin', type=int, default=0, help='RevIN; True 1 False 0')
50 | parser.add_argument('--affine', type=int, default=0, help='RevIN-affine; True 1 False 0')
51 | parser.add_argument('--subtract_last', type=int, default=0, help='0: subtract mean; 1: subtract last')
52 | parser.add_argument('--decomposition', type=int, default=0, help='decomposition; True 1 False 0')
53 | parser.add_argument('--kernel_size', type=int, default=25, help='decomposition-kernel')
54 | parser.add_argument('--individual', type=int, default=0, help='individual head; True 1 False 0')
55 |
56 | # SegRNN
57 | parser.add_argument('--rnn_type', default='gru', help='rnn_type')
58 | parser.add_argument('--dec_way', default='pmf', help='decode way')
59 | parser.add_argument('--seg_len', type=int, default=48, help='segment length')
60 | parser.add_argument('--channel_id', type=int, default=1, help='Whether to enable channel position encoding')
61 |
62 | # SparseTSF
63 | parser.add_argument('--period_len', type=int, default=24, help='period_len')
64 |
65 | # Formers
66 | parser.add_argument('--embed_type', type=int, default=0, help='0: default 1: value embedding + temporal embedding + positional embedding 2: value embedding + temporal embedding 3: value embedding + positional embedding 4: value embedding')
67 | parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') # DLinear with --individual, use this hyperparameter as the number of channels
68 | parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
69 | parser.add_argument('--c_out', type=int, default=7, help='output size')
70 | parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
71 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
72 | parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
73 | parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
74 | parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
75 | parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
76 | parser.add_argument('--factor', type=int, default=1, help='attn factor')
77 | parser.add_argument('--distil', action='store_false',
78 | help='whether to use distilling in encoder, using this argument means not using distilling',
79 | default=True)
80 | parser.add_argument('--dropout', type=float, default=0, help='dropout')
81 | parser.add_argument('--embed', type=str, default='timeF',
82 | help='time features encoding, options:[timeF, fixed, learned]')
83 | parser.add_argument('--activation', type=str, default='gelu', help='activation')
84 | parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
85 | parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data')
86 |
87 | # optimization
88 | parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
89 | parser.add_argument('--itr', type=int, default=1, help='experiments times')
90 | parser.add_argument('--train_epochs', type=int, default=30, help='train epochs')
91 | parser.add_argument('--batch_size', type=int, default=128, help='batch size of train input data')
92 | parser.add_argument('--patience', type=int, default=5, help='early stopping patience')
93 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
94 | parser.add_argument('--des', type=str, default='test', help='exp description')
95 | parser.add_argument('--loss', type=str, default='mse', help='loss function')
96 | parser.add_argument('--lradj', type=str, default='type3', help='adjust learning rate')
97 | parser.add_argument('--pct_start', type=float, default=0.3, help='pct_start')
98 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
99 |
100 | # GPU
101 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
102 | parser.add_argument('--gpu', type=int, default=0, help='gpu')
103 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
104 | parser.add_argument('--devices', type=str, default='0,1', help='device ids of multile gpus')
105 | parser.add_argument('--test_flop', action='store_true', default=False, help='See utils/tools for usage')
106 |
107 | args = parser.parse_args()
108 |
109 | # random seed
110 | fix_seed = args.random_seed
111 | random.seed(fix_seed)
112 | torch.manual_seed(fix_seed)
113 | np.random.seed(fix_seed)
114 |
115 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
116 |
117 | if args.use_gpu and args.use_multi_gpu:
118 | args.devices = args.devices.replace(' ', '')
119 | device_ids = args.devices.split(',')
120 | args.device_ids = [int(id_) for id_ in device_ids]
121 | args.gpu = args.device_ids[0]
122 |
123 | print('Args in experiment:')
124 | print(args)
125 |
126 | Exp = Exp_Main
127 |
128 | if args.is_training:
129 | for ii in range(args.itr):
130 |
131 | # setting record of experiments
132 | setting = '{}_{}_{}_ft{}_sl{}_pl{}_cycle{}_{}_seed{}'.format(
133 | args.model_id,
134 | args.model,
135 | args.data,
136 | args.features,
137 | args.seq_len,
138 | args.pred_len,
139 | args.cycle,
140 | args.model_type,
141 | fix_seed)
142 |
143 | exp = Exp(args) # set experiments
144 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
145 | exp.train(setting)
146 |
147 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
148 | exp.test(setting)
149 |
150 | if args.do_predict:
151 | print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
152 | exp.predict(setting, True)
153 |
154 | torch.cuda.empty_cache()
155 | else:
156 | ii = 0
157 | setting = '{}_{}_{}_ft{}_sl{}_pl{}_cycle{}_{}_seed{}'.format(
158 | args.model_id,
159 | args.model,
160 | args.data,
161 | args.features,
162 | args.seq_len,
163 | args.pred_len,
164 | args.cycle,
165 | args.model_type,
166 | fix_seed)
167 |
168 | exp = Exp(args) # set experiments
169 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
170 | exp.test(setting, test=1)
171 | torch.cuda.empty_cache()
172 |
--------------------------------------------------------------------------------
/run_main.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | sh scripts/CycleNet/Linear-Input-96/etth1.sh;
4 | sh scripts/CycleNet/Linear-Input-96/etth2.sh;
5 | sh scripts/CycleNet/Linear-Input-96/ettm1.sh;
6 | sh scripts/CycleNet/Linear-Input-96/ettm2.sh;
7 | sh scripts/CycleNet/Linear-Input-96/weather.sh;
8 | sh scripts/CycleNet/Linear-Input-96/electricity.sh;
9 | sh scripts/CycleNet/Linear-Input-96/traffic.sh;
10 | sh scripts/CycleNet/Linear-Input-96/solar.sh;
11 |
12 | sh scripts/CycleNet/Linear-Input-336/etth1.sh;
13 | sh scripts/CycleNet/Linear-Input-336/etth2.sh;
14 | sh scripts/CycleNet/Linear-Input-336/ettm1.sh;
15 | sh scripts/CycleNet/Linear-Input-336/ettm2.sh;
16 | sh scripts/CycleNet/Linear-Input-336/weather.sh;
17 | sh scripts/CycleNet/Linear-Input-336/electricity.sh;
18 | sh scripts/CycleNet/Linear-Input-336/traffic.sh;
19 | sh scripts/CycleNet/Linear-Input-336/solar.sh;
20 |
21 | sh scripts/CycleNet/Linear-Input-720/etth1.sh;
22 | sh scripts/CycleNet/Linear-Input-720/etth2.sh;
23 | sh scripts/CycleNet/Linear-Input-720/ettm1.sh;
24 | sh scripts/CycleNet/Linear-Input-720/ettm2.sh;
25 | sh scripts/CycleNet/Linear-Input-720/weather.sh;
26 | sh scripts/CycleNet/Linear-Input-720/electricity.sh;
27 | sh scripts/CycleNet/Linear-Input-720/traffic.sh;
28 | sh scripts/CycleNet/Linear-Input-720/solar.sh;
29 |
30 | sh scripts/CycleNet/MLP-Input-96/etth1.sh;
31 | sh scripts/CycleNet/MLP-Input-96/etth2.sh;
32 | sh scripts/CycleNet/MLP-Input-96/ettm1.sh;
33 | sh scripts/CycleNet/MLP-Input-96/ettm2.sh;
34 | sh scripts/CycleNet/MLP-Input-96/weather.sh;
35 | sh scripts/CycleNet/MLP-Input-96/electricity.sh;
36 | sh scripts/CycleNet/MLP-Input-96/traffic.sh;
37 | sh scripts/CycleNet/MLP-Input-96/solar.sh;
38 |
39 | sh scripts/CycleNet/MLP-Input-336/etth1.sh;
40 | sh scripts/CycleNet/MLP-Input-336/etth2.sh;
41 | sh scripts/CycleNet/MLP-Input-336/ettm1.sh;
42 | sh scripts/CycleNet/MLP-Input-336/ettm2.sh;
43 | sh scripts/CycleNet/MLP-Input-336/weather.sh;
44 | sh scripts/CycleNet/MLP-Input-336/electricity.sh;
45 | sh scripts/CycleNet/MLP-Input-336/traffic.sh;
46 | sh scripts/CycleNet/MLP-Input-336/solar.sh;
47 |
48 | sh scripts/CycleNet/MLP-Input-720/etth1.sh;
49 | sh scripts/CycleNet/MLP-Input-720/etth2.sh;
50 | sh scripts/CycleNet/MLP-Input-720/ettm1.sh;
51 | sh scripts/CycleNet/MLP-Input-720/ettm2.sh;
52 | sh scripts/CycleNet/MLP-Input-720/weather.sh;
53 | sh scripts/CycleNet/MLP-Input-720/electricity.sh;
54 | sh scripts/CycleNet/MLP-Input-720/traffic.sh;
55 | sh scripts/CycleNet/MLP-Input-720/solar.sh;
--------------------------------------------------------------------------------
/run_pems.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | sh scripts/CycleNet/PEMS/pems03.sh;
4 | sh scripts/CycleNet/PEMS/pems04.sh;
5 | sh scripts/CycleNet/PEMS/pems07.sh;
6 | sh scripts/CycleNet/PEMS/pems08.sh;
7 |
8 |
--------------------------------------------------------------------------------
/run_std.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | sh scripts/CycleNet/STD/CycleNet.sh;
4 | sh scripts/CycleNet/STD/LDLinear.sh;
5 | sh scripts/CycleNet/STD/DLinear.sh;
6 | sh scripts/CycleNet/STD/SparseTSF.sh;
7 | sh scripts/CycleNet/STD/Linear.sh;
8 |
9 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-336/electricity.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=electricity.csv
5 | model_id_name=Electricity
6 | data_name=custom
7 |
8 |
9 | model_type='linear'
10 | seq_len=336
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 321 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-336/etth1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh1.csv
5 | model_id_name=ETTh1
6 | data_name=ETTh1
7 |
8 |
9 | model_type='linear'
10 | seq_len=336
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 7 \
26 | --cycle 24 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-336/etth2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh2.csv
5 | model_id_name=ETTh2
6 | data_name=ETTh2
7 |
8 | model_type='linear'
9 | seq_len=336
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 24 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-336/ettm1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm1.csv
5 | model_id_name=ETTm1
6 | data_name=ETTm1
7 |
8 | model_type='linear'
9 | seq_len=336
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-336/ettm2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm2.csv
5 | model_id_name=ETTm2
6 | data_name=ETTm2
7 |
8 | model_type='linear'
9 | seq_len=336
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-336/solar.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=solar_AL.txt
5 | model_id_name=Solar
6 | data_name=Solar
7 |
8 | model_type='linear'
9 | seq_len=336
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 137 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --use_revin 0 \
30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-336/traffic.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=traffic.csv
5 | model_id_name=traffic
6 | data_name=custom
7 |
8 |
9 | model_type='linear'
10 | seq_len=336
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 862 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed
31 | done
32 | done
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-336/weather.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=weather.csv
5 | model_id_name=weather
6 | data_name=custom
7 |
8 | model_type='linear'
9 | seq_len=336
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 21 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-720/electricity.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=electricity.csv
5 | model_id_name=Electricity
6 | data_name=custom
7 |
8 |
9 | model_type='linear'
10 | seq_len=720
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 321 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-720/etth1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh1.csv
5 | model_id_name=ETTh1
6 | data_name=ETTh1
7 |
8 |
9 | model_type='linear'
10 | seq_len=720
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 7 \
26 | --cycle 24 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-720/etth2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh2.csv
5 | model_id_name=ETTh2
6 | data_name=ETTh2
7 |
8 | model_type='linear'
9 | seq_len=720
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 24 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-720/ettm1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm1.csv
5 | model_id_name=ETTm1
6 | data_name=ETTm1
7 |
8 | model_type='linear'
9 | seq_len=720
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-720/ettm2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm2.csv
5 | model_id_name=ETTm2
6 | data_name=ETTm2
7 |
8 | model_type='linear'
9 | seq_len=720
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-720/solar.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=solar_AL.txt
5 | model_id_name=Solar
6 | data_name=Solar
7 |
8 | model_type='linear'
9 | seq_len=720
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 137 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --use_revin 0 \
30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-720/traffic.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=traffic.csv
5 | model_id_name=traffic
6 | data_name=custom
7 |
8 |
9 | model_type='linear'
10 | seq_len=720
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 862 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed
31 | done
32 | done
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-720/weather.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=weather.csv
5 | model_id_name=weather
6 | data_name=custom
7 |
8 | model_type='linear'
9 | seq_len=720
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 21 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-96/electricity.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=electricity.csv
5 | model_id_name=Electricity
6 | data_name=custom
7 |
8 |
9 | model_type='linear'
10 | seq_len=96
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 321 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-96/etth1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh1.csv
5 | model_id_name=ETTh1
6 | data_name=ETTh1
7 |
8 |
9 | model_type='linear'
10 | seq_len=96
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 7 \
26 | --cycle 24 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-96/etth2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh2.csv
5 | model_id_name=ETTh2
6 | data_name=ETTh2
7 |
8 | model_type='linear'
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 24 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-96/ettm1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm1.csv
5 | model_id_name=ETTm1
6 | data_name=ETTm1
7 |
8 | model_type='linear'
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-96/ettm2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm2.csv
5 | model_id_name=ETTm2
6 | data_name=ETTm2
7 |
8 | model_type='linear'
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-96/solar.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=solar_AL.txt
5 | model_id_name=Solar
6 | data_name=Solar
7 |
8 | model_type='linear'
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 137 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --use_revin 0 \
30 | --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-96/traffic.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=traffic.csv
5 | model_id_name=traffic
6 | data_name=custom
7 |
8 |
9 | model_type='linear'
10 | seq_len=96
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 862 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed
31 | done
32 | done
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/Linear-Input-96/weather.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=weather.csv
5 | model_id_name=weather
6 | data_name=custom
7 |
8 | model_type='linear'
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 21 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-336/electricity.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=electricity.csv
5 | model_id_name=Electricity
6 | data_name=custom
7 |
8 |
9 | model_type='mlp'
10 | seq_len=336
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 321 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.002 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-336/etth1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh1.csv
5 | model_id_name=ETTh1
6 | data_name=ETTh1
7 |
8 |
9 | model_type='mlp'
10 | seq_len=336
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 7 \
26 | --cycle 24 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 256 --learning_rate 0.002 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-336/etth2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh2.csv
5 | model_id_name=ETTh2
6 | data_name=ETTh2
7 |
8 | model_type='mlp'
9 | seq_len=336
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 24 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.002 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-336/ettm1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm1.csv
5 | model_id_name=ETTm1
6 | data_name=ETTm1
7 |
8 | model_type='mlp'
9 | seq_len=336
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.002 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-336/ettm2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm2.csv
5 | model_id_name=ETTm2
6 | data_name=ETTm2
7 |
8 | model_type='mlp'
9 | seq_len=336
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.002 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-336/solar.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=solar_AL.txt
5 | model_id_name=Solar
6 | data_name=Solar
7 |
8 | model_type='mlp'
9 | seq_len=336
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 137 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --use_revin 0 \
30 | --itr 1 --batch_size 64 --learning_rate 0.002 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-336/traffic.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=traffic.csv
5 | model_id_name=traffic
6 | data_name=custom
7 |
8 |
9 | model_type='mlp'
10 | seq_len=336
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 862 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.002 --random_seed $random_seed
31 | done
32 | done
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-336/weather.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=weather.csv
5 | model_id_name=weather
6 | data_name=custom
7 |
8 | model_type='mlp'
9 | seq_len=336
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 21 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.002 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-720/electricity.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=electricity.csv
5 | model_id_name=Electricity
6 | data_name=custom
7 |
8 |
9 | model_type='mlp'
10 | seq_len=720
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 321 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.001 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-720/etth1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh1.csv
5 | model_id_name=ETTh1
6 | data_name=ETTh1
7 |
8 |
9 | model_type='mlp'
10 | seq_len=720
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 7 \
26 | --cycle 24 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 256 --learning_rate 0.0005 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-720/etth2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh2.csv
5 | model_id_name=ETTh2
6 | data_name=ETTh2
7 |
8 | model_type='mlp'
9 | seq_len=720
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 24 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.0005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-720/ettm1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm1.csv
5 | model_id_name=ETTm1
6 | data_name=ETTm1
7 |
8 | model_type='mlp'
9 | seq_len=720
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.0005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-720/ettm2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm2.csv
5 | model_id_name=ETTm2
6 | data_name=ETTm2
7 |
8 | model_type='mlp'
9 | seq_len=720
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.0005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-720/solar.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=solar_AL.txt
5 | model_id_name=Solar
6 | data_name=Solar
7 |
8 | model_type='mlp'
9 | seq_len=720
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 137 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --use_revin 0 \
30 | --itr 1 --batch_size 64 --learning_rate 0.0002 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-720/traffic.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=traffic.csv
5 | model_id_name=traffic
6 | data_name=custom
7 |
8 |
9 | model_type='mlp'
10 | seq_len=720
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 862 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.0005 --random_seed $random_seed
31 | done
32 | done
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-720/weather.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=weather.csv
5 | model_id_name=weather
6 | data_name=custom
7 |
8 | model_type='mlp'
9 | seq_len=720
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 21 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.0005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-96/electricity.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=electricity.csv
5 | model_id_name=Electricity
6 | data_name=custom
7 |
8 |
9 | model_type='mlp'
10 | seq_len=96
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 321 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-96/etth1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh1.csv
5 | model_id_name=ETTh1
6 | data_name=ETTh1
7 |
8 |
9 | model_type='mlp'
10 | seq_len=96
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 7 \
26 | --cycle 24 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
35 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-96/etth2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTh2.csv
5 | model_id_name=ETTh2
6 | data_name=ETTh2
7 |
8 | model_type='mlp'
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 24 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-96/ettm1.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm1.csv
5 | model_id_name=ETTm1
6 | data_name=ETTm1
7 |
8 | model_type='mlp'
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-96/ettm2.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=ETTm2.csv
5 | model_id_name=ETTm2
6 | data_name=ETTm2
7 |
8 | model_type='mlp'
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 7 \
25 | --cycle 96 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-96/solar.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=solar_AL.txt
5 | model_id_name=Solar
6 | data_name=Solar
7 |
8 | model_type='mlp'
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 137 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --use_revin 0 \
30 | --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed
31 | done
32 | done
33 |
34 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-96/traffic.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=traffic.csv
5 | model_id_name=traffic
6 | data_name=custom
7 |
8 |
9 | model_type='mlp'
10 | seq_len=96
11 | for pred_len in 96 192 336 720
12 | do
13 | for random_seed in 2024 2025 2026 2027 2028
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 862 \
26 | --cycle 168 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --itr 1 --batch_size 64 --learning_rate 0.002 --random_seed $random_seed
31 | done
32 | done
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/MLP-Input-96/weather.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=weather.csv
5 | model_id_name=weather
6 | data_name=custom
7 |
8 | model_type='mlp'
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024 2025 2026 2027 2028
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 21 \
25 | --cycle 144 \
26 | --model_type $model_type \
27 | --train_epochs 30 \
28 | --patience 5 \
29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
30 | done
31 | done
32 |
33 |
--------------------------------------------------------------------------------
/scripts/CycleNet/PEMS/pems03.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=PEMS03.npz
5 | model_id_name=PEMS03
6 | data_name=PEMS
7 |
8 |
9 | model_type='mlp'
10 | seq_len=96
11 | for pred_len in 12 24 48 96
12 | do
13 | for random_seed in 2024
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 358 \
26 | --cycle 288 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --use_revin 0 \
31 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed
32 | done
33 | done
34 |
35 |
36 | #model_type='linear'
37 | #seq_len=96
38 | #for pred_len in 12 24 48 96
39 | #do
40 | #for random_seed in 2024
41 | #do
42 | # python -u run.py \
43 | # --is_training 1 \
44 | # --root_path $root_path_name \
45 | # --data_path $data_path_name \
46 | # --model_id $model_id_name'_'$seq_len'_'$pred_len \
47 | # --model $model_name \
48 | # --data $data_name \
49 | # --features M \
50 | # --seq_len $seq_len \
51 | # --pred_len $pred_len \
52 | # --enc_in 358 \
53 | # --cycle 288 \
54 | # --model_type $model_type \
55 | # --train_epochs 30 \
56 | # --patience 5 \
57 | # --use_revin 0 \
58 | # --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed
59 | #done
60 | #done
61 |
62 |
63 |
--------------------------------------------------------------------------------
/scripts/CycleNet/PEMS/pems04.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=PEMS04.npz
5 | model_id_name=PEMS04
6 | data_name=PEMS
7 |
8 |
9 | model_type='mlp'
10 | seq_len=96
11 | for pred_len in 12 24 48 96
12 | do
13 | for random_seed in 2024
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 307 \
26 | --cycle 288 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --use_revin 0 \
31 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed
32 | done
33 | done
34 |
35 |
36 | #model_type='linear'
37 | #seq_len=96
38 | #for pred_len in 12 24 48 96
39 | #do
40 | #for random_seed in 2024
41 | #do
42 | # python -u run.py \
43 | # --is_training 1 \
44 | # --root_path $root_path_name \
45 | # --data_path $data_path_name \
46 | # --model_id $model_id_name'_'$seq_len'_'$pred_len \
47 | # --model $model_name \
48 | # --data $data_name \
49 | # --features M \
50 | # --seq_len $seq_len \
51 | # --pred_len $pred_len \
52 | # --enc_in 307 \
53 | # --cycle 288 \
54 | # --model_type $model_type \
55 | # --train_epochs 30 \
56 | # --patience 5 \
57 | # --use_revin 0 \
58 | # --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed
59 | #done
60 | #done
61 |
62 |
63 |
--------------------------------------------------------------------------------
/scripts/CycleNet/PEMS/pems07.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=PEMS07.npz
5 | model_id_name=PEMS07
6 | data_name=PEMS
7 |
8 |
9 | model_type='mlp'
10 | seq_len=96
11 | for pred_len in 12 24 48 96
12 | do
13 | for random_seed in 2024
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 883 \
26 | --cycle 288 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --patience 5 \
30 | --use_revin 0 \
31 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed
32 | done
33 | done
34 |
35 |
36 | #model_type='linear'
37 | #seq_len=96
38 | #for pred_len in 12 24 48 96
39 | #do
40 | #for random_seed in 2024
41 | #do
42 | # python -u run.py \
43 | # --is_training 1 \
44 | # --root_path $root_path_name \
45 | # --data_path $data_path_name \
46 | # --model_id $model_id_name'_'$seq_len'_'$pred_len \
47 | # --model $model_name \
48 | # --data $data_name \
49 | # --features M \
50 | # --seq_len $seq_len \
51 | # --pred_len $pred_len \
52 | # --enc_in 883 \
53 | # --cycle 288 \
54 | # --model_type $model_type \
55 | # --train_epochs 30 \
56 | # --use_revin 0 \
57 | # --patience 5 \
58 | # --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed
59 | #done
60 | #done
61 |
62 |
63 |
--------------------------------------------------------------------------------
/scripts/CycleNet/PEMS/pems08.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleNet
2 |
3 | root_path_name=./dataset/
4 | data_path_name=PEMS08.npz
5 | model_id_name=PEMS08
6 | data_name=PEMS
7 |
8 |
9 | model_type='mlp'
10 | seq_len=96
11 | for pred_len in 12 24 48 96
12 | do
13 | for random_seed in 2024
14 | do
15 | python -u run.py \
16 | --is_training 1 \
17 | --root_path $root_path_name \
18 | --data_path $data_path_name \
19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data $data_name \
22 | --features M \
23 | --seq_len $seq_len \
24 | --pred_len $pred_len \
25 | --enc_in 170 \
26 | --cycle 288 \
27 | --model_type $model_type \
28 | --train_epochs 30 \
29 | --use_revin 0 \
30 | --patience 5 \
31 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed
32 | done
33 | done
34 |
35 |
36 | #model_type='linear'
37 | #seq_len=96
38 | #for pred_len in 12 24 48 96
39 | #do
40 | #for random_seed in 2024
41 | #do
42 | # python -u run.py \
43 | # --is_training 1 \
44 | # --root_path $root_path_name \
45 | # --data_path $data_path_name \
46 | # --model_id $model_id_name'_'$seq_len'_'$pred_len \
47 | # --model $model_name \
48 | # --data $data_name \
49 | # --features M \
50 | # --seq_len $seq_len \
51 | # --pred_len $pred_len \
52 | # --enc_in 170 \
53 | # --cycle 288 \
54 | # --model_type $model_type \
55 | # --train_epochs 30 \
56 | # --patience 5 \
57 | # --use_revin 0 \
58 | # --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed
59 | #done
60 | #done
61 |
62 |
63 |
--------------------------------------------------------------------------------
/scripts/CycleNet/STD/CycleNet.sh:
--------------------------------------------------------------------------------
1 | root_path_name=./dataset/
2 | data_path_name=ETTh1.csv
3 | model_id_name=ETTh1
4 | data_name=ETTh1
5 |
6 | model_name=CycleNet
7 | model_type='linear'
8 | seq_len=336
9 | for pred_len in 96 192 336 720
10 | do
11 | for random_seed in 2024
12 | do
13 | python -u run.py \
14 | --is_training 1 \
15 | --root_path $root_path_name \
16 | --data_path $data_path_name \
17 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
18 | --model $model_name \
19 | --data $data_name \
20 | --features M \
21 | --seq_len $seq_len \
22 | --pred_len $pred_len \
23 | --model_type $model_type \
24 | --enc_in 7 \
25 | --cycle 24 \
26 | --train_epochs 30 \
27 | --patience 5 \
28 | --use_revin 0 \
29 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
30 | done
31 | done
32 |
33 | root_path_name=./dataset/
34 | data_path_name=ETTh2.csv
35 | model_id_name=ETTh2
36 | data_name=ETTh2
37 |
38 | model_name=CycleNet
39 | model_type='linear'
40 | seq_len=336
41 | for pred_len in 96 192 336 720
42 | do
43 | for random_seed in 2024
44 | do
45 | python -u run.py \
46 | --is_training 1 \
47 | --root_path $root_path_name \
48 | --data_path $data_path_name \
49 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
50 | --model $model_name \
51 | --data $data_name \
52 | --features M \
53 | --seq_len $seq_len \
54 | --pred_len $pred_len \
55 | --model_type $model_type \
56 | --enc_in 7 \
57 | --cycle 24 \
58 | --train_epochs 30 \
59 | --patience 5 \
60 | --use_revin 0 \
61 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
62 | done
63 | done
64 |
65 | root_path_name=./dataset/
66 | data_path_name=ETTm1.csv
67 | model_id_name=ETTm1
68 | data_name=ETTm1
69 |
70 | model_name=CycleNet
71 | model_type='linear'
72 | seq_len=336
73 | for pred_len in 96 192 336 720
74 | do
75 | for random_seed in 2024
76 | do
77 | python -u run.py \
78 | --is_training 1 \
79 | --root_path $root_path_name \
80 | --data_path $data_path_name \
81 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
82 | --model $model_name \
83 | --data $data_name \
84 | --features M \
85 | --seq_len $seq_len \
86 | --pred_len $pred_len \
87 | --model_type $model_type \
88 | --enc_in 7 \
89 | --cycle 96 \
90 | --train_epochs 30 \
91 | --patience 5 \
92 | --use_revin 0 \
93 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
94 | done
95 | done
96 |
97 | root_path_name=./dataset/
98 | data_path_name=ETTm2.csv
99 | model_id_name=ETTm2
100 | data_name=ETTm2
101 |
102 | model_name=CycleNet
103 | model_type='linear'
104 | seq_len=336
105 | for pred_len in 96 192 336 720
106 | do
107 | for random_seed in 2024
108 | do
109 | python -u run.py \
110 | --is_training 1 \
111 | --root_path $root_path_name \
112 | --data_path $data_path_name \
113 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
114 | --model $model_name \
115 | --data $data_name \
116 | --features M \
117 | --seq_len $seq_len \
118 | --pred_len $pred_len \
119 | --model_type $model_type \
120 | --enc_in 7 \
121 | --cycle 96 \
122 | --train_epochs 30 \
123 | --patience 5 \
124 | --use_revin 0 \
125 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
126 | done
127 | done
128 |
129 | root_path_name=./dataset/
130 | data_path_name=weather.csv
131 | model_id_name=weather
132 | data_name=custom
133 |
134 | model_name=CycleNet
135 | model_type='linear'
136 | seq_len=336
137 | for pred_len in 96 192 336 720
138 | do
139 | for random_seed in 2024
140 | do
141 | python -u run.py \
142 | --is_training 1 \
143 | --root_path $root_path_name \
144 | --data_path $data_path_name \
145 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
146 | --model $model_name \
147 | --data $data_name \
148 | --features M \
149 | --seq_len $seq_len \
150 | --pred_len $pred_len \
151 | --model_type $model_type \
152 | --enc_in 21 \
153 | --cycle 144 \
154 | --train_epochs 30 \
155 | --patience 5 \
156 | --use_revin 0 \
157 | --itr 1 --batch_size 256 --learning_rate 0.001 --random_seed $random_seed
158 | done
159 | done
160 |
161 | root_path_name=./dataset/
162 | data_path_name=solar_AL.txt
163 | model_id_name=Solar
164 | data_name=Solar
165 |
166 |
167 | model_name=CycleNet
168 | model_type='linear'
169 | seq_len=336
170 | for pred_len in 96 192 336 720
171 | do
172 | for random_seed in 2024
173 | do
174 | python -u run.py \
175 | --is_training 1 \
176 | --root_path $root_path_name \
177 | --data_path $data_path_name \
178 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
179 | --model $model_name \
180 | --data $data_name \
181 | --features M \
182 | --seq_len $seq_len \
183 | --pred_len $pred_len \
184 | --model_type $model_type \
185 | --enc_in 137 \
186 | --cycle 144 \
187 | --train_epochs 30 \
188 | --patience 5 \
189 | --use_revin 0 \
190 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
191 | done
192 | done
193 |
194 |
195 | root_path_name=./dataset/
196 | data_path_name=electricity.csv
197 | model_id_name=Electricity
198 | data_name=custom
199 |
200 | model_name=CycleNet
201 | model_type='linear'
202 | seq_len=336
203 | for pred_len in 96 192 336 720
204 | do
205 | for random_seed in 2024
206 | do
207 | python -u run.py \
208 | --is_training 1 \
209 | --root_path $root_path_name \
210 | --data_path $data_path_name \
211 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
212 | --model $model_name \
213 | --data $data_name \
214 | --features M \
215 | --seq_len $seq_len \
216 | --pred_len $pred_len \
217 | --model_type $model_type \
218 | --enc_in 321 \
219 | --cycle 168 \
220 | --train_epochs 30 \
221 | --patience 5 \
222 | --use_revin 0 \
223 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
224 | done
225 | done
226 |
227 |
228 |
229 | root_path_name=./dataset/
230 | data_path_name=traffic.csv
231 | model_id_name=traffic
232 | data_name=custom
233 |
234 | model_name=CycleNet
235 | model_type='linear'
236 | seq_len=336
237 | for pred_len in 96 192 336 720
238 | do
239 | for random_seed in 2024
240 | do
241 | python -u run.py \
242 | --is_training 1 \
243 | --root_path $root_path_name \
244 | --data_path $data_path_name \
245 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
246 | --model $model_name \
247 | --data $data_name \
248 | --features M \
249 | --seq_len $seq_len \
250 | --pred_len $pred_len \
251 | --model_type $model_type \
252 | --enc_in 862 \
253 | --cycle 168 \
254 | --train_epochs 30 \
255 | --patience 5 \
256 | --use_revin 0 \
257 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
258 | done
259 | done
260 |
--------------------------------------------------------------------------------
/scripts/CycleNet/STD/DLinear.sh:
--------------------------------------------------------------------------------
1 |
2 | root_path_name=./dataset/
3 | data_path_name=ETTh1.csv
4 | model_id_name=ETTh1
5 | data_name=ETTh1
6 |
7 | model_name=DLinear
8 | seq_len=336
9 | for pred_len in 96 192 336 720
10 | do
11 | for random_seed in 2024
12 | do
13 | python -u run.py \
14 | --is_training 1 \
15 | --root_path $root_path_name \
16 | --data_path $data_path_name \
17 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
18 | --model $model_name \
19 | --data $data_name \
20 | --features M \
21 | --seq_len $seq_len \
22 | --pred_len $pred_len \
23 | --enc_in 7 \
24 | --train_epochs 30 \
25 | --patience 5 \
26 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
27 | done
28 | done
29 |
30 | root_path_name=./dataset/
31 | data_path_name=ETTh2.csv
32 | model_id_name=ETTh2
33 | data_name=ETTh2
34 |
35 | model_name=DLinear
36 | seq_len=336
37 | for pred_len in 96 192 336 720
38 | do
39 | for random_seed in 2024
40 | do
41 | python -u run.py \
42 | --is_training 1 \
43 | --root_path $root_path_name \
44 | --data_path $data_path_name \
45 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
46 | --model $model_name \
47 | --data $data_name \
48 | --features M \
49 | --seq_len $seq_len \
50 | --pred_len $pred_len \
51 | --enc_in 7 \
52 | --train_epochs 30 \
53 | --patience 5 \
54 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
55 | done
56 | done
57 |
58 | root_path_name=./dataset/
59 | data_path_name=ETTm1.csv
60 | model_id_name=ETTm1
61 | data_name=ETTm1
62 |
63 | model_name=DLinear
64 | seq_len=336
65 | for pred_len in 96 192 336 720
66 | do
67 | for random_seed in 2024
68 | do
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path $root_path_name \
72 | --data_path $data_path_name \
73 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
74 | --model $model_name \
75 | --data $data_name \
76 | --features M \
77 | --seq_len $seq_len \
78 | --pred_len $pred_len \
79 | --enc_in 7 \
80 | --train_epochs 30 \
81 | --patience 5 \
82 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
83 | done
84 | done
85 |
86 |
87 | root_path_name=./dataset/
88 | data_path_name=ETTm2.csv
89 | model_id_name=ETTm2
90 | data_name=ETTm2
91 |
92 | model_name=DLinear
93 | seq_len=336
94 | for pred_len in 96 192 336 720
95 | do
96 | for random_seed in 2024
97 | do
98 | python -u run.py \
99 | --is_training 1 \
100 | --root_path $root_path_name \
101 | --data_path $data_path_name \
102 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
103 | --model $model_name \
104 | --data $data_name \
105 | --features M \
106 | --seq_len $seq_len \
107 | --pred_len $pred_len \
108 | --enc_in 7 \
109 | --train_epochs 30 \
110 | --patience 5 \
111 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
112 | done
113 | done
114 |
115 |
116 | root_path_name=./dataset/
117 | data_path_name=weather.csv
118 | model_id_name=weather
119 | data_name=custom
120 |
121 | model_name=DLinear
122 | seq_len=336
123 | for pred_len in 96 192 336 720
124 | do
125 | for random_seed in 2024
126 | do
127 | python -u run.py \
128 | --is_training 1 \
129 | --root_path $root_path_name \
130 | --data_path $data_path_name \
131 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
132 | --model $model_name \
133 | --data $data_name \
134 | --features M \
135 | --seq_len $seq_len \
136 | --pred_len $pred_len \
137 | --enc_in 21 \
138 | --train_epochs 30 \
139 | --patience 5 \
140 | --itr 1 --batch_size 256 --learning_rate 0.001 --random_seed $random_seed
141 | done
142 | done
143 |
144 |
145 | root_path_name=./dataset/
146 | data_path_name=solar_AL.txt
147 | model_id_name=Solar
148 | data_name=Solar
149 |
150 | model_name=DLinear
151 | seq_len=336
152 | for pred_len in 96 192 336 720
153 | do
154 | for random_seed in 2024
155 | do
156 | python -u run.py \
157 | --is_training 1 \
158 | --root_path $root_path_name \
159 | --data_path $data_path_name \
160 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
161 | --model $model_name \
162 | --data $data_name \
163 | --features M \
164 | --seq_len $seq_len \
165 | --pred_len $pred_len \
166 | --enc_in 137 \
167 | --train_epochs 30 \
168 | --patience 5 \
169 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
170 | done
171 | done
172 |
173 |
174 | root_path_name=./dataset/
175 | data_path_name=electricity.csv
176 | model_id_name=Electricity
177 | data_name=custom
178 |
179 | model_name=DLinear
180 | seq_len=336
181 | for pred_len in 96 192 336 720
182 | do
183 | for random_seed in 2024
184 | do
185 | python -u run.py \
186 | --is_training 1 \
187 | --root_path $root_path_name \
188 | --data_path $data_path_name \
189 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
190 | --model $model_name \
191 | --data $data_name \
192 | --features M \
193 | --seq_len $seq_len \
194 | --pred_len $pred_len \
195 | --enc_in 321 \
196 | --train_epochs 30 \
197 | --patience 5 \
198 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
199 | done
200 | done
201 |
202 | root_path_name=./dataset/
203 | data_path_name=traffic.csv
204 | model_id_name=traffic
205 | data_name=custom
206 |
207 | model_name=DLinear
208 | seq_len=336
209 | for pred_len in 96 192 336 720
210 | do
211 | for random_seed in 2024
212 | do
213 | python -u run.py \
214 | --is_training 1 \
215 | --root_path $root_path_name \
216 | --data_path $data_path_name \
217 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
218 | --model $model_name \
219 | --data $data_name \
220 | --features M \
221 | --seq_len $seq_len \
222 | --pred_len $pred_len \
223 | --enc_in 321 \
224 | --train_epochs 30 \
225 | --patience 5 \
226 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
227 | done
228 | done
--------------------------------------------------------------------------------
/scripts/CycleNet/STD/LDLinear.sh:
--------------------------------------------------------------------------------
1 |
2 | root_path_name=./dataset/
3 | data_path_name=ETTh1.csv
4 | model_id_name=ETTh1
5 | data_name=ETTh1
6 |
7 | model_name=LDLinear
8 | seq_len=336
9 | for pred_len in 96 192 336 720
10 | do
11 | for random_seed in 2024
12 | do
13 | python -u run.py \
14 | --is_training 1 \
15 | --root_path $root_path_name \
16 | --data_path $data_path_name \
17 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
18 | --model $model_name \
19 | --data $data_name \
20 | --features M \
21 | --seq_len $seq_len \
22 | --pred_len $pred_len \
23 | --enc_in 7 \
24 | --train_epochs 30 \
25 | --patience 5 \
26 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
27 | done
28 | done
29 |
30 | root_path_name=./dataset/
31 | data_path_name=ETTh2.csv
32 | model_id_name=ETTh2
33 | data_name=ETTh2
34 |
35 | model_name=LDLinear
36 | seq_len=336
37 | for pred_len in 96 192 336 720
38 | do
39 | for random_seed in 2024
40 | do
41 | python -u run.py \
42 | --is_training 1 \
43 | --root_path $root_path_name \
44 | --data_path $data_path_name \
45 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
46 | --model $model_name \
47 | --data $data_name \
48 | --features M \
49 | --seq_len $seq_len \
50 | --pred_len $pred_len \
51 | --enc_in 7 \
52 | --train_epochs 30 \
53 | --patience 5 \
54 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
55 | done
56 | done
57 |
58 | root_path_name=./dataset/
59 | data_path_name=ETTm1.csv
60 | model_id_name=ETTm1
61 | data_name=ETTm1
62 |
63 | model_name=LDLinear
64 | seq_len=336
65 | for pred_len in 96 192 336 720
66 | do
67 | for random_seed in 2024
68 | do
69 | python -u run.py \
70 | --is_training 1 \
71 | --root_path $root_path_name \
72 | --data_path $data_path_name \
73 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
74 | --model $model_name \
75 | --data $data_name \
76 | --features M \
77 | --seq_len $seq_len \
78 | --pred_len $pred_len \
79 | --enc_in 7 \
80 | --train_epochs 30 \
81 | --patience 5 \
82 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
83 | done
84 | done
85 |
86 |
87 | root_path_name=./dataset/
88 | data_path_name=ETTm2.csv
89 | model_id_name=ETTm2
90 | data_name=ETTm2
91 |
92 | model_name=LDLinear
93 | seq_len=336
94 | for pred_len in 96 192 336 720
95 | do
96 | for random_seed in 2024
97 | do
98 | python -u run.py \
99 | --is_training 1 \
100 | --root_path $root_path_name \
101 | --data_path $data_path_name \
102 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
103 | --model $model_name \
104 | --data $data_name \
105 | --features M \
106 | --seq_len $seq_len \
107 | --pred_len $pred_len \
108 | --enc_in 7 \
109 | --train_epochs 30 \
110 | --patience 5 \
111 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
112 | done
113 | done
114 |
115 |
116 | root_path_name=./dataset/
117 | data_path_name=weather.csv
118 | model_id_name=weather
119 | data_name=custom
120 |
121 | model_name=LDLinear
122 | seq_len=336
123 | for pred_len in 96 192 336 720
124 | do
125 | for random_seed in 2024
126 | do
127 | python -u run.py \
128 | --is_training 1 \
129 | --root_path $root_path_name \
130 | --data_path $data_path_name \
131 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
132 | --model $model_name \
133 | --data $data_name \
134 | --features M \
135 | --seq_len $seq_len \
136 | --pred_len $pred_len \
137 | --enc_in 21 \
138 | --train_epochs 30 \
139 | --patience 5 \
140 | --itr 1 --batch_size 256 --learning_rate 0.001 --random_seed $random_seed
141 | done
142 | done
143 |
144 |
145 | root_path_name=./dataset/
146 | data_path_name=solar_AL.txt
147 | model_id_name=Solar
148 | data_name=Solar
149 |
150 | model_name=LDLinear
151 | seq_len=336
152 | for pred_len in 96 192 336 720
153 | do
154 | for random_seed in 2024
155 | do
156 | python -u run.py \
157 | --is_training 1 \
158 | --root_path $root_path_name \
159 | --data_path $data_path_name \
160 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
161 | --model $model_name \
162 | --data $data_name \
163 | --features M \
164 | --seq_len $seq_len \
165 | --pred_len $pred_len \
166 | --enc_in 137 \
167 | --train_epochs 30 \
168 | --patience 5 \
169 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
170 | done
171 | done
172 |
173 |
174 | root_path_name=./dataset/
175 | data_path_name=electricity.csv
176 | model_id_name=Electricity
177 | data_name=custom
178 |
179 | model_name=LDLinear
180 | seq_len=336
181 | for pred_len in 96 192 336 720
182 | do
183 | for random_seed in 2024
184 | do
185 | python -u run.py \
186 | --is_training 1 \
187 | --root_path $root_path_name \
188 | --data_path $data_path_name \
189 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
190 | --model $model_name \
191 | --data $data_name \
192 | --features M \
193 | --seq_len $seq_len \
194 | --pred_len $pred_len \
195 | --enc_in 321 \
196 | --train_epochs 30 \
197 | --patience 5 \
198 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
199 | done
200 | done
201 |
202 | root_path_name=./dataset/
203 | data_path_name=traffic.csv
204 | model_id_name=traffic
205 | data_name=custom
206 |
207 | model_name=LDLinear
208 | seq_len=336
209 | for pred_len in 96 192 336 720
210 | do
211 | for random_seed in 2024
212 | do
213 | python -u run.py \
214 | --is_training 1 \
215 | --root_path $root_path_name \
216 | --data_path $data_path_name \
217 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
218 | --model $model_name \
219 | --data $data_name \
220 | --features M \
221 | --seq_len $seq_len \
222 | --pred_len $pred_len \
223 | --enc_in 321 \
224 | --train_epochs 30 \
225 | --patience 5 \
226 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
227 | done
228 | done
--------------------------------------------------------------------------------
/scripts/CycleNet/STD/Linear.sh:
--------------------------------------------------------------------------------
1 |
2 | root_path_name=./dataset/
3 | data_path_name=ETTh1.csv
4 | model_id_name=ETTh1
5 | data_name=ETTh1
6 |
7 | model_name=Linear
8 | seq_len=336
9 | for pred_len in 96 192 336 720
10 | do
11 | for random_seed in 2024
12 | do
13 | python -u run.py \
14 | --is_training 1 \
15 | --root_path $root_path_name \
16 | --data_path $data_path_name \
17 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
18 | --model $model_name \
19 | --data $data_name \
20 | --features M \
21 | --seq_len $seq_len \
22 | --pred_len $pred_len \
23 | --enc_in 7 \
24 | --train_epochs 30 \
25 | --patience 5 \
26 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
27 | done
28 | done
29 |
30 | root_path_name=./dataset/
31 | data_path_name=ETTh2.csv
32 | model_id_name=ETTh2
33 | data_name=ETTh2
34 |
35 | model_name=Linear
36 | seq_len=336
37 | for pred_len in 96 192 336 720
38 | do
39 | for random_seed in 2024
40 | do
41 | python -u run.py \
42 | --is_training 1 \
43 | --root_path $root_path_name \
44 | --data_path $data_path_name \
45 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
46 | --model $model_name \
47 | --data $data_name \
48 | --features M \
49 | --seq_len $seq_len \
50 | --pred_len $pred_len \
51 | --enc_in 7 \
52 | --train_epochs 30 \
53 | --patience 5 \
54 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
55 | done
56 | done
57 |
58 |
59 | root_path_name=./dataset/
60 | data_path_name=ETTm1.csv
61 | model_id_name=ETTm1
62 | data_name=ETTm1
63 |
64 | model_name=Linear
65 | seq_len=336
66 | for pred_len in 96 192 336 720
67 | do
68 | for random_seed in 2024
69 | do
70 | python -u run.py \
71 | --is_training 1 \
72 | --root_path $root_path_name \
73 | --data_path $data_path_name \
74 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
75 | --model $model_name \
76 | --data $data_name \
77 | --features M \
78 | --seq_len $seq_len \
79 | --pred_len $pred_len \
80 | --enc_in 7 \
81 | --train_epochs 30 \
82 | --patience 5 \
83 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
84 | done
85 | done
86 |
87 | root_path_name=./dataset/
88 | data_path_name=ETTm2.csv
89 | model_id_name=ETTm2
90 | data_name=ETTm2
91 |
92 | model_name=Linear
93 | seq_len=336
94 | for pred_len in 96 192 336 720
95 | do
96 | for random_seed in 2024
97 | do
98 | python -u run.py \
99 | --is_training 1 \
100 | --root_path $root_path_name \
101 | --data_path $data_path_name \
102 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
103 | --model $model_name \
104 | --data $data_name \
105 | --features M \
106 | --seq_len $seq_len \
107 | --pred_len $pred_len \
108 | --enc_in 7 \
109 | --train_epochs 30 \
110 | --patience 5 \
111 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
112 | done
113 | done
114 |
115 |
116 | root_path_name=./dataset/
117 | data_path_name=weather.csv
118 | model_id_name=weather
119 | data_name=custom
120 |
121 | model_name=Linear
122 | seq_len=336
123 | for pred_len in 96 192 336 720
124 | do
125 | for random_seed in 2024
126 | do
127 | python -u run.py \
128 | --is_training 1 \
129 | --root_path $root_path_name \
130 | --data_path $data_path_name \
131 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
132 | --model $model_name \
133 | --data $data_name \
134 | --features M \
135 | --seq_len $seq_len \
136 | --pred_len $pred_len \
137 | --enc_in 21 \
138 | --train_epochs 30 \
139 | --patience 5 \
140 | --itr 1 --batch_size 256 --learning_rate 0.001 --random_seed $random_seed
141 | done
142 | done
143 |
144 |
145 | root_path_name=./dataset/
146 | data_path_name=solar_AL.txt
147 | model_id_name=Solar
148 | data_name=Solar
149 |
150 | model_name=Linear
151 | seq_len=336
152 | for pred_len in 96 192 336 720
153 | do
154 | for random_seed in 2024
155 | do
156 | python -u run.py \
157 | --is_training 1 \
158 | --root_path $root_path_name \
159 | --data_path $data_path_name \
160 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
161 | --model $model_name \
162 | --data $data_name \
163 | --features M \
164 | --seq_len $seq_len \
165 | --pred_len $pred_len \
166 | --enc_in 137 \
167 | --train_epochs 30 \
168 | --patience 5 \
169 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
170 | done
171 | done
172 |
173 |
174 | root_path_name=./dataset/
175 | data_path_name=electricity.csv
176 | model_id_name=Electricity
177 | data_name=custom
178 |
179 | model_name=Linear
180 | seq_len=336
181 | for pred_len in 96 192 336 720
182 | do
183 | for random_seed in 2024
184 | do
185 | python -u run.py \
186 | --is_training 1 \
187 | --root_path $root_path_name \
188 | --data_path $data_path_name \
189 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
190 | --model $model_name \
191 | --data $data_name \
192 | --features M \
193 | --seq_len $seq_len \
194 | --pred_len $pred_len \
195 | --enc_in 321 \
196 | --train_epochs 30 \
197 | --patience 5 \
198 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
199 | done
200 | done
201 |
202 | root_path_name=./dataset/
203 | data_path_name=traffic.csv
204 | model_id_name=traffic
205 | data_name=custom
206 |
207 | model_name=Linear
208 | seq_len=336
209 | for pred_len in 96 192 336 720
210 | do
211 | for random_seed in 2024
212 | do
213 | python -u run.py \
214 | --is_training 1 \
215 | --root_path $root_path_name \
216 | --data_path $data_path_name \
217 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
218 | --model $model_name \
219 | --data $data_name \
220 | --features M \
221 | --seq_len $seq_len \
222 | --pred_len $pred_len \
223 | --enc_in 321 \
224 | --train_epochs 30 \
225 | --patience 5 \
226 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed
227 | done
228 | done
--------------------------------------------------------------------------------
/scripts/CycleNet/STD/SparseTSF.sh:
--------------------------------------------------------------------------------
1 | root_path_name=./dataset/
2 | data_path_name=ETTh1.csv
3 | model_id_name=ETTh1
4 | data_name=ETTh1
5 |
6 | model_name=SparseTSF
7 | seq_len=336
8 | for pred_len in 96 192 336 720
9 | do
10 | for random_seed in 2024
11 | do
12 | python -u run.py \
13 | --is_training 1 \
14 | --root_path $root_path_name \
15 | --data_path $data_path_name \
16 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
17 | --model $model_name \
18 | --data $data_name \
19 | --features M \
20 | --seq_len $seq_len \
21 | --pred_len $pred_len \
22 | --period_len 24 \
23 | --enc_in 7 \
24 | --train_epochs 30 \
25 | --patience 5 \
26 | --use_revin 0 \
27 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed
28 | done
29 | done
30 |
31 |
32 |
33 | root_path_name=./dataset/
34 | data_path_name=ETTh2.csv
35 | model_id_name=ETTh2
36 | data_name=ETTh2
37 |
38 | model_name=SparseTSF
39 | seq_len=336
40 | for pred_len in 96 192 336 720
41 | do
42 | for random_seed in 2024
43 | do
44 | python -u run.py \
45 | --is_training 1 \
46 | --root_path $root_path_name \
47 | --data_path $data_path_name \
48 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
49 | --model $model_name \
50 | --data $data_name \
51 | --features M \
52 | --seq_len $seq_len \
53 | --pred_len $pred_len \
54 | --enc_in 7 \
55 | --period_len 24 \
56 | --train_epochs 30 \
57 | --patience 5 \
58 | --use_revin 0 \
59 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed
60 | done
61 | done
62 |
63 | root_path_name=./dataset/
64 | data_path_name=ETTm1.csv
65 | model_id_name=ETTm1
66 | data_name=ETTm1
67 |
68 | model_name=SparseTSF
69 | seq_len=336
70 | for pred_len in 96 192 336 720
71 | do
72 | for random_seed in 2024
73 | do
74 | python -u run.py \
75 | --is_training 1 \
76 | --root_path $root_path_name \
77 | --data_path $data_path_name \
78 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
79 | --model $model_name \
80 | --data $data_name \
81 | --features M \
82 | --seq_len $seq_len \
83 | --pred_len $pred_len \
84 | --enc_in 7 \
85 | --period_len 4 \
86 | --train_epochs 30 \
87 | --patience 5 \
88 | --use_revin 0 \
89 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed
90 | done
91 | done
92 |
93 | root_path_name=./dataset/
94 | data_path_name=ETTm2.csv
95 | model_id_name=ETTm2
96 | data_name=ETTm2
97 |
98 | model_name=SparseTSF
99 | seq_len=336
100 | for pred_len in 96 192 336 720
101 | do
102 | for random_seed in 2024
103 | do
104 | python -u run.py \
105 | --is_training 1 \
106 | --root_path $root_path_name \
107 | --data_path $data_path_name \
108 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
109 | --model $model_name \
110 | --data $data_name \
111 | --features M \
112 | --seq_len $seq_len \
113 | --pred_len $pred_len \
114 | --period_len 24 \
115 | --enc_in 7 \
116 | --train_epochs 30 \
117 | --patience 5 \
118 | --use_revin 0 \
119 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed
120 | done
121 | done
122 |
123 |
124 |
125 | root_path_name=./dataset/
126 | data_path_name=weather.csv
127 | model_id_name=weather
128 | data_name=custom
129 |
130 | model_name=SparseTSF
131 | seq_len=336
132 | for pred_len in 96 192 336 720
133 | do
134 | for random_seed in 2024
135 | do
136 | python -u run.py \
137 | --is_training 1 \
138 | --root_path $root_path_name \
139 | --data_path $data_path_name \
140 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
141 | --model $model_name \
142 | --data $data_name \
143 | --features M \
144 | --seq_len $seq_len \
145 | --pred_len $pred_len \
146 | --period_len 4 \
147 | --enc_in 21 \
148 | --train_epochs 30 \
149 | --patience 5 \
150 | --use_revin 0 \
151 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed
152 | done
153 | done
154 |
155 |
156 | root_path_name=./dataset/
157 | data_path_name=solar_AL.txt
158 | model_id_name=Solar
159 | data_name=Solar
160 |
161 | model_name=SparseTSF
162 | seq_len=336
163 | for pred_len in 96 192 336 720
164 | do
165 | for random_seed in 2024
166 | do
167 | python -u run.py \
168 | --is_training 1 \
169 | --root_path $root_path_name \
170 | --data_path $data_path_name \
171 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
172 | --model $model_name \
173 | --data $data_name \
174 | --features M \
175 | --seq_len $seq_len \
176 | --pred_len $pred_len \
177 | --period_len 4 \
178 | --enc_in 137 \
179 | --train_epochs 30 \
180 | --patience 5 \
181 | --use_revin 0 \
182 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed
183 | done
184 | done
185 |
186 |
187 |
188 | root_path_name=./dataset/
189 | data_path_name=electricity.csv
190 | model_id_name=Electricity
191 | data_name=custom
192 |
193 | model_name=SparseTSF
194 | seq_len=336
195 | for pred_len in 96 192 336 720
196 | do
197 | for random_seed in 2024
198 | do
199 | python -u run.py \
200 | --is_training 1 \
201 | --root_path $root_path_name \
202 | --data_path $data_path_name \
203 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
204 | --model $model_name \
205 | --data $data_name \
206 | --features M \
207 | --seq_len $seq_len \
208 | --pred_len $pred_len \
209 | --period_len 24 \
210 | --enc_in 862 \
211 | --train_epochs 30 \
212 | --patience 5 \
213 | --use_revin 0 \
214 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed
215 | done
216 | done
217 |
218 | root_path_name=./dataset/
219 | data_path_name=traffic.csv
220 | model_id_name=traffic
221 | data_name=custom
222 |
223 |
224 | model_name=SparseTSF
225 | seq_len=336
226 | for pred_len in 96 192 336 720
227 | do
228 | for random_seed in 2024
229 | do
230 | python -u run.py \
231 | --is_training 1 \
232 | --root_path $root_path_name \
233 | --data_path $data_path_name \
234 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
235 | --model $model_name \
236 | --data $data_name \
237 | --features M \
238 | --seq_len $seq_len \
239 | --pred_len $pred_len \
240 | --period_len 24 \
241 | --enc_in 862 \
242 | --train_epochs 30 \
243 | --patience 5 \
244 | --use_revin 0 \
245 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed
246 | done
247 | done
--------------------------------------------------------------------------------
/scripts/iTransformer/electricity.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleiTransformer
2 |
3 | root_path_name=./dataset/
4 | data_path_name=electricity.csv
5 | model_id_name=Electricity
6 | data_name=custom
7 |
8 |
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 321 \
25 | --cycle 168 \
26 | --d_model 512 \
27 | --d_ff 512 \
28 | --dropout 0.1 \
29 | --e_layers 3 \
30 | --train_epochs 10 \
31 | --patience 3 \
32 | --itr 1 --batch_size 16 --learning_rate 0.0005 --random_seed $random_seed
33 | done
34 | done
35 |
36 |
37 |
--------------------------------------------------------------------------------
/scripts/iTransformer/traffic.sh:
--------------------------------------------------------------------------------
1 | model_name=CycleiTransformer
2 |
3 | root_path_name=./dataset/
4 | data_path_name=traffic.csv
5 | model_id_name=traffic
6 | data_name=custom
7 |
8 |
9 | seq_len=96
10 | for pred_len in 96 192 336 720
11 | do
12 | for random_seed in 2024
13 | do
14 | python -u run.py \
15 | --is_training 1 \
16 | --root_path $root_path_name \
17 | --data_path $data_path_name \
18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \
19 | --model $model_name \
20 | --data $data_name \
21 | --features M \
22 | --seq_len $seq_len \
23 | --pred_len $pred_len \
24 | --enc_in 862 \
25 | --cycle 168 \
26 | --d_model 512 \
27 | --d_ff 512 \
28 | --e_layers 4 \
29 | --dropout 0.1 \
30 | --train_epochs 30 \
31 | --patience 5 \
32 | --itr 1 --batch_size 16 --learning_rate 0.001 --random_seed $random_seed
33 | done
34 | done
35 |
--------------------------------------------------------------------------------
/utils/masking.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TriangularCausalMask():
5 | def __init__(self, B, L, device="cpu"):
6 | mask_shape = [B, 1, L, L]
7 | with torch.no_grad():
8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
9 |
10 | @property
11 | def mask(self):
12 | return self._mask
13 |
14 |
15 | class ProbMask():
16 | def __init__(self, B, H, L, index, scores, device="cpu"):
17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
19 | indicator = _mask_ex[torch.arange(B)[:, None, None],
20 | torch.arange(H)[None, :, None],
21 | index, :].to(device)
22 | self._mask = indicator.view(scores.shape).to(device)
23 |
24 | @property
25 | def mask(self):
26 | return self._mask
27 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def RSE(pred, true):
5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
6 |
7 |
8 | def CORR(pred, true):
9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
11 | d += 1e-12
12 | return 0.01*(u / d).mean(-1)
13 |
14 |
15 | def MAE(pred, true):
16 | return np.mean(np.abs(pred - true))
17 |
18 |
19 | def MSE(pred, true):
20 | return np.mean((pred - true) ** 2)
21 |
22 |
23 | def RMSE(pred, true):
24 | return np.sqrt(MSE(pred, true))
25 |
26 |
27 | def MAPE(pred, true):
28 | return np.mean(np.abs((pred - true) / true))
29 |
30 |
31 | def MSPE(pred, true):
32 | return np.mean(np.square((pred - true) / true))
33 |
34 |
35 | def metric(pred, true):
36 | mae = MAE(pred, true)
37 | mse = MSE(pred, true)
38 | rmse = RMSE(pred, true)
39 | mape = MAPE(pred, true)
40 | mspe = MSPE(pred, true)
41 | rse = RSE(pred, true)
42 | corr = CORR(pred, true)
43 |
44 | return mae, mse, rmse, mape, mspe, rse, corr
45 |
--------------------------------------------------------------------------------
/utils/timefeatures.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from pandas.tseries import offsets
6 | from pandas.tseries.frequencies import to_offset
7 |
8 |
9 | class TimeFeature:
10 | def __init__(self):
11 | pass
12 |
13 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
14 | pass
15 |
16 | def __repr__(self):
17 | return self.__class__.__name__ + "()"
18 |
19 |
20 | class SecondOfMinute(TimeFeature):
21 | """Minute of hour encoded as value between [-0.5, 0.5]"""
22 |
23 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
24 | return index.second / 59.0 - 0.5
25 |
26 |
27 | class MinuteOfHour(TimeFeature):
28 | """Minute of hour encoded as value between [-0.5, 0.5]"""
29 |
30 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
31 | return index.minute / 59.0 - 0.5
32 |
33 |
34 | class HourOfDay(TimeFeature):
35 | """Hour of day encoded as value between [-0.5, 0.5]"""
36 |
37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
38 | return index.hour / 23.0 - 0.5
39 |
40 |
41 | class DayOfWeek(TimeFeature):
42 | """Hour of day encoded as value between [-0.5, 0.5]"""
43 |
44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
45 | return index.dayofweek / 6.0 - 0.5
46 |
47 |
48 | class DayOfMonth(TimeFeature):
49 | """Day of month encoded as value between [-0.5, 0.5]"""
50 |
51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
52 | return (index.day - 1) / 30.0 - 0.5
53 |
54 |
55 | class DayOfYear(TimeFeature):
56 | """Day of year encoded as value between [-0.5, 0.5]"""
57 |
58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
59 | return (index.dayofyear - 1) / 365.0 - 0.5
60 |
61 |
62 | class MonthOfYear(TimeFeature):
63 | """Month of year encoded as value between [-0.5, 0.5]"""
64 |
65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
66 | return (index.month - 1) / 11.0 - 0.5
67 |
68 |
69 | class WeekOfYear(TimeFeature):
70 | """Week of year encoded as value between [-0.5, 0.5]"""
71 |
72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
73 | return (index.isocalendar().week - 1) / 52.0 - 0.5
74 |
75 |
76 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
77 | """
78 | Returns a list of time features that will be appropriate for the given frequency string.
79 | Parameters
80 | ----------
81 | freq_str
82 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
83 | """
84 |
85 | features_by_offsets = {
86 | offsets.YearEnd: [],
87 | offsets.QuarterEnd: [MonthOfYear],
88 | offsets.MonthEnd: [MonthOfYear],
89 | offsets.Week: [DayOfMonth, WeekOfYear],
90 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
91 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
92 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
93 | offsets.Minute: [
94 | MinuteOfHour,
95 | HourOfDay,
96 | DayOfWeek,
97 | DayOfMonth,
98 | DayOfYear,
99 | ],
100 | offsets.Second: [
101 | SecondOfMinute,
102 | MinuteOfHour,
103 | HourOfDay,
104 | DayOfWeek,
105 | DayOfMonth,
106 | DayOfYear,
107 | ],
108 | }
109 |
110 | offset = to_offset(freq_str)
111 |
112 | for offset_type, feature_classes in features_by_offsets.items():
113 | if isinstance(offset, offset_type):
114 | return [cls() for cls in feature_classes]
115 |
116 | supported_freq_msg = f"""
117 | Unsupported frequency {freq_str}
118 | The following frequencies are supported:
119 | Y - yearly
120 | alias: A
121 | M - monthly
122 | W - weekly
123 | D - daily
124 | B - business days
125 | H - hourly
126 | T - minutely
127 | alias: min
128 | S - secondly
129 | """
130 | raise RuntimeError(supported_freq_msg)
131 |
132 |
133 | def time_features(dates, freq='h'):
134 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])
135 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | import time
5 |
6 | plt.switch_backend('agg')
7 |
8 |
9 | def adjust_learning_rate(optimizer, scheduler, epoch, args, printout=True):
10 | # lr = args.learning_rate * (0.2 ** (epoch // 2))
11 | if args.lradj == 'type1':
12 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
13 | elif args.lradj == 'type2':
14 | lr_adjust = {
15 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
16 | 10: 5e-7, 15: 1e-7, 20: 5e-8
17 | }
18 | elif args.lradj == 'type3':
19 | lr_adjust = {epoch: args.learning_rate if epoch < 3 else args.learning_rate * (0.8 ** ((epoch - 3) // 1))}
20 | elif args.lradj == 'constant':
21 | lr_adjust = {epoch: args.learning_rate}
22 | elif args.lradj == '3':
23 | lr_adjust = {epoch: args.learning_rate if epoch < 10 else args.learning_rate*0.1}
24 | elif args.lradj == '4':
25 | lr_adjust = {epoch: args.learning_rate if epoch < 15 else args.learning_rate*0.1}
26 | elif args.lradj == '5':
27 | lr_adjust = {epoch: args.learning_rate if epoch < 25 else args.learning_rate*0.1}
28 | elif args.lradj == '6':
29 | lr_adjust = {epoch: args.learning_rate if epoch < 5 else args.learning_rate*0.1}
30 | elif args.lradj == 'TST':
31 | lr_adjust = {epoch: scheduler.get_last_lr()[0]}
32 |
33 | if epoch in lr_adjust.keys():
34 | lr = lr_adjust[epoch]
35 | for param_group in optimizer.param_groups:
36 | param_group['lr'] = lr
37 | if printout: print('Updating learning rate to {}'.format(lr))
38 |
39 |
40 | class EarlyStopping:
41 | def __init__(self, patience=7, verbose=False, delta=0):
42 | self.patience = patience
43 | self.verbose = verbose
44 | self.counter = 0
45 | self.best_score = None
46 | self.early_stop = False
47 | self.val_loss_min = np.Inf
48 | self.delta = delta
49 |
50 | def __call__(self, val_loss, model, path):
51 | score = -val_loss
52 | if self.best_score is None:
53 | self.best_score = score
54 | self.save_checkpoint(val_loss, model, path)
55 | elif score < self.best_score + self.delta:
56 | self.counter += 1
57 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
58 | if self.counter >= self.patience:
59 | self.early_stop = True
60 | else:
61 | self.best_score = score
62 | self.save_checkpoint(val_loss, model, path)
63 | self.counter = 0
64 |
65 | def save_checkpoint(self, val_loss, model, path):
66 | if self.verbose:
67 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
68 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
69 | self.val_loss_min = val_loss
70 |
71 |
72 | class dotdict(dict):
73 | """dot.notation access to dictionary attributes"""
74 | __getattr__ = dict.get
75 | __setattr__ = dict.__setitem__
76 | __delattr__ = dict.__delitem__
77 |
78 |
79 | class StandardScaler():
80 | def __init__(self, mean, std):
81 | self.mean = mean
82 | self.std = std
83 |
84 | def transform(self, data):
85 | return (data - self.mean) / self.std
86 |
87 | def inverse_transform(self, data):
88 | return (data * self.std) + self.mean
89 |
90 |
91 | def visual(true, preds=None, name='./pic/test.pdf'):
92 | """
93 | Results visualization
94 | """
95 | plt.figure()
96 | plt.plot(true, label='GroundTruth', linewidth=2)
97 | if preds is not None:
98 | plt.plot(preds, label='Prediction', linewidth=2)
99 | plt.legend()
100 | plt.savefig(name, bbox_inches='tight')
101 |
102 | def test_params_flop(model,x_shape):
103 | """
104 | If you want to thest former's flop, you need to give default value to inputs in model.forward(), the following code can only pass one argument to forward()
105 | """
106 | # model_params = 0
107 | # for parameter in model.parameters():
108 | # model_params += parameter.numel()
109 | # print('INFO: Trainable parameter count: {:.2f}M'.format(model_params / 1000000.0))
110 | from ptflops import get_model_complexity_info
111 | with torch.cuda.device(0):
112 | macs, params = get_model_complexity_info(model.cuda(), x_shape, as_strings=True, print_per_layer_stat=False)
113 | print('{:<30} {:<8}'.format('Computational complexity: ', macs))
114 | print('{:<30} {:<8}'.format('Number of parameters: ', params))
115 |
--------------------------------------------------------------------------------