├── .gitattributes ├── .gitignore ├── Figures ├── Figure2.png ├── Figure3.png ├── Figure4.png ├── Figure6.png ├── Table2.png ├── Table4.png ├── Table5.png └── Table6.png ├── LICENSE ├── README.md ├── acf_plot.ipynb ├── checkpoints ├── ETTm1_96_96_CycleNet_ETTm1_ftM_sl96_pl96_cycle96_linear_seed2024 │ └── checkpoint.pth ├── Electricity_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024 │ └── checkpoint.pth ├── Solar_96_96_CycleNet_Solar_ftM_sl96_pl96_cycle144_linear_seed2024 │ └── checkpoint.pth ├── traffic_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024 │ └── checkpoint.pth └── weather_96_96_CycleNet_custom_ftM_sl96_pl96_cycle144_linear_seed2024 │ └── checkpoint.pth ├── data_provider ├── data_factory.py └── data_loader.py ├── exp ├── exp_basic.py └── exp_main.py ├── layers ├── AutoCorrelation.py ├── Autoformer_EncDec.py ├── Embed.py ├── PatchTST_backbone.py ├── PatchTST_layers.py ├── RevIN.py ├── SelfAttention_Family.py └── Transformer_EncDec.py ├── models ├── Autoformer.py ├── CycleNet.py ├── CycleiTransformer.py ├── DLinear.py ├── Informer.py ├── LDLinear.py ├── Linear.py ├── NLinear.py ├── PatchTST.py ├── RLinear.py ├── RMLP.py ├── SegRNN.py ├── SparseTSF.py └── Transformer.py ├── requirements.txt ├── result.txt ├── run.py ├── run_main.sh ├── run_pems.sh ├── run_std.sh ├── scripts ├── CycleNet │ ├── Linear-Input-336 │ │ ├── electricity.sh │ │ ├── etth1.sh │ │ ├── etth2.sh │ │ ├── ettm1.sh │ │ ├── ettm2.sh │ │ ├── solar.sh │ │ ├── traffic.sh │ │ └── weather.sh │ ├── Linear-Input-720 │ │ ├── electricity.sh │ │ ├── etth1.sh │ │ ├── etth2.sh │ │ ├── ettm1.sh │ │ ├── ettm2.sh │ │ ├── solar.sh │ │ ├── traffic.sh │ │ └── weather.sh │ ├── Linear-Input-96 │ │ ├── electricity.sh │ │ ├── etth1.sh │ │ ├── etth2.sh │ │ ├── ettm1.sh │ │ ├── ettm2.sh │ │ ├── solar.sh │ │ ├── traffic.sh │ │ └── weather.sh │ ├── MLP-Input-336 │ │ ├── electricity.sh │ │ ├── etth1.sh │ │ ├── etth2.sh │ │ ├── ettm1.sh │ │ ├── ettm2.sh │ │ ├── solar.sh │ │ ├── traffic.sh │ │ └── weather.sh │ ├── MLP-Input-720 │ │ ├── electricity.sh │ │ ├── etth1.sh │ │ ├── etth2.sh │ │ ├── ettm1.sh │ │ ├── ettm2.sh │ │ ├── solar.sh │ │ ├── traffic.sh │ │ └── weather.sh │ ├── MLP-Input-96 │ │ ├── electricity.sh │ │ ├── etth1.sh │ │ ├── etth2.sh │ │ ├── ettm1.sh │ │ ├── ettm2.sh │ │ ├── solar.sh │ │ ├── traffic.sh │ │ └── weather.sh │ ├── PEMS │ │ ├── pems03.sh │ │ ├── pems04.sh │ │ ├── pems07.sh │ │ └── pems08.sh │ └── STD │ │ ├── CycleNet.sh │ │ ├── DLinear.sh │ │ ├── LDLinear.sh │ │ ├── Linear.sh │ │ └── SparseTSF.sh └── iTransformer │ ├── electricity.sh │ └── traffic.sh ├── utils ├── masking.py ├── metrics.py ├── timefeatures.py └── tools.py └── visualization.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.pyc 3 | #*.pth 4 | dataset/ 5 | test_results/ 6 | results/ 7 | logs/ 8 | .DS_Store -------------------------------------------------------------------------------- /Figures/Figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Figure2.png -------------------------------------------------------------------------------- /Figures/Figure3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Figure3.png -------------------------------------------------------------------------------- /Figures/Figure4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Figure4.png -------------------------------------------------------------------------------- /Figures/Figure6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Figure6.png -------------------------------------------------------------------------------- /Figures/Table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Table2.png -------------------------------------------------------------------------------- /Figures/Table4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Table4.png -------------------------------------------------------------------------------- /Figures/Table5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Table5.png -------------------------------------------------------------------------------- /Figures/Table6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/Figures/Table6.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CycleNet 2 | 3 | Welcome to the official repository of the CycleNet paper: "[CycleNet: Enhancing Time Series Forecasting through Modeling Periodic Patterns](https://arxiv.org/pdf/2409.18479)". 4 | 5 | [[Poster|海报]](https://drive.google.com/file/d/1dBnmrjtTab4M5L9qfrdAp2Y_xr53hBZ1/view?usp=drive_link) - 6 | [[Slides|幻灯片]](https://drive.google.com/file/d/1QcCCxRFtnFYPtaXmZiF4Zby5r-nl5tR5/view?usp=drive_link) - 7 | [[中文解读]](https://zhuanlan.zhihu.com/p/984766136) 8 | 9 | ## Updates 10 | 🚩 **News** (2025.05): Our latest work, [**TQNet**](https://github.com/ACAT-SCUT/TQNet), has been accepted to **ICML 2025**. TQNet is a powerful successor to CycleNet, addressing its limitation in *modeling inter-variable correlations* effectively. 11 | 12 | 🚩 **News** (2024.10): Thanks to the contribution of [wayhoww](https://github.com/wayhoww), CycleNet has been updated with a [new implementation](https://github.com/ACAT-SCUT/CycleNet/blob/f1deb5e1329970bf0c97b8fa593bb02c6d894587/models/CycleNet.py#L17) for generating cyclic components, achieving a 2x to 3x acceleration. 13 | 14 | 🚩 **News** (2024.09): CycleNet has been accepted as **NeurIPS 2024 _Spotlight_** (_average rating 7.25_, Top 1%). 15 | 16 | **Please note that** [**SparseTSF**](https://github.com/lss-1138/SparseTSF), [**CycleNet**](https://github.com/ACAT-SCUT/CycleNet), and [**TQNet**](https://github.com/ACAT-SCUT/TQNet) represent our continued exploration of **leveraging periodicity** for long-term time series forecasting (LTSF). 17 | The differences and connections among them are as follows: 18 | 19 | | Model | Use of Periodicity | Technique | Effect | Efficiency | Strengths | Limitation | 20 | | :----------------------------------------------------------: | :-------------------------------: | :------------------------------: | :----------------------------------------: | :------------------: | :-------------------------------------------: | :---------------------------------------------------: | 21 | | [**SparseTSF**](https://github.com/lss-1138/SparseTSF)
**(ICML 2024 Oral**) | Indirectly via downsampling | Cross-Period Sparse Forecasting | Ultra-light design | < 1k parameters | Extremely lightweight, near SOTA | Fails to cover multi-periods **(solved by CycleNet)** | 22 | | [**CycleNet**](https://github.com/ACAT-SCUT/CycleNet)
**(NeurIPS 2024 Spotlight)** | Explicit via learnable parameters | Residual Cycle Forecasting (RCF) | Better use of periodicity | 100k ~ 1M parameters | Strong performance on periodic data | Fails in multivariate modeling **(solved by TQNet)** | 23 | | [**TQNet**](https://github.com/ACAT-SCUT/TQNet)
**(ICML 2025)** | Serve as global correlations | Temporal Query in attention mechanism | Robust inter-variable correlation modeling | ~1M parameters | Enhanced multivariate forecasting performance | Hard to scale to ultra-long look-back inputs | 24 | 25 | ## Introduction 26 | 27 | CycleNet pioneers the **explicit modeling of periodicity** to enhance model performance in long-term time series forecasting (LTSF) tasks. Specifically, we introduce the Residual Cycle Forecasting (RCF) technique, which uses **learnable recurrent cycles** to capture inherent periodic patterns in sequences and then makes predictions on the *residual components* of the modeled cycles. 28 | 29 | ![image](Figures/Figure2.png) 30 | 31 | The RCF technique comprises two steps: the first step involves modeling the periodic patterns of sequences through globally **learnable recurrent cycles** within independent channels, and the second step entails predicting the *residual components* of the modeled cycles. 32 | 33 | ![image](Figures/Figure3.png) 34 | 35 | The learnable recurrent cycles _**Q**_ are initialized to ***zeros*** and then undergo gradient backpropagation training along with the backbone module for prediction, yielding *learned representations* (different from the initial zeros) that uncover the cyclic patterns embedded within the sequence. Here, we have provided the code implementation [[visualization.ipynb](https://github.com/ACAT-SCUT/CycleNet/blob/main/visualization.ipynb)] to visualize the learned periodic patterns. 36 | 37 | ![image](Figures/Figure4.png) 38 | 39 | RCF can be regarded as a novel approach to achieve Seasonal-Trend Decomposition (STD). Compared to other existing methods, such as moving averages (used in DLinear, FEDformer, and Autoformer), it offers significant advantages. 40 | 41 | ![image](Figures/Table5.png) 42 | 43 | As a result, CycleNet can achieve current state-of-the-art performance using only a simple Linear or dual-layer MLP as its backbone, and it also provides substantial computational efficiency. 44 | 45 | ![image](Figures/Table2.png) 46 | 47 | In addition to simple models like Linear and MLP, RCF can also improve the performance of more advanced algorithms. 48 | 49 | ![image](Figures/Table4.png) 50 | 51 | Finally, as an explicit periodic modeling technique, RCF requires that the period length of the data is identified beforehand. 52 | For RCF to be effective, the length ***W*** of the learnable recurrent cycles _**Q**_ must accurately match the intrinsic period length of the data. 53 | 54 | ![image](Figures/Table6.png) 55 | 56 | To identify the data's inherent period length, a straightforward method is manual inference. For instance, in the Electricity dataset, we know there is a weekly periodic pattern and that the sampling granularity is hourly, so the period length can be deduced to be 168. 57 | 58 | Moreover, a more scientific approach is to use the Autocorrelation Function (ACF), for which we provide an example ([ACF_ETTh1.ipynb](https://github.com/lss-1138/SparseTSF/blob/main/ACF_ETTh1.ipynb)) in the [SparseTSF](https://github.com/lss-1138/SparseTSF) repository. The hyperparameter ***W*** should be set to the lag corresponding to the observed maximum peak. 59 | 60 | ![image](Figures/Figure6.png) 61 | 62 | ## Model Implementation 63 | 64 | The key technique of CycleNet (or RCF) is to use learnable recurrent cycle to explicitly model the periodic patterns within the data, and then model the residual components of the modeled cycles using either a single-layer Linear or a dual-layer MLP. The core implementation code of CycleNet (or RCF) is available at: 65 | 66 | ``` 67 | models/CycleNet.py 68 | ``` 69 | 70 | To identify the relative position of each sample within the recurrent cycles, we need to generate cycle index (i.e., **_t_ mod _W_** mentioned in the paper) additionally for each data sample. The code for this part is available at: 71 | 72 | ``` 73 | data_provider/data_loader.py 74 | ``` 75 | The specific implementation code of cycle index includes: 76 | ```python 77 | def __read_data__(self): 78 | ... 79 | self.cycle_index = (np.arange(len(data)) % self.cycle)[border1:border2] 80 | 81 | def __getitem__(self, index): 82 | ... 83 | cycle_index = torch.tensor(self.cycle_index[s_end]) 84 | return ..., cycle_index 85 | ``` 86 | 87 | This simple implementation requires ensuring that the sequences have **no missing values**. In practical usage, a more elegant approach can be employed to generate the cycle index, such as mapping real-time timestamp to indices, for example: 88 | ```python 89 | def getCycleIndex(timestamp, frequency, cycle_len): 90 | return (timestamp / frequency) % cycle_len 91 | ``` 92 | 93 | 94 | 95 | ## Getting Started 96 | 97 | ### Environment Requirements 98 | 99 | To get started, ensure you have Conda installed on your system and follow these steps to set up the environment: 100 | 101 | ``` 102 | conda create -n CycleNet python=3.8 103 | conda activate CycleNet 104 | pip install -r requirements.txt 105 | ``` 106 | 107 | ### Data Preparation 108 | 109 | All the datasets needed for CycleNet can be obtained from the [[Google Drive]](https://drive.google.com/file/d/1bNbw1y8VYp-8pkRTqbjoW-TA-G8T0EQf/view) that introduced in previous works such as Autoformer and SCINet. 110 | Create a separate folder named ```./dataset``` and place all the CSV files in this directory. 111 | **Note**: Place the CSV files directly into this directory, such as "./dataset/ETTh1.csv" 112 | 113 | ### Training Example 114 | 115 | You can easily reproduce the results from the paper by running the provided script command. For instance, to reproduce the main results, execute the following command: 116 | 117 | ``` 118 | sh run_main.sh 119 | ``` 120 | 121 | **For your convenience**, we have provided the execution results of "sh run_main.sh" in [**[result.txt](https://github.com/ACAT-SCUT/CycleNet/blob/main/result.txt)**], which contain the results of running CycleNet/Linear and CycleNet/MLP five times each with various input lengths {96, 336, 720} and random seeds {2024, 2025, 2026, 2027, 2028}. 122 | 123 | You can also run the following command to reproduce the results of various STD techniques as well as the performance on the PEMS datasets (the PEMS datasets can be obtained from [SCINet](https://github.com/cure-lab/SCINet)): 124 | 125 | ``` 126 | sh run_std.sh 127 | sh run_pems.sh 128 | ``` 129 | 130 | Furthermore, you can specify separate scripts to run independent tasks, such as obtaining results on etth1: 131 | 132 | ``` 133 | sh scripts/CycleNet/Linear-Input-96/etth1.sh 134 | ``` 135 | 136 | 137 | 138 | 139 | 140 | ## Citation 141 | If you find this repo useful, please cite our paper. 142 | ``` 143 | @inproceedings{cyclenet, 144 | title={CycleNet: Enhancing Time Series Forecasting through Modeling Periodic Patterns}, 145 | author={Lin, Shengsheng and Lin, Weiwei and Hu, Xinyi and Wu, Wentai and Mo, Ruichao and Zhong, Haocheng}, 146 | booktitle={Thirty-eighth Conference on Neural Information Processing Systems}, 147 | year={2024} 148 | } 149 | ``` 150 | 151 | 152 | 153 | 154 | 155 | ## Contact 156 | If you have any questions or suggestions, feel free to contact: 157 | - Shengsheng Lin ([cslinshengsheng@mail.scut.edu.cn]()) 158 | - Weiwei Lin ([linww@scut.edu.cn]()) 159 | - Xinyi Hu (xyhu@cse.cuhk.edu.hk) 160 | 161 | ## Acknowledgement 162 | 163 | We extend our heartfelt appreciation to the following GitHub repositories for providing valuable code bases and datasets: 164 | 165 | https://github.com/lss-1138/SparseTSF 166 | 167 | https://github.com/thuml/iTransformer 168 | 169 | https://github.com/lss-1138/SegRNN 170 | 171 | https://github.com/yuqinie98/patchtst 172 | 173 | https://github.com/cure-lab/LTSF-Linear 174 | 175 | https://github.com/zhouhaoyi/Informer2020 176 | 177 | https://github.com/thuml/Autoformer 178 | 179 | https://github.com/MAZiqing/FEDformer 180 | 181 | https://github.com/alipay/Pyraformer 182 | 183 | https://github.com/ts-kim/RevIN 184 | 185 | https://github.com/timeseriesAI/tsai 186 | 187 | -------------------------------------------------------------------------------- /checkpoints/ETTm1_96_96_CycleNet_ETTm1_ftM_sl96_pl96_cycle96_linear_seed2024/checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/checkpoints/ETTm1_96_96_CycleNet_ETTm1_ftM_sl96_pl96_cycle96_linear_seed2024/checkpoint.pth -------------------------------------------------------------------------------- /checkpoints/Electricity_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024/checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/checkpoints/Electricity_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024/checkpoint.pth -------------------------------------------------------------------------------- /checkpoints/Solar_96_96_CycleNet_Solar_ftM_sl96_pl96_cycle144_linear_seed2024/checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/checkpoints/Solar_96_96_CycleNet_Solar_ftM_sl96_pl96_cycle144_linear_seed2024/checkpoint.pth -------------------------------------------------------------------------------- /checkpoints/traffic_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024/checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/checkpoints/traffic_96_96_CycleNet_custom_ftM_sl96_pl96_cycle168_linear_seed2024/checkpoint.pth -------------------------------------------------------------------------------- /checkpoints/weather_96_96_CycleNet_custom_ftM_sl96_pl96_cycle144_linear_seed2024/checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ACAT-SCUT/CycleNet/af6e06ccf8389f8fadb524cbf5c89ef0349dc30f/checkpoints/weather_96_96_CycleNet_custom_ftM_sl96_pl96_cycle144_linear_seed2024/checkpoint.pth -------------------------------------------------------------------------------- /data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred, Dataset_Solar, Dataset_PEMS 2 | from torch.utils.data import DataLoader 3 | 4 | data_dict = { 5 | 'ETTh1': Dataset_ETT_hour, 6 | 'ETTh2': Dataset_ETT_hour, 7 | 'ETTm1': Dataset_ETT_minute, 8 | 'ETTm2': Dataset_ETT_minute, 9 | 'custom': Dataset_Custom, 10 | 'Solar': Dataset_Solar, 11 | 'PEMS': Dataset_PEMS 12 | } 13 | 14 | 15 | def data_provider(args, flag): 16 | Data = data_dict[args.data] 17 | timeenc = 0 if args.embed != 'timeF' else 1 18 | 19 | if flag == 'test': 20 | shuffle_flag = False 21 | drop_last = False 22 | batch_size = args.batch_size 23 | freq = args.freq 24 | elif flag == 'pred': 25 | shuffle_flag = False 26 | drop_last = False 27 | batch_size = 1 28 | freq = args.freq 29 | Data = Dataset_Pred 30 | else: 31 | shuffle_flag = True 32 | drop_last = True 33 | batch_size = args.batch_size 34 | freq = args.freq 35 | 36 | data_set = Data( 37 | root_path=args.root_path, 38 | data_path=args.data_path, 39 | flag=flag, 40 | size=[args.seq_len, args.label_len, args.pred_len], 41 | features=args.features, 42 | target=args.target, 43 | timeenc=timeenc, 44 | freq=freq, 45 | cycle=args.cycle 46 | ) 47 | print(flag, len(data_set)) 48 | data_loader = DataLoader( 49 | data_set, 50 | batch_size=batch_size, 51 | shuffle=shuffle_flag, 52 | num_workers=args.num_workers, 53 | drop_last=drop_last) 54 | return data_set, data_loader 55 | -------------------------------------------------------------------------------- /exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class Exp_Basic(object): 7 | def __init__(self, args): 8 | self.args = args 9 | self.device = self._acquire_device() 10 | self.model = self._build_model().to(self.device) 11 | 12 | def _build_model(self): 13 | raise NotImplementedError 14 | return None 15 | 16 | def _acquire_device(self): 17 | if self.args.use_gpu: 18 | os.environ["CUDA_VISIBLE_DEVICES"] = str( 19 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices 20 | device = torch.device('cuda:{}'.format(self.args.gpu)) 21 | print('Use GPU: cuda:{}'.format(self.args.gpu)) 22 | else: 23 | device = torch.device('cpu') 24 | print('Use CPU') 25 | return device 26 | 27 | def _get_data(self): 28 | pass 29 | 30 | def vali(self): 31 | pass 32 | 33 | def train(self): 34 | pass 35 | 36 | def test(self): 37 | pass 38 | -------------------------------------------------------------------------------- /layers/AutoCorrelation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import math 7 | from math import sqrt 8 | import os 9 | 10 | 11 | class AutoCorrelation(nn.Module): 12 | """ 13 | AutoCorrelation Mechanism with the following two phases: 14 | (1) period-based dependencies discovery 15 | (2) time delay aggregation 16 | This block can replace the self-attention family mechanism seamlessly. 17 | """ 18 | def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False): 19 | super(AutoCorrelation, self).__init__() 20 | self.factor = factor 21 | self.scale = scale 22 | self.mask_flag = mask_flag 23 | self.output_attention = output_attention 24 | self.dropout = nn.Dropout(attention_dropout) 25 | 26 | def time_delay_agg_training(self, values, corr): 27 | """ 28 | SpeedUp version of Autocorrelation (a batch-normalization style design) 29 | This is for the training phase. 30 | """ 31 | head = values.shape[1] 32 | channel = values.shape[2] 33 | length = values.shape[3] 34 | # find top k 35 | top_k = int(self.factor * math.log(length)) 36 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 37 | index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] 38 | weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) 39 | # update corr 40 | tmp_corr = torch.softmax(weights, dim=-1) 41 | # aggregation 42 | tmp_values = values 43 | delays_agg = torch.zeros_like(values).float() 44 | for i in range(top_k): 45 | pattern = torch.roll(tmp_values, -int(index[i]), -1) 46 | delays_agg = delays_agg + pattern * \ 47 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 48 | return delays_agg 49 | 50 | def time_delay_agg_inference(self, values, corr): 51 | """ 52 | SpeedUp version of Autocorrelation (a batch-normalization style design) 53 | This is for the inference phase. 54 | """ 55 | batch = values.shape[0] 56 | head = values.shape[1] 57 | channel = values.shape[2] 58 | length = values.shape[3] 59 | # index init 60 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 61 | # find top k 62 | top_k = int(self.factor * math.log(length)) 63 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 64 | weights = torch.topk(mean_value, top_k, dim=-1)[0] 65 | delay = torch.topk(mean_value, top_k, dim=-1)[1] 66 | # update corr 67 | tmp_corr = torch.softmax(weights, dim=-1) 68 | # aggregation 69 | tmp_values = values.repeat(1, 1, 1, 2) 70 | delays_agg = torch.zeros_like(values).float() 71 | for i in range(top_k): 72 | tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) 73 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 74 | delays_agg = delays_agg + pattern * \ 75 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 76 | return delays_agg 77 | 78 | def time_delay_agg_full(self, values, corr): 79 | """ 80 | Standard version of Autocorrelation 81 | """ 82 | batch = values.shape[0] 83 | head = values.shape[1] 84 | channel = values.shape[2] 85 | length = values.shape[3] 86 | # index init 87 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 88 | # find top k 89 | top_k = int(self.factor * math.log(length)) 90 | weights = torch.topk(corr, top_k, dim=-1)[0] 91 | delay = torch.topk(corr, top_k, dim=-1)[1] 92 | # update corr 93 | tmp_corr = torch.softmax(weights, dim=-1) 94 | # aggregation 95 | tmp_values = values.repeat(1, 1, 1, 2) 96 | delays_agg = torch.zeros_like(values).float() 97 | for i in range(top_k): 98 | tmp_delay = init_index + delay[..., i].unsqueeze(-1) 99 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 100 | delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) 101 | return delays_agg 102 | 103 | def forward(self, queries, keys, values, attn_mask): 104 | B, L, H, E = queries.shape 105 | _, S, _, D = values.shape 106 | if L > S: 107 | zeros = torch.zeros_like(queries[:, :(L - S), :]).float() 108 | values = torch.cat([values, zeros], dim=1) 109 | keys = torch.cat([keys, zeros], dim=1) 110 | else: 111 | values = values[:, :L, :, :] 112 | keys = keys[:, :L, :, :] 113 | 114 | # period-based dependencies 115 | q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) 116 | k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) 117 | res = q_fft * torch.conj(k_fft) 118 | corr = torch.fft.irfft(res, dim=-1) 119 | 120 | # time delay agg 121 | if self.training: 122 | V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 123 | else: 124 | V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 125 | 126 | if self.output_attention: 127 | return (V.contiguous(), corr.permute(0, 3, 1, 2)) 128 | else: 129 | return (V.contiguous(), None) 130 | 131 | 132 | class AutoCorrelationLayer(nn.Module): 133 | def __init__(self, correlation, d_model, n_heads, d_keys=None, 134 | d_values=None): 135 | super(AutoCorrelationLayer, self).__init__() 136 | 137 | d_keys = d_keys or (d_model // n_heads) 138 | d_values = d_values or (d_model // n_heads) 139 | 140 | self.inner_correlation = correlation 141 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 142 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 143 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 144 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 145 | self.n_heads = n_heads 146 | 147 | def forward(self, queries, keys, values, attn_mask): 148 | B, L, _ = queries.shape 149 | _, S, _ = keys.shape 150 | H = self.n_heads 151 | 152 | queries = self.query_projection(queries).view(B, L, H, -1) 153 | keys = self.key_projection(keys).view(B, S, H, -1) 154 | values = self.value_projection(values).view(B, S, H, -1) 155 | 156 | out, attn = self.inner_correlation( 157 | queries, 158 | keys, 159 | values, 160 | attn_mask 161 | ) 162 | out = out.view(B, L, -1) 163 | 164 | return self.out_projection(out), attn 165 | -------------------------------------------------------------------------------- /layers/Autoformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class my_Layernorm(nn.Module): 7 | """ 8 | Special designed layernorm for the seasonal part 9 | """ 10 | def __init__(self, channels): 11 | super(my_Layernorm, self).__init__() 12 | self.layernorm = nn.LayerNorm(channels) 13 | 14 | def forward(self, x): 15 | x_hat = self.layernorm(x) 16 | bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) 17 | return x_hat - bias 18 | 19 | 20 | class moving_avg(nn.Module): 21 | """ 22 | Moving average block to highlight the trend of time series 23 | """ 24 | def __init__(self, kernel_size, stride): 25 | super(moving_avg, self).__init__() 26 | self.kernel_size = kernel_size 27 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 28 | 29 | def forward(self, x): 30 | # padding on the both ends of time series 31 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 32 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 33 | x = torch.cat([front, x, end], dim=1) 34 | x = self.avg(x.permute(0, 2, 1)) 35 | x = x.permute(0, 2, 1) 36 | return x 37 | 38 | 39 | class series_decomp(nn.Module): 40 | """ 41 | Series decomposition block 42 | """ 43 | def __init__(self, kernel_size): 44 | super(series_decomp, self).__init__() 45 | self.moving_avg = moving_avg(kernel_size, stride=1) 46 | 47 | def forward(self, x): 48 | moving_mean = self.moving_avg(x) 49 | res = x - moving_mean 50 | return res, moving_mean 51 | 52 | 53 | class EncoderLayer(nn.Module): 54 | """ 55 | Autoformer encoder layer with the progressive decomposition architecture 56 | """ 57 | def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): 58 | super(EncoderLayer, self).__init__() 59 | d_ff = d_ff or 4 * d_model 60 | self.attention = attention 61 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 62 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 63 | self.decomp1 = series_decomp(moving_avg) 64 | self.decomp2 = series_decomp(moving_avg) 65 | self.dropout = nn.Dropout(dropout) 66 | self.activation = F.relu if activation == "relu" else F.gelu 67 | 68 | def forward(self, x, attn_mask=None): 69 | new_x, attn = self.attention( 70 | x, x, x, 71 | attn_mask=attn_mask 72 | ) 73 | x = x + self.dropout(new_x) 74 | x, _ = self.decomp1(x) 75 | y = x 76 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 77 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 78 | res, _ = self.decomp2(x + y) 79 | return res, attn 80 | 81 | 82 | class Encoder(nn.Module): 83 | """ 84 | Autoformer encoder 85 | """ 86 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 87 | super(Encoder, self).__init__() 88 | self.attn_layers = nn.ModuleList(attn_layers) 89 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 90 | self.norm = norm_layer 91 | 92 | def forward(self, x, attn_mask=None): 93 | attns = [] 94 | if self.conv_layers is not None: 95 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 96 | x, attn = attn_layer(x, attn_mask=attn_mask) 97 | x = conv_layer(x) 98 | attns.append(attn) 99 | x, attn = self.attn_layers[-1](x) 100 | attns.append(attn) 101 | else: 102 | for attn_layer in self.attn_layers: 103 | x, attn = attn_layer(x, attn_mask=attn_mask) 104 | attns.append(attn) 105 | 106 | if self.norm is not None: 107 | x = self.norm(x) 108 | 109 | return x, attns 110 | 111 | 112 | class DecoderLayer(nn.Module): 113 | """ 114 | Autoformer decoder layer with the progressive decomposition architecture 115 | """ 116 | def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None, 117 | moving_avg=25, dropout=0.1, activation="relu"): 118 | super(DecoderLayer, self).__init__() 119 | d_ff = d_ff or 4 * d_model 120 | self.self_attention = self_attention 121 | self.cross_attention = cross_attention 122 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 123 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 124 | self.decomp1 = series_decomp(moving_avg) 125 | self.decomp2 = series_decomp(moving_avg) 126 | self.decomp3 = series_decomp(moving_avg) 127 | self.dropout = nn.Dropout(dropout) 128 | self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1, 129 | padding_mode='circular', bias=False) 130 | self.activation = F.relu if activation == "relu" else F.gelu 131 | 132 | def forward(self, x, cross, x_mask=None, cross_mask=None): 133 | x = x + self.dropout(self.self_attention( 134 | x, x, x, 135 | attn_mask=x_mask 136 | )[0]) 137 | x, trend1 = self.decomp1(x) 138 | x = x + self.dropout(self.cross_attention( 139 | x, cross, cross, 140 | attn_mask=cross_mask 141 | )[0]) 142 | x, trend2 = self.decomp2(x) 143 | y = x 144 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 145 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 146 | x, trend3 = self.decomp3(x + y) 147 | 148 | residual_trend = trend1 + trend2 + trend3 149 | residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) 150 | return x, residual_trend 151 | 152 | 153 | class Decoder(nn.Module): 154 | """ 155 | Autoformer encoder 156 | """ 157 | def __init__(self, layers, norm_layer=None, projection=None): 158 | super(Decoder, self).__init__() 159 | self.layers = nn.ModuleList(layers) 160 | self.norm = norm_layer 161 | self.projection = projection 162 | 163 | def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): 164 | for layer in self.layers: 165 | x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 166 | trend = trend + residual_trend 167 | 168 | if self.norm is not None: 169 | x = self.norm(x) 170 | 171 | if self.projection is not None: 172 | x = self.projection(x) 173 | return x, trend 174 | -------------------------------------------------------------------------------- /layers/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm 5 | import math 6 | 7 | 8 | class PositionalEmbedding(nn.Module): 9 | def __init__(self, d_model, max_len=5000): 10 | super(PositionalEmbedding, self).__init__() 11 | # Compute the positional encodings once in log space. 12 | pe = torch.zeros(max_len, d_model).float() 13 | pe.require_grad = False 14 | 15 | position = torch.arange(0, max_len).float().unsqueeze(1) 16 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 17 | 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | 21 | pe = pe.unsqueeze(0) 22 | self.register_buffer('pe', pe) 23 | 24 | def forward(self, x): 25 | return self.pe[:, :x.size(1)] 26 | 27 | 28 | class TokenEmbedding(nn.Module): 29 | def __init__(self, c_in, d_model): 30 | super(TokenEmbedding, self).__init__() 31 | padding = 1 if torch.__version__ >= '1.5.0' else 2 32 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 33 | kernel_size=3, padding=padding, padding_mode='circular', bias=False) 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv1d): 36 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') 37 | 38 | def forward(self, x): 39 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 40 | return x 41 | 42 | 43 | class FixedEmbedding(nn.Module): 44 | def __init__(self, c_in, d_model): 45 | super(FixedEmbedding, self).__init__() 46 | 47 | w = torch.zeros(c_in, d_model).float() 48 | w.require_grad = False 49 | 50 | position = torch.arange(0, c_in).float().unsqueeze(1) 51 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 52 | 53 | w[:, 0::2] = torch.sin(position * div_term) 54 | w[:, 1::2] = torch.cos(position * div_term) 55 | 56 | self.emb = nn.Embedding(c_in, d_model) 57 | self.emb.weight = nn.Parameter(w, requires_grad=False) 58 | 59 | def forward(self, x): 60 | return self.emb(x).detach() 61 | 62 | 63 | class TemporalEmbedding(nn.Module): 64 | def __init__(self, d_model, embed_type='fixed', freq='h'): 65 | super(TemporalEmbedding, self).__init__() 66 | 67 | minute_size = 4 68 | hour_size = 24 69 | weekday_size = 7 70 | day_size = 32 71 | month_size = 13 72 | 73 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding 74 | if freq == 't': 75 | self.minute_embed = Embed(minute_size, d_model) 76 | self.hour_embed = Embed(hour_size, d_model) 77 | self.weekday_embed = Embed(weekday_size, d_model) 78 | self.day_embed = Embed(day_size, d_model) 79 | self.month_embed = Embed(month_size, d_model) 80 | 81 | def forward(self, x): 82 | x = x.long() 83 | 84 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0. 85 | hour_x = self.hour_embed(x[:, :, 3]) 86 | weekday_x = self.weekday_embed(x[:, :, 2]) 87 | day_x = self.day_embed(x[:, :, 1]) 88 | month_x = self.month_embed(x[:, :, 0]) 89 | 90 | return hour_x + weekday_x + day_x + month_x + minute_x 91 | 92 | 93 | class TimeFeatureEmbedding(nn.Module): 94 | def __init__(self, d_model, embed_type='timeF', freq='h'): 95 | super(TimeFeatureEmbedding, self).__init__() 96 | 97 | freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 98 | d_inp = freq_map[freq] 99 | self.embed = nn.Linear(d_inp, d_model, bias=False) 100 | 101 | def forward(self, x): 102 | return self.embed(x) 103 | 104 | 105 | class DataEmbedding(nn.Module): 106 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 107 | super(DataEmbedding, self).__init__() 108 | 109 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 110 | self.position_embedding = PositionalEmbedding(d_model=d_model) 111 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 112 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 113 | d_model=d_model, embed_type=embed_type, freq=freq) 114 | self.dropout = nn.Dropout(p=dropout) 115 | 116 | def forward(self, x, x_mark): 117 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 118 | return self.dropout(x) 119 | 120 | 121 | class DataEmbedding_wo_pos(nn.Module): 122 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 123 | super(DataEmbedding_wo_pos, self).__init__() 124 | 125 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 126 | self.position_embedding = PositionalEmbedding(d_model=d_model) 127 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 128 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 129 | d_model=d_model, embed_type=embed_type, freq=freq) 130 | self.dropout = nn.Dropout(p=dropout) 131 | 132 | def forward(self, x, x_mark): 133 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) 134 | return self.dropout(x) 135 | 136 | class DataEmbedding_wo_pos_temp(nn.Module): 137 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 138 | super(DataEmbedding_wo_pos_temp, self).__init__() 139 | 140 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 141 | self.position_embedding = PositionalEmbedding(d_model=d_model) 142 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 143 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 144 | d_model=d_model, embed_type=embed_type, freq=freq) 145 | self.dropout = nn.Dropout(p=dropout) 146 | 147 | def forward(self, x, x_mark): 148 | x = self.value_embedding(x) 149 | return self.dropout(x) 150 | 151 | class DataEmbedding_wo_temp(nn.Module): 152 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 153 | super(DataEmbedding_wo_temp, self).__init__() 154 | 155 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 156 | self.position_embedding = PositionalEmbedding(d_model=d_model) 157 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 158 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 159 | d_model=d_model, embed_type=embed_type, freq=freq) 160 | self.dropout = nn.Dropout(p=dropout) 161 | 162 | def forward(self, x, x_mark): 163 | x = self.value_embedding(x) + self.position_embedding(x) 164 | return self.dropout(x) 165 | 166 | 167 | class DataEmbedding_inverted(nn.Module): 168 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 169 | super(DataEmbedding_inverted, self).__init__() 170 | self.value_embedding = nn.Linear(c_in, d_model) 171 | self.dropout = nn.Dropout(p=dropout) 172 | 173 | def forward(self, x, x_mark): 174 | x = x.permute(0, 2, 1) 175 | # x: [Batch Variate Time] 176 | if x_mark is None: 177 | x = self.value_embedding(x) 178 | else: 179 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) 180 | # x: [Batch Variate d_model] 181 | return self.dropout(x) 182 | -------------------------------------------------------------------------------- /layers/PatchTST_layers.py: -------------------------------------------------------------------------------- 1 | __all__ = ['Transpose', 'get_activation_fn', 'moving_avg', 'series_decomp', 'PositionalEncoding', 'SinCosPosEncoding', 'Coord2dPosEncoding', 'Coord1dPosEncoding', 'positional_encoding'] 2 | 3 | import torch 4 | from torch import nn 5 | import math 6 | 7 | class Transpose(nn.Module): 8 | def __init__(self, *dims, contiguous=False): 9 | super().__init__() 10 | self.dims, self.contiguous = dims, contiguous 11 | def forward(self, x): 12 | if self.contiguous: return x.transpose(*self.dims).contiguous() 13 | else: return x.transpose(*self.dims) 14 | 15 | 16 | def get_activation_fn(activation): 17 | if callable(activation): return activation() 18 | elif activation.lower() == "relu": return nn.ReLU() 19 | elif activation.lower() == "gelu": return nn.GELU() 20 | raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable') 21 | 22 | 23 | # decomposition 24 | 25 | class moving_avg(nn.Module): 26 | """ 27 | Moving average block to highlight the trend of time series 28 | """ 29 | def __init__(self, kernel_size, stride): 30 | super(moving_avg, self).__init__() 31 | self.kernel_size = kernel_size 32 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 33 | 34 | def forward(self, x): 35 | # padding on the both ends of time series 36 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 37 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 38 | x = torch.cat([front, x, end], dim=1) 39 | x = self.avg(x.permute(0, 2, 1)) 40 | x = x.permute(0, 2, 1) 41 | return x 42 | 43 | 44 | class series_decomp(nn.Module): 45 | """ 46 | Series decomposition block 47 | """ 48 | def __init__(self, kernel_size): 49 | super(series_decomp, self).__init__() 50 | self.moving_avg = moving_avg(kernel_size, stride=1) 51 | 52 | def forward(self, x): 53 | moving_mean = self.moving_avg(x) 54 | res = x - moving_mean 55 | return res, moving_mean 56 | 57 | 58 | 59 | # pos_encoding 60 | 61 | def PositionalEncoding(q_len, d_model, normalize=True): 62 | pe = torch.zeros(q_len, d_model) 63 | position = torch.arange(0, q_len).unsqueeze(1) 64 | div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) 65 | pe[:, 0::2] = torch.sin(position * div_term) 66 | pe[:, 1::2] = torch.cos(position * div_term) 67 | if normalize: 68 | pe = pe - pe.mean() 69 | pe = pe / (pe.std() * 10) 70 | return pe 71 | 72 | SinCosPosEncoding = PositionalEncoding 73 | 74 | def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False): 75 | x = .5 if exponential else 1 76 | i = 0 77 | for i in range(100): 78 | cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1 79 | pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose) 80 | if abs(cpe.mean()) <= eps: break 81 | elif cpe.mean() > eps: x += .001 82 | else: x -= .001 83 | i += 1 84 | if normalize: 85 | cpe = cpe - cpe.mean() 86 | cpe = cpe / (cpe.std() * 10) 87 | return cpe 88 | 89 | def Coord1dPosEncoding(q_len, exponential=False, normalize=True): 90 | cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1) 91 | if normalize: 92 | cpe = cpe - cpe.mean() 93 | cpe = cpe / (cpe.std() * 10) 94 | return cpe 95 | 96 | def positional_encoding(pe, learn_pe, q_len, d_model): 97 | # Positional encoding 98 | if pe == None: 99 | W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe 100 | nn.init.uniform_(W_pos, -0.02, 0.02) 101 | learn_pe = False 102 | elif pe == 'zero': 103 | W_pos = torch.empty((q_len, 1)) 104 | nn.init.uniform_(W_pos, -0.02, 0.02) 105 | elif pe == 'zeros': 106 | W_pos = torch.empty((q_len, d_model)) 107 | nn.init.uniform_(W_pos, -0.02, 0.02) 108 | elif pe == 'normal' or pe == 'gauss': 109 | W_pos = torch.zeros((q_len, 1)) 110 | torch.nn.init.normal_(W_pos, mean=0.0, std=0.1) 111 | elif pe == 'uniform': 112 | W_pos = torch.zeros((q_len, 1)) 113 | nn.init.uniform_(W_pos, a=0.0, b=0.1) 114 | elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True) 115 | elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True) 116 | elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True) 117 | elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True) 118 | elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True) 119 | else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \ 120 | 'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)") 121 | return nn.Parameter(W_pos, requires_grad=learn_pe) -------------------------------------------------------------------------------- /layers/RevIN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RevIN(nn.Module): 5 | def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False): 6 | """ 7 | :param num_features: the number of features or channels 8 | :param eps: a value added for numerical stability 9 | :param affine: if True, RevIN has learnable affine parameters 10 | """ 11 | super(RevIN, self).__init__() 12 | self.num_features = num_features 13 | self.eps = eps 14 | self.affine = affine 15 | self.subtract_last = subtract_last 16 | if self.affine: 17 | self._init_params() 18 | 19 | def forward(self, x, mode:str): 20 | if mode == 'norm': 21 | self._get_statistics(x) 22 | x = self._normalize(x) 23 | elif mode == 'denorm': 24 | x = self._denormalize(x) 25 | else: raise NotImplementedError 26 | return x 27 | 28 | def _init_params(self): 29 | # initialize RevIN params: (C,) 30 | self.affine_weight = nn.Parameter(torch.ones(self.num_features)) 31 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) 32 | 33 | def _get_statistics(self, x): 34 | dim2reduce = tuple(range(1, x.ndim-1)) 35 | if self.subtract_last: 36 | self.last = x[:,-1,:].unsqueeze(1) 37 | else: 38 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() 39 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() 40 | 41 | def _normalize(self, x): 42 | if self.subtract_last: 43 | x = x - self.last 44 | else: 45 | x = x - self.mean 46 | x = x / self.stdev 47 | if self.affine: 48 | x = x * self.affine_weight 49 | x = x + self.affine_bias 50 | return x 51 | 52 | def _denormalize(self, x): 53 | if self.affine: 54 | x = x - self.affine_bias 55 | x = x / (self.affine_weight + self.eps*self.eps) 56 | x = x * self.stdev 57 | if self.subtract_last: 58 | x = x + self.last 59 | else: 60 | x = x + self.mean 61 | return x -------------------------------------------------------------------------------- /layers/SelfAttention_Family.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | import numpy as np 8 | import math 9 | from math import sqrt 10 | from utils.masking import TriangularCausalMask, ProbMask 11 | import os 12 | 13 | 14 | class FullAttention(nn.Module): 15 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 16 | super(FullAttention, self).__init__() 17 | self.scale = scale 18 | self.mask_flag = mask_flag 19 | self.output_attention = output_attention 20 | self.dropout = nn.Dropout(attention_dropout) 21 | 22 | def forward(self, queries, keys, values, attn_mask): 23 | B, L, H, E = queries.shape 24 | _, S, _, D = values.shape 25 | scale = self.scale or 1. / sqrt(E) 26 | 27 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) 28 | 29 | if self.mask_flag: 30 | if attn_mask is None: 31 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 32 | 33 | scores.masked_fill_(attn_mask.mask, -np.inf) 34 | 35 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 36 | V = torch.einsum("bhls,bshd->blhd", A, values) 37 | 38 | if self.output_attention: 39 | return (V.contiguous(), A) 40 | else: 41 | return (V.contiguous(), None) 42 | 43 | 44 | class ProbAttention(nn.Module): 45 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 46 | super(ProbAttention, self).__init__() 47 | self.factor = factor 48 | self.scale = scale 49 | self.mask_flag = mask_flag 50 | self.output_attention = output_attention 51 | self.dropout = nn.Dropout(attention_dropout) 52 | 53 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) 54 | # Q [B, H, L, D] 55 | B, H, L_K, E = K.shape 56 | _, _, L_Q, _ = Q.shape 57 | 58 | # calculate the sampled Q_K 59 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) 60 | index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q 61 | K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] 62 | Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() 63 | 64 | # find the Top_k query with sparisty measurement 65 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) 66 | M_top = M.topk(n_top, sorted=False)[1] 67 | 68 | # use the reduced Q to calculate Q_K 69 | Q_reduce = Q[torch.arange(B)[:, None, None], 70 | torch.arange(H)[None, :, None], 71 | M_top, :] # factor*ln(L_q) 72 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k 73 | 74 | return Q_K, M_top 75 | 76 | def _get_initial_context(self, V, L_Q): 77 | B, H, L_V, D = V.shape 78 | if not self.mask_flag: 79 | # V_sum = V.sum(dim=-2) 80 | V_sum = V.mean(dim=-2) 81 | contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() 82 | else: # use mask 83 | assert (L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only 84 | contex = V.cumsum(dim=-2) 85 | return contex 86 | 87 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): 88 | B, H, L_V, D = V.shape 89 | 90 | if self.mask_flag: 91 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) 92 | scores.masked_fill_(attn_mask.mask, -np.inf) 93 | 94 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) 95 | 96 | context_in[torch.arange(B)[:, None, None], 97 | torch.arange(H)[None, :, None], 98 | index, :] = torch.matmul(attn, V).type_as(context_in) 99 | if self.output_attention: 100 | attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) 101 | attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn 102 | return (context_in, attns) 103 | else: 104 | return (context_in, None) 105 | 106 | def forward(self, queries, keys, values, attn_mask): 107 | B, L_Q, H, D = queries.shape 108 | _, L_K, _, _ = keys.shape 109 | 110 | queries = queries.transpose(2, 1) 111 | keys = keys.transpose(2, 1) 112 | values = values.transpose(2, 1) 113 | 114 | U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) 115 | u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 116 | 117 | U_part = U_part if U_part < L_K else L_K 118 | u = u if u < L_Q else L_Q 119 | 120 | scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) 121 | 122 | # add scale factor 123 | scale = self.scale or 1. / sqrt(D) 124 | if scale is not None: 125 | scores_top = scores_top * scale 126 | # get the context 127 | context = self._get_initial_context(values, L_Q) 128 | # update the context with selected top_k queries 129 | context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask) 130 | 131 | return context.contiguous(), attn 132 | 133 | 134 | class AttentionLayer(nn.Module): 135 | def __init__(self, attention, d_model, n_heads, d_keys=None, 136 | d_values=None): 137 | super(AttentionLayer, self).__init__() 138 | 139 | d_keys = d_keys or (d_model // n_heads) 140 | d_values = d_values or (d_model // n_heads) 141 | 142 | self.inner_attention = attention 143 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 144 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 145 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 146 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 147 | self.n_heads = n_heads 148 | 149 | def forward(self, queries, keys, values, attn_mask): 150 | B, L, _ = queries.shape 151 | _, S, _ = keys.shape 152 | H = self.n_heads 153 | 154 | queries = self.query_projection(queries).view(B, L, H, -1) 155 | keys = self.key_projection(keys).view(B, S, H, -1) 156 | values = self.value_projection(values).view(B, S, H, -1) 157 | 158 | out, attn = self.inner_attention( 159 | queries, 160 | keys, 161 | values, 162 | attn_mask 163 | ) 164 | out = out.view(B, L, -1) 165 | 166 | return self.out_projection(out), attn 167 | -------------------------------------------------------------------------------- /layers/Transformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLayer(nn.Module): 7 | def __init__(self, c_in): 8 | super(ConvLayer, self).__init__() 9 | self.downConv = nn.Conv1d(in_channels=c_in, 10 | out_channels=c_in, 11 | kernel_size=3, 12 | padding=2, 13 | padding_mode='circular') 14 | self.norm = nn.BatchNorm1d(c_in) 15 | self.activation = nn.ELU() 16 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 17 | 18 | def forward(self, x): 19 | x = self.downConv(x.permute(0, 2, 1)) 20 | x = self.norm(x) 21 | x = self.activation(x) 22 | x = self.maxPool(x) 23 | x = x.transpose(1, 2) 24 | return x 25 | 26 | 27 | class EncoderLayer(nn.Module): 28 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 29 | super(EncoderLayer, self).__init__() 30 | d_ff = d_ff or 4 * d_model 31 | self.attention = attention 32 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 33 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 34 | self.norm1 = nn.LayerNorm(d_model) 35 | self.norm2 = nn.LayerNorm(d_model) 36 | self.dropout = nn.Dropout(dropout) 37 | self.activation = F.relu if activation == "relu" else F.gelu 38 | 39 | def forward(self, x, attn_mask=None): 40 | new_x, attn = self.attention( 41 | x, x, x, 42 | attn_mask=attn_mask 43 | ) 44 | x = x + self.dropout(new_x) 45 | 46 | y = x = self.norm1(x) 47 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 48 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 49 | 50 | return self.norm2(x + y), attn 51 | 52 | 53 | class Encoder(nn.Module): 54 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 55 | super(Encoder, self).__init__() 56 | self.attn_layers = nn.ModuleList(attn_layers) 57 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 58 | self.norm = norm_layer 59 | 60 | def forward(self, x, attn_mask=None): 61 | # x [B, L, D] 62 | attns = [] 63 | if self.conv_layers is not None: 64 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 65 | x, attn = attn_layer(x, attn_mask=attn_mask) 66 | x = conv_layer(x) 67 | attns.append(attn) 68 | x, attn = self.attn_layers[-1](x) 69 | attns.append(attn) 70 | else: 71 | for attn_layer in self.attn_layers: 72 | x, attn = attn_layer(x, attn_mask=attn_mask) 73 | attns.append(attn) 74 | 75 | if self.norm is not None: 76 | x = self.norm(x) 77 | 78 | return x, attns 79 | 80 | 81 | class DecoderLayer(nn.Module): 82 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 83 | dropout=0.1, activation="relu"): 84 | super(DecoderLayer, self).__init__() 85 | d_ff = d_ff or 4 * d_model 86 | self.self_attention = self_attention 87 | self.cross_attention = cross_attention 88 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 89 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 90 | self.norm1 = nn.LayerNorm(d_model) 91 | self.norm2 = nn.LayerNorm(d_model) 92 | self.norm3 = nn.LayerNorm(d_model) 93 | self.dropout = nn.Dropout(dropout) 94 | self.activation = F.relu if activation == "relu" else F.gelu 95 | 96 | def forward(self, x, cross, x_mask=None, cross_mask=None): 97 | x = x + self.dropout(self.self_attention( 98 | x, x, x, 99 | attn_mask=x_mask 100 | )[0]) 101 | x = self.norm1(x) 102 | 103 | x = x + self.dropout(self.cross_attention( 104 | x, cross, cross, 105 | attn_mask=cross_mask 106 | )[0]) 107 | 108 | y = x = self.norm2(x) 109 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 110 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 111 | 112 | return self.norm3(x + y) 113 | 114 | 115 | class Decoder(nn.Module): 116 | def __init__(self, layers, norm_layer=None, projection=None): 117 | super(Decoder, self).__init__() 118 | self.layers = nn.ModuleList(layers) 119 | self.norm = norm_layer 120 | self.projection = projection 121 | 122 | def forward(self, x, cross, x_mask=None, cross_mask=None): 123 | for layer in self.layers: 124 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 125 | 126 | if self.norm is not None: 127 | x = self.norm(x) 128 | 129 | if self.projection is not None: 130 | x = self.projection(x) 131 | return x 132 | -------------------------------------------------------------------------------- /models/Autoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Embed import DataEmbedding, DataEmbedding_wo_pos,DataEmbedding_wo_pos_temp,DataEmbedding_wo_temp 5 | from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer 6 | from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp 7 | import math 8 | import numpy as np 9 | 10 | 11 | class Model(nn.Module): 12 | """ 13 | Autoformer is the first method to achieve the series-wise connection, 14 | with inherent O(LlogL) complexity 15 | """ 16 | def __init__(self, configs): 17 | super(Model, self).__init__() 18 | self.seq_len = configs.seq_len 19 | self.label_len = configs.label_len 20 | self.pred_len = configs.pred_len 21 | self.output_attention = configs.output_attention 22 | 23 | # Decomp 24 | kernel_size = configs.moving_avg 25 | self.decomp = series_decomp(kernel_size) 26 | 27 | # Embedding 28 | # The series-wise connection inherently contains the sequential information. 29 | # Thus, we can discard the position embedding of transformers. 30 | if configs.embed_type == 0: 31 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, 32 | configs.dropout) 33 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, 34 | configs.dropout) 35 | elif configs.embed_type == 1: 36 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 37 | configs.dropout) 38 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 39 | configs.dropout) 40 | elif configs.embed_type == 2: 41 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, 42 | configs.dropout) 43 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, 44 | configs.dropout) 45 | 46 | elif configs.embed_type == 3: 47 | self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, 48 | configs.dropout) 49 | self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, 50 | configs.dropout) 51 | elif configs.embed_type == 4: 52 | self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, 53 | configs.dropout) 54 | self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, 55 | configs.dropout) 56 | 57 | # Encoder 58 | self.encoder = Encoder( 59 | [ 60 | EncoderLayer( 61 | AutoCorrelationLayer( 62 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 63 | output_attention=configs.output_attention), 64 | configs.d_model, configs.n_heads), 65 | configs.d_model, 66 | configs.d_ff, 67 | moving_avg=configs.moving_avg, 68 | dropout=configs.dropout, 69 | activation=configs.activation 70 | ) for l in range(configs.e_layers) 71 | ], 72 | norm_layer=my_Layernorm(configs.d_model) 73 | ) 74 | # Decoder 75 | self.decoder = Decoder( 76 | [ 77 | DecoderLayer( 78 | AutoCorrelationLayer( 79 | AutoCorrelation(True, configs.factor, attention_dropout=configs.dropout, 80 | output_attention=False), 81 | configs.d_model, configs.n_heads), 82 | AutoCorrelationLayer( 83 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 84 | output_attention=False), 85 | configs.d_model, configs.n_heads), 86 | configs.d_model, 87 | configs.c_out, 88 | configs.d_ff, 89 | moving_avg=configs.moving_avg, 90 | dropout=configs.dropout, 91 | activation=configs.activation, 92 | ) 93 | for l in range(configs.d_layers) 94 | ], 95 | norm_layer=my_Layernorm(configs.d_model), 96 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 97 | ) 98 | 99 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 100 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 101 | # decomp init 102 | mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) 103 | zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device) 104 | seasonal_init, trend_init = self.decomp(x_enc) 105 | # decoder input 106 | trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1) 107 | seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1) 108 | # enc 109 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 110 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 111 | # dec 112 | dec_out = self.dec_embedding(seasonal_init, x_mark_dec) 113 | seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, 114 | trend=trend_init) 115 | # final 116 | dec_out = trend_part + seasonal_part 117 | 118 | if self.output_attention: 119 | return dec_out[:, -self.pred_len:, :], attns 120 | else: 121 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 122 | -------------------------------------------------------------------------------- /models/CycleNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RecurrentCycle(torch.nn.Module): 5 | # Thanks for the contribution of wayhoww. 6 | # The new implementation uses index arithmetic with modulo to directly gather cyclic data in a single operation, 7 | # while the original implementation manually rolls and repeats the data through looping. 8 | # It achieves a significant speed improvement (2x ~ 3x acceleration). 9 | # See https://github.com/ACAT-SCUT/CycleNet/pull/4 for more details. 10 | def __init__(self, cycle_len, channel_size): 11 | super(RecurrentCycle, self).__init__() 12 | self.cycle_len = cycle_len 13 | self.channel_size = channel_size 14 | self.data = torch.nn.Parameter(torch.zeros(cycle_len, channel_size), requires_grad=True) 15 | 16 | def forward(self, index, length): 17 | gather_index = (index.view(-1, 1) + torch.arange(length, device=index.device).view(1, -1)) % self.cycle_len 18 | return self.data[gather_index] 19 | 20 | 21 | class Model(nn.Module): 22 | def __init__(self, configs): 23 | super(Model, self).__init__() 24 | 25 | self.seq_len = configs.seq_len 26 | self.pred_len = configs.pred_len 27 | self.enc_in = configs.enc_in 28 | self.cycle_len = configs.cycle 29 | self.model_type = configs.model_type 30 | self.d_model = configs.d_model 31 | self.use_revin = configs.use_revin 32 | 33 | self.cycleQueue = RecurrentCycle(cycle_len=self.cycle_len, channel_size=self.enc_in) 34 | 35 | assert self.model_type in ['linear', 'mlp'] 36 | if self.model_type == 'linear': 37 | self.model = nn.Linear(self.seq_len, self.pred_len) 38 | elif self.model_type == 'mlp': 39 | self.model = nn.Sequential( 40 | nn.Linear(self.seq_len, self.d_model), 41 | nn.ReLU(), 42 | nn.Linear(self.d_model, self.pred_len) 43 | ) 44 | 45 | def forward(self, x, cycle_index): 46 | # x: (batch_size, seq_len, enc_in), cycle_index: (batch_size,) 47 | 48 | # instance norm 49 | if self.use_revin: 50 | seq_mean = torch.mean(x, dim=1, keepdim=True) 51 | seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5 52 | x = (x - seq_mean) / torch.sqrt(seq_var) 53 | 54 | # remove the cycle of the input data 55 | x = x - self.cycleQueue(cycle_index, self.seq_len) 56 | 57 | # forecasting with channel independence (parameters-sharing) 58 | y = self.model(x.permute(0, 2, 1)).permute(0, 2, 1) 59 | 60 | # add back the cycle of the output data 61 | y = y + self.cycleQueue((cycle_index + self.seq_len) % self.cycle_len, self.pred_len) 62 | 63 | # instance denorm 64 | if self.use_revin: 65 | y = y * torch.sqrt(seq_var) + seq_mean 66 | 67 | return y 68 | -------------------------------------------------------------------------------- /models/CycleiTransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding_inverted 7 | import numpy as np 8 | 9 | ### Intergration of Cycle 10 | class RecurrentCycle(torch.nn.Module): 11 | def __init__(self, cycle_len, channel_size): 12 | super(RecurrentCycle, self).__init__() 13 | self.cycle_len = cycle_len 14 | self.channel_size = channel_size 15 | self.data = torch.nn.Parameter(torch.zeros(cycle_len, channel_size), requires_grad=True) 16 | 17 | def forward(self, index, length): 18 | gather_index = (index.view(-1, 1) + torch.arange(length, device=index.device).view(1, -1)) % self.cycle_len 19 | return self.data[gather_index] 20 | 21 | 22 | class Model(nn.Module): 23 | def __init__(self, configs): 24 | super(Model, self).__init__() 25 | self.seq_len = configs.seq_len 26 | self.pred_len = configs.pred_len 27 | self.output_attention = configs.output_attention 28 | self.use_norm = configs.use_revin 29 | # Embedding 30 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq, 31 | configs.dropout) 32 | # Encoder-only architecture 33 | self.encoder = Encoder( 34 | [ 35 | EncoderLayer( 36 | AttentionLayer( 37 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 38 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 39 | configs.d_model, 40 | configs.d_ff, 41 | dropout=configs.dropout, 42 | activation=configs.activation 43 | ) for l in range(configs.e_layers) 44 | ], 45 | norm_layer=torch.nn.LayerNorm(configs.d_model) 46 | ) 47 | self.projector = nn.Linear(configs.d_model, configs.pred_len, bias=True) 48 | 49 | ### ### Intergration of Cycle 50 | self.cycle_len = configs.cycle 51 | self.enc_in = configs.enc_in 52 | self.cycleQueue = RecurrentCycle(cycle_len=self.cycle_len, channel_size=self.enc_in) 53 | 54 | def forecast(self, x_enc, cycle_index, x_mark_enc=None, x_dec=None, x_mark_dec=None): 55 | 56 | if self.use_norm: 57 | # Normalization from Non-stationary Transformer 58 | means = x_enc.mean(1, keepdim=True).detach() 59 | x_enc = x_enc - means 60 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 61 | x_enc /= stdev 62 | 63 | ### Intergration of Cycle 64 | # remove the cycle of the input data 65 | x_enc = x_enc - self.cycleQueue(cycle_index, self.seq_len) 66 | 67 | _, _, N = x_enc.shape # B L N 68 | # B: batch_size; E: d_model; 69 | # L: seq_len; S: pred_len; 70 | # N: number of variate (tokens), can also includes covariates 71 | 72 | # Embedding 73 | # B L N -> B N E (B L N -> B L E in the vanilla Transformer) 74 | enc_out = self.enc_embedding(x_enc, x_mark_enc) # covariates (e.g timestamp) can be also embedded as tokens 75 | 76 | # B N E -> B N E (B L E -> B L E in the vanilla Transformer) 77 | # the dimensions of embedded time series has been inverted, and then processed by native attn, layernorm and ffn modules 78 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 79 | 80 | # B N E -> B N S -> B S N 81 | dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates 82 | 83 | ### Intergration of Cycle 84 | # add back the cycle of the output data 85 | dec_out = dec_out + self.cycleQueue((cycle_index + self.seq_len) % self.cycle_len, self.pred_len) 86 | 87 | if self.use_norm: 88 | # De-Normalization from Non-stationary Transformer 89 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 90 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 91 | 92 | 93 | return dec_out 94 | 95 | def forward(self, x_enc, cycle_index, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None): 96 | dec_out = self.forecast(x_enc, cycle_index, x_mark_enc, x_dec, x_mark_dec) 97 | return dec_out[:, -self.pred_len:, :] # [B, L, D] -------------------------------------------------------------------------------- /models/DLinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class moving_avg(nn.Module): 7 | """ 8 | Moving average block to highlight the trend of time series 9 | """ 10 | def __init__(self, kernel_size, stride): 11 | super(moving_avg, self).__init__() 12 | self.kernel_size = kernel_size 13 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 14 | 15 | def forward(self, x): 16 | # padding on the both ends of time series 17 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 18 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 19 | x = torch.cat([front, x, end], dim=1) 20 | x = self.avg(x.permute(0, 2, 1)) 21 | x = x.permute(0, 2, 1) 22 | return x 23 | 24 | 25 | class series_decomp(nn.Module): 26 | """ 27 | Series decomposition block 28 | """ 29 | def __init__(self, kernel_size): 30 | super(series_decomp, self).__init__() 31 | self.moving_avg = moving_avg(kernel_size, stride=1) 32 | 33 | def forward(self, x): 34 | moving_mean = self.moving_avg(x) 35 | res = x - moving_mean 36 | return res, moving_mean 37 | 38 | class Model(nn.Module): 39 | """ 40 | Decomposition-Linear 41 | """ 42 | def __init__(self, configs): 43 | super(Model, self).__init__() 44 | self.seq_len = configs.seq_len 45 | self.pred_len = configs.pred_len 46 | 47 | # Decompsition Kernel Size 48 | kernel_size = 25 49 | self.decompsition = series_decomp(kernel_size) 50 | self.individual = configs.individual 51 | self.channels = configs.enc_in 52 | 53 | if self.individual: 54 | self.Linear_Seasonal = nn.ModuleList() 55 | self.Linear_Trend = nn.ModuleList() 56 | 57 | for i in range(self.channels): 58 | self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len)) 59 | self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len)) 60 | 61 | # Use this two lines if you want to visualize the weights 62 | # self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 63 | # self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 64 | else: 65 | self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len) 66 | self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len) 67 | 68 | # Use this two lines if you want to visualize the weights 69 | # self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 70 | # self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 71 | 72 | def forward(self, x): 73 | # x: [Batch, Input length, Channel] 74 | seasonal_init, trend_init = self.decompsition(x) 75 | seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1) 76 | if self.individual: 77 | seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device) 78 | trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device) 79 | for i in range(self.channels): 80 | seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:]) 81 | trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:]) 82 | else: 83 | seasonal_output = self.Linear_Seasonal(seasonal_init) 84 | trend_output = self.Linear_Trend(trend_init) 85 | 86 | x = seasonal_output + trend_output 87 | 88 | 89 | return x.permute(0,2,1) # to [Batch, Output length, Channel] 90 | -------------------------------------------------------------------------------- /models/Informer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.masking import TriangularCausalMask, ProbMask 5 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer 6 | from layers.SelfAttention_Family import FullAttention, ProbAttention, AttentionLayer 7 | from layers.Embed import DataEmbedding,DataEmbedding_wo_pos,DataEmbedding_wo_temp,DataEmbedding_wo_pos_temp 8 | import numpy as np 9 | 10 | 11 | class Model(nn.Module): 12 | """ 13 | Informer with Propspare attention in O(LlogL) complexity 14 | """ 15 | def __init__(self, configs): 16 | super(Model, self).__init__() 17 | self.pred_len = configs.pred_len 18 | self.output_attention = configs.output_attention 19 | 20 | # Embedding 21 | if configs.embed_type == 0: 22 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 23 | configs.dropout) 24 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 25 | configs.dropout) 26 | elif configs.embed_type == 1: 27 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 28 | configs.dropout) 29 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 30 | configs.dropout) 31 | elif configs.embed_type == 2: 32 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, 33 | configs.dropout) 34 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, 35 | configs.dropout) 36 | 37 | elif configs.embed_type == 3: 38 | self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, 39 | configs.dropout) 40 | self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, 41 | configs.dropout) 42 | elif configs.embed_type == 4: 43 | self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, 44 | configs.dropout) 45 | self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, 46 | configs.dropout) 47 | # Encoder 48 | self.encoder = Encoder( 49 | [ 50 | EncoderLayer( 51 | AttentionLayer( 52 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, 53 | output_attention=configs.output_attention), 54 | configs.d_model, configs.n_heads), 55 | configs.d_model, 56 | configs.d_ff, 57 | dropout=configs.dropout, 58 | activation=configs.activation 59 | ) for l in range(configs.e_layers) 60 | ], 61 | [ 62 | ConvLayer( 63 | configs.d_model 64 | ) for l in range(configs.e_layers - 1) 65 | ] if configs.distil else None, 66 | norm_layer=torch.nn.LayerNorm(configs.d_model) 67 | ) 68 | # Decoder 69 | self.decoder = Decoder( 70 | [ 71 | DecoderLayer( 72 | AttentionLayer( 73 | ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False), 74 | configs.d_model, configs.n_heads), 75 | AttentionLayer( 76 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), 77 | configs.d_model, configs.n_heads), 78 | configs.d_model, 79 | configs.d_ff, 80 | dropout=configs.dropout, 81 | activation=configs.activation, 82 | ) 83 | for l in range(configs.d_layers) 84 | ], 85 | norm_layer=torch.nn.LayerNorm(configs.d_model), 86 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 87 | ) 88 | 89 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 90 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 91 | 92 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 93 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 94 | 95 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 96 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) 97 | 98 | if self.output_attention: 99 | return dec_out[:, -self.pred_len:, :], attns 100 | else: 101 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 102 | -------------------------------------------------------------------------------- /models/LDLinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from torch import Tensor 5 | import torch.nn.functional as F 6 | from typing import Optional 7 | 8 | ''' 9 | LDLinear: Replace the Moving Average Kernel (MOV) of DLinear with a Learnable Decomposition Module (LD), which is proposed in this paper: 10 | https://openreview.net/forum?id=87CYNyCGOo 11 | ''' 12 | class LD(nn.Module): 13 | def __init__(self, kernel_size=25): 14 | super(LD, self).__init__() 15 | # Define a shared convolution layers for all channels 16 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, stride=1, padding=int(kernel_size // 2), 17 | padding_mode='replicate', bias=True) 18 | # Define the parameters for Gaussian initialization 19 | kernel_size_half = kernel_size // 2 20 | sigma = 1.0 # 1 for variance 21 | weights = torch.zeros(1, 1, kernel_size) 22 | for i in range(kernel_size): 23 | weights[0, 0, i] = math.exp(-((i - kernel_size_half) / (2 * sigma)) ** 2) 24 | 25 | # Set the weights of the convolution layer 26 | self.conv.weight.data = F.softmax(weights, dim=-1) 27 | self.conv.bias.data.fill_(0.0) 28 | 29 | def forward(self, inp): 30 | # Permute the input tensor to match the expected shape for 1D convolution (B, N, T) 31 | inp = inp.permute(0, 2, 1) 32 | # Split the input tensor into separate channels 33 | input_channels = torch.split(inp, 1, dim=1) 34 | 35 | # Apply convolution to each channel 36 | conv_outputs = [self.conv(input_channel) for input_channel in input_channels] 37 | 38 | # Concatenate the channel outputs 39 | out = torch.cat(conv_outputs, dim=1) 40 | out = out.permute(0, 2, 1) 41 | return out 42 | 43 | 44 | class Model(nn.Module): 45 | 46 | def __init__(self, configs): 47 | super(Model, self).__init__() 48 | self.seq_len = configs.seq_len 49 | self.pred_len = configs.pred_len 50 | 51 | # Decompsition Kernel Size 52 | kernel_size = 25 53 | self.LD = LD(kernel_size=kernel_size) 54 | # self.decompsition = series_decomp(kernel_size) 55 | self.channels = configs.enc_in 56 | 57 | self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len) 58 | self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len) 59 | 60 | 61 | def forward(self, x): 62 | # x: [Batch, Input length, Channel] 63 | trend_init = self.LD(x) 64 | seasonal_init = x - trend_init 65 | seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(0, 2, 1) 66 | 67 | seasonal_output = self.Linear_Seasonal(seasonal_init) 68 | trend_output = self.Linear_Trend(trend_init) 69 | 70 | x = seasonal_output + trend_output 71 | 72 | return x.permute(0, 2, 1) # to [Batch, Output length, Channel] -------------------------------------------------------------------------------- /models/Linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class Model(nn.Module): 7 | """ 8 | Just one Linear layer 9 | """ 10 | def __init__(self, configs): 11 | super(Model, self).__init__() 12 | self.seq_len = configs.seq_len 13 | self.pred_len = configs.pred_len 14 | self.Linear = nn.Linear(self.seq_len, self.pred_len) 15 | # Use this line if you want to visualize the weights 16 | # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 17 | 18 | def forward(self, x): 19 | # x: [Batch, Input length, Channel] 20 | x = self.Linear(x.permute(0,2,1)).permute(0,2,1) 21 | return x # [Batch, Output length, Channel] -------------------------------------------------------------------------------- /models/NLinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class Model(nn.Module): 7 | """ 8 | Normalization-Linear 9 | """ 10 | def __init__(self, configs): 11 | super(Model, self).__init__() 12 | self.seq_len = configs.seq_len 13 | self.pred_len = configs.pred_len 14 | self.Linear = nn.Linear(self.seq_len, self.pred_len) 15 | # Use this line if you want to visualize the weights 16 | # self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 17 | 18 | def forward(self, x): 19 | # x: [Batch, Input length, Channel] 20 | seq_last = x[:,-1:,:].detach() 21 | x = x - seq_last 22 | x = self.Linear(x.permute(0,2,1)).permute(0,2,1) 23 | x = x + seq_last 24 | return x # [Batch, Output length, Channel] -------------------------------------------------------------------------------- /models/PatchTST.py: -------------------------------------------------------------------------------- 1 | __all__ = ['PatchTST'] 2 | 3 | # Cell 4 | from typing import Callable, Optional 5 | import torch 6 | from torch import nn 7 | from torch import Tensor 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | from layers.PatchTST_backbone import PatchTST_backbone 12 | from layers.PatchTST_layers import series_decomp 13 | 14 | 15 | class Model(nn.Module): 16 | def __init__(self, configs, max_seq_len:Optional[int]=1024, d_k:Optional[int]=None, d_v:Optional[int]=None, norm:str='BatchNorm', attn_dropout:float=0., 17 | act:str="gelu", key_padding_mask:bool='auto',padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, 18 | pre_norm:bool=False, store_attn:bool=False, pe:str='zeros', learn_pe:bool=True, pretrain_head:bool=False, head_type = 'flatten', verbose:bool=False, **kwargs): 19 | 20 | super().__init__() 21 | 22 | # load parameters 23 | c_in = configs.enc_in 24 | context_window = configs.seq_len 25 | target_window = configs.pred_len 26 | 27 | n_layers = configs.e_layers 28 | n_heads = configs.n_heads 29 | d_model = configs.d_model 30 | d_ff = configs.d_ff 31 | dropout = configs.dropout 32 | fc_dropout = configs.fc_dropout 33 | head_dropout = configs.head_dropout 34 | 35 | individual = configs.individual 36 | 37 | patch_len = configs.patch_len 38 | stride = configs.stride 39 | padding_patch = configs.padding_patch 40 | 41 | revin = configs.revin 42 | affine = configs.affine 43 | subtract_last = configs.subtract_last 44 | 45 | decomposition = configs.decomposition 46 | kernel_size = configs.kernel_size 47 | 48 | 49 | # model 50 | self.decomposition = decomposition 51 | if self.decomposition: 52 | self.decomp_module = series_decomp(kernel_size) 53 | self.model_trend = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride, 54 | max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model, 55 | n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, 56 | dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, 57 | attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, 58 | pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch, 59 | pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine, 60 | subtract_last=subtract_last, verbose=verbose, **kwargs) 61 | self.model_res = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride, 62 | max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model, 63 | n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, 64 | dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, 65 | attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, 66 | pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch, 67 | pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine, 68 | subtract_last=subtract_last, verbose=verbose, **kwargs) 69 | else: 70 | self.model = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride, 71 | max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model, 72 | n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, 73 | dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, 74 | attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, 75 | pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch, 76 | pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine, 77 | subtract_last=subtract_last, verbose=verbose, **kwargs) 78 | 79 | 80 | def forward(self, x): # x: [Batch, Input length, Channel] 81 | if self.decomposition: 82 | res_init, trend_init = self.decomp_module(x) 83 | res_init, trend_init = res_init.permute(0,2,1), trend_init.permute(0,2,1) # x: [Batch, Channel, Input length] 84 | res = self.model_res(res_init) 85 | trend = self.model_trend(trend_init) 86 | x = res + trend 87 | x = x.permute(0,2,1) # x: [Batch, Input length, Channel] 88 | else: 89 | x = x.permute(0,2,1) # x: [Batch, Channel, Input length] 90 | x = self.model(x) 91 | x = x.permute(0,2,1) # x: [Batch, Input length, Channel] 92 | return x -------------------------------------------------------------------------------- /models/RLinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class Model(nn.Module): 7 | def __init__(self, configs): 8 | super(Model, self).__init__() 9 | self.seq_len = configs.seq_len 10 | self.pred_len = configs.pred_len 11 | self.Linear = nn.Linear(self.seq_len, self.pred_len) 12 | 13 | def forward(self, x): 14 | # x: [Batch, Input length, Channel] 15 | seq_mean = torch.mean(x, dim=1, keepdim=True) 16 | seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5 17 | x = (x - seq_mean) / torch.sqrt(seq_var) 18 | 19 | y = self.Linear(x.permute(0,2,1)).permute(0,2,1) 20 | 21 | y = y * torch.sqrt(seq_var) + seq_mean 22 | return y # [Batch, Output length, Channel] -------------------------------------------------------------------------------- /models/RMLP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class Model(nn.Module): 7 | def __init__(self, configs): 8 | super(Model, self).__init__() 9 | self.seq_len = configs.seq_len 10 | self.pred_len = configs.pred_len 11 | self.d_model = configs.d_model 12 | self.MLP = nn.Sequential( 13 | nn.Linear(self.seq_len, self.d_model), 14 | nn.ReLU(), 15 | nn.Linear(self.d_model, self.pred_len) 16 | ) 17 | 18 | def forward(self, x): 19 | # x: [Batch, Input length, Channel] 20 | seq_mean = torch.mean(x, dim=1, keepdim=True) 21 | seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5 22 | x = (x - seq_mean) / torch.sqrt(seq_var) 23 | 24 | y = self.MLP(x.permute(0,2,1)).permute(0,2,1) 25 | 26 | y = y * torch.sqrt(seq_var) + seq_mean 27 | return y # [Batch, Output length, Channel] -------------------------------------------------------------------------------- /models/SegRNN.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A complete implementation version containing all code (including ablation components) 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | from layers.RevIN import RevIN 8 | 9 | class Model(nn.Module): 10 | def __init__(self, configs): 11 | super(Model, self).__init__() 12 | 13 | # get parameters 14 | self.seq_len = configs.seq_len 15 | self.pred_len = configs.pred_len 16 | self.enc_in = configs.enc_in 17 | self.d_model = configs.d_model 18 | self.dropout = configs.dropout 19 | 20 | self.rnn_type = configs.rnn_type 21 | self.dec_way = configs.dec_way 22 | self.seg_len = configs.seg_len 23 | self.channel_id = configs.channel_id 24 | self.revin = configs.revin 25 | 26 | assert self.rnn_type in ['rnn', 'gru', 'lstm'] 27 | assert self.dec_way in ['rmf', 'pmf'] 28 | 29 | self.seg_num_x = self.seq_len//self.seg_len 30 | 31 | # build model 32 | self.valueEmbedding = nn.Sequential( 33 | nn.Linear(self.seg_len, self.d_model), 34 | nn.ReLU() 35 | ) 36 | 37 | if self.rnn_type == "rnn": 38 | self.rnn = nn.RNN(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True, 39 | batch_first=True, bidirectional=False) 40 | elif self.rnn_type == "gru": 41 | self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True, 42 | batch_first=True, bidirectional=False) 43 | elif self.rnn_type == "lstm": 44 | self.rnn = nn.LSTM(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True, 45 | batch_first=True, bidirectional=False) 46 | 47 | if self.dec_way == "rmf": 48 | self.seg_num_y = self.pred_len // self.seg_len 49 | self.predict = nn.Sequential( 50 | nn.Dropout(self.dropout), 51 | nn.Linear(self.d_model, self.seg_len) 52 | ) 53 | elif self.dec_way == "pmf": 54 | self.seg_num_y = self.pred_len // self.seg_len 55 | 56 | if self.channel_id: 57 | self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model // 2)) 58 | self.channel_emb = nn.Parameter(torch.randn(self.enc_in, self.d_model // 2)) 59 | else: 60 | self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model)) 61 | 62 | self.predict = nn.Sequential( 63 | nn.Dropout(self.dropout), 64 | nn.Linear(self.d_model, self.seg_len) 65 | ) 66 | if self.revin: 67 | self.revinLayer = RevIN(self.enc_in, affine=False, subtract_last=False) 68 | 69 | 70 | def forward(self, x): 71 | 72 | # b:batch_size c:channel_size s:seq_len s:seq_len 73 | # d:d_model w:seg_len n:seg_num_x m:seg_num_y 74 | batch_size = x.size(0) 75 | 76 | # normalization and permute b,s,c -> b,c,s 77 | if self.revin: 78 | x = self.revinLayer(x, 'norm').permute(0, 2, 1) 79 | else: 80 | seq_last = x[:, -1:, :].detach() 81 | x = (x - seq_last).permute(0, 2, 1) # b,c,s 82 | 83 | # segment and embedding b,c,s -> bc,n,w -> bc,n,d 84 | x = self.valueEmbedding(x.reshape(-1, self.seg_num_x, self.seg_len)) 85 | 86 | # encoding 87 | if self.rnn_type == "lstm": 88 | _, (hn, cn) = self.rnn(x) 89 | else: 90 | _, hn = self.rnn(x) # bc,n,d 1,bc,d 91 | 92 | # decoding 93 | if self.dec_way == "rmf": 94 | y = [] 95 | for i in range(self.seg_num_y): 96 | yy = self.predict(hn) # 1,bc,l 97 | yy = yy.permute(1,0,2) # bc,1,l 98 | y.append(yy) 99 | yy = self.valueEmbedding(yy) 100 | if self.rnn_type == "lstm": 101 | _, (hn, cn) = self.rnn(yy, (hn, cn)) 102 | else: 103 | _, hn = self.rnn(yy, hn) 104 | y = torch.stack(y, dim=1).squeeze(2).reshape(-1, self.enc_in, self.pred_len) # b,c,s 105 | elif self.dec_way == "pmf": 106 | if self.channel_id: 107 | # m,d//2 -> 1,m,d//2 -> c,m,d//2 108 | # c,d//2 -> c,1,d//2 -> c,m,d//2 109 | # c,m,d -> cm,1,d -> bcm, 1, d 110 | pos_emb = torch.cat([ 111 | self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1), 112 | self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1) 113 | ], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1) 114 | else: 115 | # m,d -> bcm,d -> bcm, 1, d 116 | pos_emb = self.pos_emb.repeat(batch_size * self.enc_in, 1).unsqueeze(1) 117 | 118 | # pos_emb: m,d -> bcm,d -> bcm,1,d 119 | # hn, cn: 1,bc,d -> 1,bc,md -> 1,bcm,d 120 | if self.rnn_type == "lstm": 121 | _, (hy, cy) = self.rnn(pos_emb, 122 | (hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model), 123 | cn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model))) 124 | else: 125 | _, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model)) 126 | # 1,bcm,d -> 1,bcm,w -> b,c,s 127 | y = self.predict(hy).view(-1, self.enc_in, self.pred_len) 128 | 129 | # permute and denorm 130 | if self.revin: 131 | y = self.revinLayer(y.permute(0, 2, 1), 'denorm') 132 | else: 133 | y = y.permute(0, 2, 1) + seq_last 134 | 135 | return y 136 | 137 | ''' 138 | Concise version implementation that only includes necessary code 139 | ''' 140 | # import torch 141 | # import torch.nn as nn 142 | # 143 | # class Model(nn.Module): 144 | # def __init__(self, configs): 145 | # super(Model, self).__init__() 146 | # 147 | # # get parameters 148 | # self.seq_len = configs.seq_len 149 | # self.pred_len = configs.pred_len 150 | # self.enc_in = configs.enc_in 151 | # self.d_model = configs.d_model 152 | # self.dropout = configs.dropout 153 | # 154 | # self.seg_len = configs.seg_len 155 | # self.seg_num_x = self.seq_len//self.seg_len 156 | # self.seg_num_y = self.pred_len // self.seg_len 157 | # 158 | # 159 | # self.valueEmbedding = nn.Sequential( 160 | # nn.Linear(self.seg_len, self.d_model), 161 | # nn.ReLU() 162 | # ) 163 | # self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True, 164 | # batch_first=True, bidirectional=False) 165 | # self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model // 2)) 166 | # self.channel_emb = nn.Parameter(torch.randn(self.enc_in, self.d_model // 2)) 167 | # self.predict = nn.Sequential( 168 | # nn.Dropout(self.dropout), 169 | # nn.Linear(self.d_model, self.seg_len) 170 | # ) 171 | # 172 | # def forward(self, x): 173 | # # b:batch_size c:channel_size s:seq_len s:seq_len 174 | # # d:d_model w:seg_len n:seg_num_x m:seg_num_y 175 | # batch_size = x.size(0) 176 | # 177 | # # normalization and permute b,s,c -> b,c,s 178 | # seq_last = x[:, -1:, :].detach() 179 | # x = (x - seq_last).permute(0, 2, 1) # b,c,s 180 | # 181 | # # segment and embedding b,c,s -> bc,n,w -> bc,n,d 182 | # x = self.valueEmbedding(x.reshape(-1, self.seg_num_x, self.seg_len)) 183 | # 184 | # # encoding 185 | # _, hn = self.rnn(x) # bc,n,d 1,bc,d 186 | # 187 | # # m,d//2 -> 1,m,d//2 -> c,m,d//2 188 | # # c,d//2 -> c,1,d//2 -> c,m,d//2 189 | # # c,m,d -> cm,1,d -> bcm, 1, d 190 | # pos_emb = torch.cat([ 191 | # self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1), 192 | # self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1) 193 | # ], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1) 194 | # 195 | # _, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model)) # bcm,1,d 1,bcm,d 196 | # 197 | # # 1,bcm,d -> 1,bcm,w -> b,c,s 198 | # y = self.predict(hy).view(-1, self.enc_in, self.pred_len) 199 | # 200 | # # permute and denorm 201 | # y = y.permute(0, 2, 1) + seq_last 202 | # 203 | # return y -------------------------------------------------------------------------------- /models/SparseTSF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Model(nn.Module): 6 | def __init__(self, configs): 7 | super(Model, self).__init__() 8 | 9 | # get parameters 10 | self.seq_len = configs.seq_len 11 | self.pred_len = configs.pred_len 12 | self.enc_in = configs.enc_in 13 | 14 | self.use_revin = configs.use_revin 15 | 16 | self.period_len = configs.period_len 17 | 18 | self.seg_num_x = self.seq_len // self.period_len 19 | self.seg_num_y = self.pred_len // self.period_len 20 | 21 | self.conv1d = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=1 + 2 * self.period_len // 2, 22 | stride=1, padding=self.period_len // 2, padding_mode="zeros", bias=False) 23 | 24 | self.linear = nn.Linear(self.seg_num_x, self.seg_num_y, bias=False) 25 | 26 | 27 | def forward(self, x): 28 | batch_size = x.shape[0] 29 | 30 | # normalization and permute b,s,c -> b,c,s 31 | if self.use_revin: 32 | seq_mean = torch.mean(x, dim=1).unsqueeze(1) 33 | x = (x - seq_mean).permute(0, 2, 1) 34 | else: 35 | x = x.permute(0, 2, 1) 36 | 37 | x = self.conv1d(x.reshape(-1, 1, self.seq_len)).reshape(-1, self.enc_in, self.seq_len) + x 38 | 39 | # b,c,s -> bc,n,w -> bc,w,n 40 | x = x.reshape(-1, self.seg_num_x, self.period_len).permute(0, 2, 1) 41 | 42 | y = self.linear(x) # bc,w,m 43 | 44 | # bc,w,m -> bc,m,w -> b,c,s 45 | y = y.permute(0, 2, 1).reshape(batch_size, self.enc_in, self.pred_len) 46 | 47 | if self.use_revin: 48 | y = y.permute(0, 2, 1) + seq_mean 49 | else: 50 | y = y.permute(0, 2, 1) 51 | 52 | return y -------------------------------------------------------------------------------- /models/Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer 5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding,DataEmbedding_wo_pos,DataEmbedding_wo_temp,DataEmbedding_wo_pos_temp 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Vanilla Transformer with O(L^2) complexity 13 | """ 14 | def __init__(self, configs): 15 | super(Model, self).__init__() 16 | self.pred_len = configs.pred_len 17 | self.output_attention = configs.output_attention 18 | 19 | # Embedding 20 | if configs.embed_type == 0: 21 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 22 | configs.dropout) 23 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 24 | configs.dropout) 25 | elif configs.embed_type == 1: 26 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 27 | configs.dropout) 28 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 29 | configs.dropout) 30 | elif configs.embed_type == 2: 31 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, 32 | configs.dropout) 33 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, 34 | configs.dropout) 35 | 36 | elif configs.embed_type == 3: 37 | self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, 38 | configs.dropout) 39 | self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, 40 | configs.dropout) 41 | elif configs.embed_type == 4: 42 | self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, 43 | configs.dropout) 44 | self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, 45 | configs.dropout) 46 | # Encoder 47 | self.encoder = Encoder( 48 | [ 49 | EncoderLayer( 50 | AttentionLayer( 51 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 52 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 53 | configs.d_model, 54 | configs.d_ff, 55 | dropout=configs.dropout, 56 | activation=configs.activation 57 | ) for l in range(configs.e_layers) 58 | ], 59 | norm_layer=torch.nn.LayerNorm(configs.d_model) 60 | ) 61 | # Decoder 62 | self.decoder = Decoder( 63 | [ 64 | DecoderLayer( 65 | AttentionLayer( 66 | FullAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False), 67 | configs.d_model, configs.n_heads), 68 | AttentionLayer( 69 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), 70 | configs.d_model, configs.n_heads), 71 | configs.d_model, 72 | configs.d_ff, 73 | dropout=configs.dropout, 74 | activation=configs.activation, 75 | ) 76 | for l in range(configs.d_layers) 77 | ], 78 | norm_layer=torch.nn.LayerNorm(configs.d_model), 79 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 80 | ) 81 | 82 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 83 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 84 | 85 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 86 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 87 | 88 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 89 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) 90 | 91 | if self.output_attention: 92 | return dec_out[:, -self.pred_len:, :], attns 93 | else: 94 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 95 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | pandas 4 | scikit-learn 5 | torch -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from exp.exp_main import Exp_Main 5 | import random 6 | import numpy as np 7 | 8 | parser = argparse.ArgumentParser(description='Model family for Time Series Forecasting') 9 | 10 | # random seed 11 | parser.add_argument('--random_seed', type=int, default=2024, help='random seed') 12 | 13 | # basic config 14 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status') 15 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id') 16 | parser.add_argument('--model', type=str, required=True, default='Autoformer', 17 | help='model name, options: [Autoformer, Informer, Transformer]') 18 | 19 | # data loader 20 | parser.add_argument('--data', type=str, required=True, default='ETTh1', help='dataset type') 21 | parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file') 22 | parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file') 23 | parser.add_argument('--features', type=str, default='M', 24 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate') 25 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task') 26 | parser.add_argument('--freq', type=str, default='h', 27 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') 28 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints') 29 | 30 | # forecasting task 31 | parser.add_argument('--seq_len', type=int, default=96, help='input sequence length') 32 | parser.add_argument('--label_len', type=int, default=0, help='start token length') #fixed 33 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length') 34 | 35 | # CycleNet. 36 | parser.add_argument('--cycle', type=int, default=24, help='cycle length') 37 | parser.add_argument('--model_type', type=str, default='mlp', help='model type, options: [linear, mlp]') 38 | parser.add_argument('--use_revin', type=int, default=1, help='1: use revin or 0: no revin') 39 | 40 | # DLinear 41 | #parser.add_argument('--individual', action='store_true', default=False, help='DLinear: a linear layer for each variate(channel) individually') 42 | 43 | # PatchTST 44 | parser.add_argument('--fc_dropout', type=float, default=0.05, help='fully connected dropout') 45 | parser.add_argument('--head_dropout', type=float, default=0.0, help='head dropout') 46 | parser.add_argument('--patch_len', type=int, default=16, help='patch length') 47 | parser.add_argument('--stride', type=int, default=8, help='stride') 48 | parser.add_argument('--padding_patch', default='end', help='None: None; end: padding on the end') 49 | parser.add_argument('--revin', type=int, default=0, help='RevIN; True 1 False 0') 50 | parser.add_argument('--affine', type=int, default=0, help='RevIN-affine; True 1 False 0') 51 | parser.add_argument('--subtract_last', type=int, default=0, help='0: subtract mean; 1: subtract last') 52 | parser.add_argument('--decomposition', type=int, default=0, help='decomposition; True 1 False 0') 53 | parser.add_argument('--kernel_size', type=int, default=25, help='decomposition-kernel') 54 | parser.add_argument('--individual', type=int, default=0, help='individual head; True 1 False 0') 55 | 56 | # SegRNN 57 | parser.add_argument('--rnn_type', default='gru', help='rnn_type') 58 | parser.add_argument('--dec_way', default='pmf', help='decode way') 59 | parser.add_argument('--seg_len', type=int, default=48, help='segment length') 60 | parser.add_argument('--channel_id', type=int, default=1, help='Whether to enable channel position encoding') 61 | 62 | # SparseTSF 63 | parser.add_argument('--period_len', type=int, default=24, help='period_len') 64 | 65 | # Formers 66 | parser.add_argument('--embed_type', type=int, default=0, help='0: default 1: value embedding + temporal embedding + positional embedding 2: value embedding + temporal embedding 3: value embedding + positional embedding 4: value embedding') 67 | parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') # DLinear with --individual, use this hyperparameter as the number of channels 68 | parser.add_argument('--dec_in', type=int, default=7, help='decoder input size') 69 | parser.add_argument('--c_out', type=int, default=7, help='output size') 70 | parser.add_argument('--d_model', type=int, default=512, help='dimension of model') 71 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads') 72 | parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers') 73 | parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers') 74 | parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn') 75 | parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average') 76 | parser.add_argument('--factor', type=int, default=1, help='attn factor') 77 | parser.add_argument('--distil', action='store_false', 78 | help='whether to use distilling in encoder, using this argument means not using distilling', 79 | default=True) 80 | parser.add_argument('--dropout', type=float, default=0, help='dropout') 81 | parser.add_argument('--embed', type=str, default='timeF', 82 | help='time features encoding, options:[timeF, fixed, learned]') 83 | parser.add_argument('--activation', type=str, default='gelu', help='activation') 84 | parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder') 85 | parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data') 86 | 87 | # optimization 88 | parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') 89 | parser.add_argument('--itr', type=int, default=1, help='experiments times') 90 | parser.add_argument('--train_epochs', type=int, default=30, help='train epochs') 91 | parser.add_argument('--batch_size', type=int, default=128, help='batch size of train input data') 92 | parser.add_argument('--patience', type=int, default=5, help='early stopping patience') 93 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate') 94 | parser.add_argument('--des', type=str, default='test', help='exp description') 95 | parser.add_argument('--loss', type=str, default='mse', help='loss function') 96 | parser.add_argument('--lradj', type=str, default='type3', help='adjust learning rate') 97 | parser.add_argument('--pct_start', type=float, default=0.3, help='pct_start') 98 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False) 99 | 100 | # GPU 101 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') 102 | parser.add_argument('--gpu', type=int, default=0, help='gpu') 103 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) 104 | parser.add_argument('--devices', type=str, default='0,1', help='device ids of multile gpus') 105 | parser.add_argument('--test_flop', action='store_true', default=False, help='See utils/tools for usage') 106 | 107 | args = parser.parse_args() 108 | 109 | # random seed 110 | fix_seed = args.random_seed 111 | random.seed(fix_seed) 112 | torch.manual_seed(fix_seed) 113 | np.random.seed(fix_seed) 114 | 115 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False 116 | 117 | if args.use_gpu and args.use_multi_gpu: 118 | args.devices = args.devices.replace(' ', '') 119 | device_ids = args.devices.split(',') 120 | args.device_ids = [int(id_) for id_ in device_ids] 121 | args.gpu = args.device_ids[0] 122 | 123 | print('Args in experiment:') 124 | print(args) 125 | 126 | Exp = Exp_Main 127 | 128 | if args.is_training: 129 | for ii in range(args.itr): 130 | 131 | # setting record of experiments 132 | setting = '{}_{}_{}_ft{}_sl{}_pl{}_cycle{}_{}_seed{}'.format( 133 | args.model_id, 134 | args.model, 135 | args.data, 136 | args.features, 137 | args.seq_len, 138 | args.pred_len, 139 | args.cycle, 140 | args.model_type, 141 | fix_seed) 142 | 143 | exp = Exp(args) # set experiments 144 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) 145 | exp.train(setting) 146 | 147 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 148 | exp.test(setting) 149 | 150 | if args.do_predict: 151 | print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 152 | exp.predict(setting, True) 153 | 154 | torch.cuda.empty_cache() 155 | else: 156 | ii = 0 157 | setting = '{}_{}_{}_ft{}_sl{}_pl{}_cycle{}_{}_seed{}'.format( 158 | args.model_id, 159 | args.model, 160 | args.data, 161 | args.features, 162 | args.seq_len, 163 | args.pred_len, 164 | args.cycle, 165 | args.model_type, 166 | fix_seed) 167 | 168 | exp = Exp(args) # set experiments 169 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 170 | exp.test(setting, test=1) 171 | torch.cuda.empty_cache() 172 | -------------------------------------------------------------------------------- /run_main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sh scripts/CycleNet/Linear-Input-96/etth1.sh; 4 | sh scripts/CycleNet/Linear-Input-96/etth2.sh; 5 | sh scripts/CycleNet/Linear-Input-96/ettm1.sh; 6 | sh scripts/CycleNet/Linear-Input-96/ettm2.sh; 7 | sh scripts/CycleNet/Linear-Input-96/weather.sh; 8 | sh scripts/CycleNet/Linear-Input-96/electricity.sh; 9 | sh scripts/CycleNet/Linear-Input-96/traffic.sh; 10 | sh scripts/CycleNet/Linear-Input-96/solar.sh; 11 | 12 | sh scripts/CycleNet/Linear-Input-336/etth1.sh; 13 | sh scripts/CycleNet/Linear-Input-336/etth2.sh; 14 | sh scripts/CycleNet/Linear-Input-336/ettm1.sh; 15 | sh scripts/CycleNet/Linear-Input-336/ettm2.sh; 16 | sh scripts/CycleNet/Linear-Input-336/weather.sh; 17 | sh scripts/CycleNet/Linear-Input-336/electricity.sh; 18 | sh scripts/CycleNet/Linear-Input-336/traffic.sh; 19 | sh scripts/CycleNet/Linear-Input-336/solar.sh; 20 | 21 | sh scripts/CycleNet/Linear-Input-720/etth1.sh; 22 | sh scripts/CycleNet/Linear-Input-720/etth2.sh; 23 | sh scripts/CycleNet/Linear-Input-720/ettm1.sh; 24 | sh scripts/CycleNet/Linear-Input-720/ettm2.sh; 25 | sh scripts/CycleNet/Linear-Input-720/weather.sh; 26 | sh scripts/CycleNet/Linear-Input-720/electricity.sh; 27 | sh scripts/CycleNet/Linear-Input-720/traffic.sh; 28 | sh scripts/CycleNet/Linear-Input-720/solar.sh; 29 | 30 | sh scripts/CycleNet/MLP-Input-96/etth1.sh; 31 | sh scripts/CycleNet/MLP-Input-96/etth2.sh; 32 | sh scripts/CycleNet/MLP-Input-96/ettm1.sh; 33 | sh scripts/CycleNet/MLP-Input-96/ettm2.sh; 34 | sh scripts/CycleNet/MLP-Input-96/weather.sh; 35 | sh scripts/CycleNet/MLP-Input-96/electricity.sh; 36 | sh scripts/CycleNet/MLP-Input-96/traffic.sh; 37 | sh scripts/CycleNet/MLP-Input-96/solar.sh; 38 | 39 | sh scripts/CycleNet/MLP-Input-336/etth1.sh; 40 | sh scripts/CycleNet/MLP-Input-336/etth2.sh; 41 | sh scripts/CycleNet/MLP-Input-336/ettm1.sh; 42 | sh scripts/CycleNet/MLP-Input-336/ettm2.sh; 43 | sh scripts/CycleNet/MLP-Input-336/weather.sh; 44 | sh scripts/CycleNet/MLP-Input-336/electricity.sh; 45 | sh scripts/CycleNet/MLP-Input-336/traffic.sh; 46 | sh scripts/CycleNet/MLP-Input-336/solar.sh; 47 | 48 | sh scripts/CycleNet/MLP-Input-720/etth1.sh; 49 | sh scripts/CycleNet/MLP-Input-720/etth2.sh; 50 | sh scripts/CycleNet/MLP-Input-720/ettm1.sh; 51 | sh scripts/CycleNet/MLP-Input-720/ettm2.sh; 52 | sh scripts/CycleNet/MLP-Input-720/weather.sh; 53 | sh scripts/CycleNet/MLP-Input-720/electricity.sh; 54 | sh scripts/CycleNet/MLP-Input-720/traffic.sh; 55 | sh scripts/CycleNet/MLP-Input-720/solar.sh; -------------------------------------------------------------------------------- /run_pems.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sh scripts/CycleNet/PEMS/pems03.sh; 4 | sh scripts/CycleNet/PEMS/pems04.sh; 5 | sh scripts/CycleNet/PEMS/pems07.sh; 6 | sh scripts/CycleNet/PEMS/pems08.sh; 7 | 8 | -------------------------------------------------------------------------------- /run_std.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sh scripts/CycleNet/STD/CycleNet.sh; 4 | sh scripts/CycleNet/STD/LDLinear.sh; 5 | sh scripts/CycleNet/STD/DLinear.sh; 6 | sh scripts/CycleNet/STD/SparseTSF.sh; 7 | sh scripts/CycleNet/STD/Linear.sh; 8 | 9 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-336/electricity.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=electricity.csv 5 | model_id_name=Electricity 6 | data_name=custom 7 | 8 | 9 | model_type='linear' 10 | seq_len=336 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 321 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-336/etth1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh1.csv 5 | model_id_name=ETTh1 6 | data_name=ETTh1 7 | 8 | 9 | model_type='linear' 10 | seq_len=336 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 7 \ 26 | --cycle 24 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-336/etth2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh2.csv 5 | model_id_name=ETTh2 6 | data_name=ETTh2 7 | 8 | model_type='linear' 9 | seq_len=336 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 24 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-336/ettm1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm1.csv 5 | model_id_name=ETTm1 6 | data_name=ETTm1 7 | 8 | model_type='linear' 9 | seq_len=336 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-336/ettm2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm2.csv 5 | model_id_name=ETTm2 6 | data_name=ETTm2 7 | 8 | model_type='linear' 9 | seq_len=336 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-336/solar.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=solar_AL.txt 5 | model_id_name=Solar 6 | data_name=Solar 7 | 8 | model_type='linear' 9 | seq_len=336 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 137 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --use_revin 0 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed 31 | done 32 | done 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-336/traffic.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=traffic.csv 5 | model_id_name=traffic 6 | data_name=custom 7 | 8 | 9 | model_type='linear' 10 | seq_len=336 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 862 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed 31 | done 32 | done 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-336/weather.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=weather.csv 5 | model_id_name=weather 6 | data_name=custom 7 | 8 | model_type='linear' 9 | seq_len=336 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 21 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-720/electricity.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=electricity.csv 5 | model_id_name=Electricity 6 | data_name=custom 7 | 8 | 9 | model_type='linear' 10 | seq_len=720 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 321 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-720/etth1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh1.csv 5 | model_id_name=ETTh1 6 | data_name=ETTh1 7 | 8 | 9 | model_type='linear' 10 | seq_len=720 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 7 \ 26 | --cycle 24 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-720/etth2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh2.csv 5 | model_id_name=ETTh2 6 | data_name=ETTh2 7 | 8 | model_type='linear' 9 | seq_len=720 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 24 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-720/ettm1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm1.csv 5 | model_id_name=ETTm1 6 | data_name=ETTm1 7 | 8 | model_type='linear' 9 | seq_len=720 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-720/ettm2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm2.csv 5 | model_id_name=ETTm2 6 | data_name=ETTm2 7 | 8 | model_type='linear' 9 | seq_len=720 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-720/solar.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=solar_AL.txt 5 | model_id_name=Solar 6 | data_name=Solar 7 | 8 | model_type='linear' 9 | seq_len=720 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 137 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --use_revin 0 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed 31 | done 32 | done 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-720/traffic.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=traffic.csv 5 | model_id_name=traffic 6 | data_name=custom 7 | 8 | 9 | model_type='linear' 10 | seq_len=720 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 862 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed 31 | done 32 | done 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-720/weather.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=weather.csv 5 | model_id_name=weather 6 | data_name=custom 7 | 8 | model_type='linear' 9 | seq_len=720 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 21 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-96/electricity.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=electricity.csv 5 | model_id_name=Electricity 6 | data_name=custom 7 | 8 | 9 | model_type='linear' 10 | seq_len=96 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 321 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-96/etth1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh1.csv 5 | model_id_name=ETTh1 6 | data_name=ETTh1 7 | 8 | 9 | model_type='linear' 10 | seq_len=96 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 7 \ 26 | --cycle 24 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-96/etth2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh2.csv 5 | model_id_name=ETTh2 6 | data_name=ETTh2 7 | 8 | model_type='linear' 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 24 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 30 | done 31 | done 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-96/ettm1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm1.csv 5 | model_id_name=ETTm1 6 | data_name=ETTm1 7 | 8 | model_type='linear' 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-96/ettm2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm2.csv 5 | model_id_name=ETTm2 6 | data_name=ETTm2 7 | 8 | model_type='linear' 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-96/solar.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=solar_AL.txt 5 | model_id_name=Solar 6 | data_name=Solar 7 | 8 | model_type='linear' 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 137 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --use_revin 0 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed 31 | done 32 | done 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-96/traffic.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=traffic.csv 5 | model_id_name=traffic 6 | data_name=custom 7 | 8 | 9 | model_type='linear' 10 | seq_len=96 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 862 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed 31 | done 32 | done 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/Linear-Input-96/weather.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=weather.csv 5 | model_id_name=weather 6 | data_name=custom 7 | 8 | model_type='linear' 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 21 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-336/electricity.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=electricity.csv 5 | model_id_name=Electricity 6 | data_name=custom 7 | 8 | 9 | model_type='mlp' 10 | seq_len=336 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 321 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.002 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-336/etth1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh1.csv 5 | model_id_name=ETTh1 6 | data_name=ETTh1 7 | 8 | 9 | model_type='mlp' 10 | seq_len=336 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 7 \ 26 | --cycle 24 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 256 --learning_rate 0.002 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-336/etth2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh2.csv 5 | model_id_name=ETTh2 6 | data_name=ETTh2 7 | 8 | model_type='mlp' 9 | seq_len=336 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 24 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.002 --random_seed $random_seed 30 | done 31 | done 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-336/ettm1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm1.csv 5 | model_id_name=ETTm1 6 | data_name=ETTm1 7 | 8 | model_type='mlp' 9 | seq_len=336 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.002 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-336/ettm2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm2.csv 5 | model_id_name=ETTm2 6 | data_name=ETTm2 7 | 8 | model_type='mlp' 9 | seq_len=336 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.002 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-336/solar.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=solar_AL.txt 5 | model_id_name=Solar 6 | data_name=Solar 7 | 8 | model_type='mlp' 9 | seq_len=336 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 137 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --use_revin 0 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.002 --random_seed $random_seed 31 | done 32 | done 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-336/traffic.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=traffic.csv 5 | model_id_name=traffic 6 | data_name=custom 7 | 8 | 9 | model_type='mlp' 10 | seq_len=336 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 862 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.002 --random_seed $random_seed 31 | done 32 | done 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-336/weather.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=weather.csv 5 | model_id_name=weather 6 | data_name=custom 7 | 8 | model_type='mlp' 9 | seq_len=336 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 21 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.002 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-720/electricity.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=electricity.csv 5 | model_id_name=Electricity 6 | data_name=custom 7 | 8 | 9 | model_type='mlp' 10 | seq_len=720 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 321 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.001 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-720/etth1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh1.csv 5 | model_id_name=ETTh1 6 | data_name=ETTh1 7 | 8 | 9 | model_type='mlp' 10 | seq_len=720 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 7 \ 26 | --cycle 24 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 256 --learning_rate 0.0005 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-720/etth2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh2.csv 5 | model_id_name=ETTh2 6 | data_name=ETTh2 7 | 8 | model_type='mlp' 9 | seq_len=720 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 24 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.0005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-720/ettm1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm1.csv 5 | model_id_name=ETTm1 6 | data_name=ETTm1 7 | 8 | model_type='mlp' 9 | seq_len=720 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.0005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-720/ettm2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm2.csv 5 | model_id_name=ETTm2 6 | data_name=ETTm2 7 | 8 | model_type='mlp' 9 | seq_len=720 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.0005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-720/solar.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=solar_AL.txt 5 | model_id_name=Solar 6 | data_name=Solar 7 | 8 | model_type='mlp' 9 | seq_len=720 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 137 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --use_revin 0 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.0002 --random_seed $random_seed 31 | done 32 | done 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-720/traffic.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=traffic.csv 5 | model_id_name=traffic 6 | data_name=custom 7 | 8 | 9 | model_type='mlp' 10 | seq_len=720 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 862 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.0005 --random_seed $random_seed 31 | done 32 | done 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-720/weather.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=weather.csv 5 | model_id_name=weather 6 | data_name=custom 7 | 8 | model_type='mlp' 9 | seq_len=720 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 21 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.0005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-96/electricity.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=electricity.csv 5 | model_id_name=Electricity 6 | data_name=custom 7 | 8 | 9 | model_type='mlp' 10 | seq_len=96 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 321 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-96/etth1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh1.csv 5 | model_id_name=ETTh1 6 | data_name=ETTh1 7 | 8 | 9 | model_type='mlp' 10 | seq_len=96 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 7 \ 26 | --cycle 24 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 31 | done 32 | done 33 | 34 | 35 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-96/etth2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTh2.csv 5 | model_id_name=ETTh2 6 | data_name=ETTh2 7 | 8 | model_type='mlp' 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 24 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-96/ettm1.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm1.csv 5 | model_id_name=ETTm1 6 | data_name=ETTm1 7 | 8 | model_type='mlp' 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-96/ettm2.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=ETTm2.csv 5 | model_id_name=ETTm2 6 | data_name=ETTm2 7 | 8 | model_type='mlp' 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 7 \ 25 | --cycle 96 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-96/solar.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=solar_AL.txt 5 | model_id_name=Solar 6 | data_name=Solar 7 | 8 | model_type='mlp' 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 137 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --use_revin 0 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed 31 | done 32 | done 33 | 34 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-96/traffic.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=traffic.csv 5 | model_id_name=traffic 6 | data_name=custom 7 | 8 | 9 | model_type='mlp' 10 | seq_len=96 11 | for pred_len in 96 192 336 720 12 | do 13 | for random_seed in 2024 2025 2026 2027 2028 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 862 \ 26 | --cycle 168 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --itr 1 --batch_size 64 --learning_rate 0.002 --random_seed $random_seed 31 | done 32 | done 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/MLP-Input-96/weather.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=weather.csv 5 | model_id_name=weather 6 | data_name=custom 7 | 8 | model_type='mlp' 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 2025 2026 2027 2028 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 21 \ 25 | --cycle 144 \ 26 | --model_type $model_type \ 27 | --train_epochs 30 \ 28 | --patience 5 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 30 | done 31 | done 32 | 33 | -------------------------------------------------------------------------------- /scripts/CycleNet/PEMS/pems03.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=PEMS03.npz 5 | model_id_name=PEMS03 6 | data_name=PEMS 7 | 8 | 9 | model_type='mlp' 10 | seq_len=96 11 | for pred_len in 12 24 48 96 12 | do 13 | for random_seed in 2024 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 358 \ 26 | --cycle 288 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --use_revin 0 \ 31 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed 32 | done 33 | done 34 | 35 | 36 | #model_type='linear' 37 | #seq_len=96 38 | #for pred_len in 12 24 48 96 39 | #do 40 | #for random_seed in 2024 41 | #do 42 | # python -u run.py \ 43 | # --is_training 1 \ 44 | # --root_path $root_path_name \ 45 | # --data_path $data_path_name \ 46 | # --model_id $model_id_name'_'$seq_len'_'$pred_len \ 47 | # --model $model_name \ 48 | # --data $data_name \ 49 | # --features M \ 50 | # --seq_len $seq_len \ 51 | # --pred_len $pred_len \ 52 | # --enc_in 358 \ 53 | # --cycle 288 \ 54 | # --model_type $model_type \ 55 | # --train_epochs 30 \ 56 | # --patience 5 \ 57 | # --use_revin 0 \ 58 | # --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed 59 | #done 60 | #done 61 | 62 | 63 | -------------------------------------------------------------------------------- /scripts/CycleNet/PEMS/pems04.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=PEMS04.npz 5 | model_id_name=PEMS04 6 | data_name=PEMS 7 | 8 | 9 | model_type='mlp' 10 | seq_len=96 11 | for pred_len in 12 24 48 96 12 | do 13 | for random_seed in 2024 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 307 \ 26 | --cycle 288 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --use_revin 0 \ 31 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed 32 | done 33 | done 34 | 35 | 36 | #model_type='linear' 37 | #seq_len=96 38 | #for pred_len in 12 24 48 96 39 | #do 40 | #for random_seed in 2024 41 | #do 42 | # python -u run.py \ 43 | # --is_training 1 \ 44 | # --root_path $root_path_name \ 45 | # --data_path $data_path_name \ 46 | # --model_id $model_id_name'_'$seq_len'_'$pred_len \ 47 | # --model $model_name \ 48 | # --data $data_name \ 49 | # --features M \ 50 | # --seq_len $seq_len \ 51 | # --pred_len $pred_len \ 52 | # --enc_in 307 \ 53 | # --cycle 288 \ 54 | # --model_type $model_type \ 55 | # --train_epochs 30 \ 56 | # --patience 5 \ 57 | # --use_revin 0 \ 58 | # --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed 59 | #done 60 | #done 61 | 62 | 63 | -------------------------------------------------------------------------------- /scripts/CycleNet/PEMS/pems07.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=PEMS07.npz 5 | model_id_name=PEMS07 6 | data_name=PEMS 7 | 8 | 9 | model_type='mlp' 10 | seq_len=96 11 | for pred_len in 12 24 48 96 12 | do 13 | for random_seed in 2024 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 883 \ 26 | --cycle 288 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --patience 5 \ 30 | --use_revin 0 \ 31 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed 32 | done 33 | done 34 | 35 | 36 | #model_type='linear' 37 | #seq_len=96 38 | #for pred_len in 12 24 48 96 39 | #do 40 | #for random_seed in 2024 41 | #do 42 | # python -u run.py \ 43 | # --is_training 1 \ 44 | # --root_path $root_path_name \ 45 | # --data_path $data_path_name \ 46 | # --model_id $model_id_name'_'$seq_len'_'$pred_len \ 47 | # --model $model_name \ 48 | # --data $data_name \ 49 | # --features M \ 50 | # --seq_len $seq_len \ 51 | # --pred_len $pred_len \ 52 | # --enc_in 883 \ 53 | # --cycle 288 \ 54 | # --model_type $model_type \ 55 | # --train_epochs 30 \ 56 | # --use_revin 0 \ 57 | # --patience 5 \ 58 | # --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed 59 | #done 60 | #done 61 | 62 | 63 | -------------------------------------------------------------------------------- /scripts/CycleNet/PEMS/pems08.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleNet 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=PEMS08.npz 5 | model_id_name=PEMS08 6 | data_name=PEMS 7 | 8 | 9 | model_type='mlp' 10 | seq_len=96 11 | for pred_len in 12 24 48 96 12 | do 13 | for random_seed in 2024 14 | do 15 | python -u run.py \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 20 | --model $model_name \ 21 | --data $data_name \ 22 | --features M \ 23 | --seq_len $seq_len \ 24 | --pred_len $pred_len \ 25 | --enc_in 170 \ 26 | --cycle 288 \ 27 | --model_type $model_type \ 28 | --train_epochs 30 \ 29 | --use_revin 0 \ 30 | --patience 5 \ 31 | --itr 1 --batch_size 64 --learning_rate 0.005 --random_seed $random_seed 32 | done 33 | done 34 | 35 | 36 | #model_type='linear' 37 | #seq_len=96 38 | #for pred_len in 12 24 48 96 39 | #do 40 | #for random_seed in 2024 41 | #do 42 | # python -u run.py \ 43 | # --is_training 1 \ 44 | # --root_path $root_path_name \ 45 | # --data_path $data_path_name \ 46 | # --model_id $model_id_name'_'$seq_len'_'$pred_len \ 47 | # --model $model_name \ 48 | # --data $data_name \ 49 | # --features M \ 50 | # --seq_len $seq_len \ 51 | # --pred_len $pred_len \ 52 | # --enc_in 170 \ 53 | # --cycle 288 \ 54 | # --model_type $model_type \ 55 | # --train_epochs 30 \ 56 | # --patience 5 \ 57 | # --use_revin 0 \ 58 | # --itr 1 --batch_size 64 --learning_rate 0.01 --random_seed $random_seed 59 | #done 60 | #done 61 | 62 | 63 | -------------------------------------------------------------------------------- /scripts/CycleNet/STD/CycleNet.sh: -------------------------------------------------------------------------------- 1 | root_path_name=./dataset/ 2 | data_path_name=ETTh1.csv 3 | model_id_name=ETTh1 4 | data_name=ETTh1 5 | 6 | model_name=CycleNet 7 | model_type='linear' 8 | seq_len=336 9 | for pred_len in 96 192 336 720 10 | do 11 | for random_seed in 2024 12 | do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path $root_path_name \ 16 | --data_path $data_path_name \ 17 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 18 | --model $model_name \ 19 | --data $data_name \ 20 | --features M \ 21 | --seq_len $seq_len \ 22 | --pred_len $pred_len \ 23 | --model_type $model_type \ 24 | --enc_in 7 \ 25 | --cycle 24 \ 26 | --train_epochs 30 \ 27 | --patience 5 \ 28 | --use_revin 0 \ 29 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 30 | done 31 | done 32 | 33 | root_path_name=./dataset/ 34 | data_path_name=ETTh2.csv 35 | model_id_name=ETTh2 36 | data_name=ETTh2 37 | 38 | model_name=CycleNet 39 | model_type='linear' 40 | seq_len=336 41 | for pred_len in 96 192 336 720 42 | do 43 | for random_seed in 2024 44 | do 45 | python -u run.py \ 46 | --is_training 1 \ 47 | --root_path $root_path_name \ 48 | --data_path $data_path_name \ 49 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 50 | --model $model_name \ 51 | --data $data_name \ 52 | --features M \ 53 | --seq_len $seq_len \ 54 | --pred_len $pred_len \ 55 | --model_type $model_type \ 56 | --enc_in 7 \ 57 | --cycle 24 \ 58 | --train_epochs 30 \ 59 | --patience 5 \ 60 | --use_revin 0 \ 61 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 62 | done 63 | done 64 | 65 | root_path_name=./dataset/ 66 | data_path_name=ETTm1.csv 67 | model_id_name=ETTm1 68 | data_name=ETTm1 69 | 70 | model_name=CycleNet 71 | model_type='linear' 72 | seq_len=336 73 | for pred_len in 96 192 336 720 74 | do 75 | for random_seed in 2024 76 | do 77 | python -u run.py \ 78 | --is_training 1 \ 79 | --root_path $root_path_name \ 80 | --data_path $data_path_name \ 81 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 82 | --model $model_name \ 83 | --data $data_name \ 84 | --features M \ 85 | --seq_len $seq_len \ 86 | --pred_len $pred_len \ 87 | --model_type $model_type \ 88 | --enc_in 7 \ 89 | --cycle 96 \ 90 | --train_epochs 30 \ 91 | --patience 5 \ 92 | --use_revin 0 \ 93 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 94 | done 95 | done 96 | 97 | root_path_name=./dataset/ 98 | data_path_name=ETTm2.csv 99 | model_id_name=ETTm2 100 | data_name=ETTm2 101 | 102 | model_name=CycleNet 103 | model_type='linear' 104 | seq_len=336 105 | for pred_len in 96 192 336 720 106 | do 107 | for random_seed in 2024 108 | do 109 | python -u run.py \ 110 | --is_training 1 \ 111 | --root_path $root_path_name \ 112 | --data_path $data_path_name \ 113 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 114 | --model $model_name \ 115 | --data $data_name \ 116 | --features M \ 117 | --seq_len $seq_len \ 118 | --pred_len $pred_len \ 119 | --model_type $model_type \ 120 | --enc_in 7 \ 121 | --cycle 96 \ 122 | --train_epochs 30 \ 123 | --patience 5 \ 124 | --use_revin 0 \ 125 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 126 | done 127 | done 128 | 129 | root_path_name=./dataset/ 130 | data_path_name=weather.csv 131 | model_id_name=weather 132 | data_name=custom 133 | 134 | model_name=CycleNet 135 | model_type='linear' 136 | seq_len=336 137 | for pred_len in 96 192 336 720 138 | do 139 | for random_seed in 2024 140 | do 141 | python -u run.py \ 142 | --is_training 1 \ 143 | --root_path $root_path_name \ 144 | --data_path $data_path_name \ 145 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 146 | --model $model_name \ 147 | --data $data_name \ 148 | --features M \ 149 | --seq_len $seq_len \ 150 | --pred_len $pred_len \ 151 | --model_type $model_type \ 152 | --enc_in 21 \ 153 | --cycle 144 \ 154 | --train_epochs 30 \ 155 | --patience 5 \ 156 | --use_revin 0 \ 157 | --itr 1 --batch_size 256 --learning_rate 0.001 --random_seed $random_seed 158 | done 159 | done 160 | 161 | root_path_name=./dataset/ 162 | data_path_name=solar_AL.txt 163 | model_id_name=Solar 164 | data_name=Solar 165 | 166 | 167 | model_name=CycleNet 168 | model_type='linear' 169 | seq_len=336 170 | for pred_len in 96 192 336 720 171 | do 172 | for random_seed in 2024 173 | do 174 | python -u run.py \ 175 | --is_training 1 \ 176 | --root_path $root_path_name \ 177 | --data_path $data_path_name \ 178 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 179 | --model $model_name \ 180 | --data $data_name \ 181 | --features M \ 182 | --seq_len $seq_len \ 183 | --pred_len $pred_len \ 184 | --model_type $model_type \ 185 | --enc_in 137 \ 186 | --cycle 144 \ 187 | --train_epochs 30 \ 188 | --patience 5 \ 189 | --use_revin 0 \ 190 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 191 | done 192 | done 193 | 194 | 195 | root_path_name=./dataset/ 196 | data_path_name=electricity.csv 197 | model_id_name=Electricity 198 | data_name=custom 199 | 200 | model_name=CycleNet 201 | model_type='linear' 202 | seq_len=336 203 | for pred_len in 96 192 336 720 204 | do 205 | for random_seed in 2024 206 | do 207 | python -u run.py \ 208 | --is_training 1 \ 209 | --root_path $root_path_name \ 210 | --data_path $data_path_name \ 211 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 212 | --model $model_name \ 213 | --data $data_name \ 214 | --features M \ 215 | --seq_len $seq_len \ 216 | --pred_len $pred_len \ 217 | --model_type $model_type \ 218 | --enc_in 321 \ 219 | --cycle 168 \ 220 | --train_epochs 30 \ 221 | --patience 5 \ 222 | --use_revin 0 \ 223 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 224 | done 225 | done 226 | 227 | 228 | 229 | root_path_name=./dataset/ 230 | data_path_name=traffic.csv 231 | model_id_name=traffic 232 | data_name=custom 233 | 234 | model_name=CycleNet 235 | model_type='linear' 236 | seq_len=336 237 | for pred_len in 96 192 336 720 238 | do 239 | for random_seed in 2024 240 | do 241 | python -u run.py \ 242 | --is_training 1 \ 243 | --root_path $root_path_name \ 244 | --data_path $data_path_name \ 245 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 246 | --model $model_name \ 247 | --data $data_name \ 248 | --features M \ 249 | --seq_len $seq_len \ 250 | --pred_len $pred_len \ 251 | --model_type $model_type \ 252 | --enc_in 862 \ 253 | --cycle 168 \ 254 | --train_epochs 30 \ 255 | --patience 5 \ 256 | --use_revin 0 \ 257 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 258 | done 259 | done 260 | -------------------------------------------------------------------------------- /scripts/CycleNet/STD/DLinear.sh: -------------------------------------------------------------------------------- 1 | 2 | root_path_name=./dataset/ 3 | data_path_name=ETTh1.csv 4 | model_id_name=ETTh1 5 | data_name=ETTh1 6 | 7 | model_name=DLinear 8 | seq_len=336 9 | for pred_len in 96 192 336 720 10 | do 11 | for random_seed in 2024 12 | do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path $root_path_name \ 16 | --data_path $data_path_name \ 17 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 18 | --model $model_name \ 19 | --data $data_name \ 20 | --features M \ 21 | --seq_len $seq_len \ 22 | --pred_len $pred_len \ 23 | --enc_in 7 \ 24 | --train_epochs 30 \ 25 | --patience 5 \ 26 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 27 | done 28 | done 29 | 30 | root_path_name=./dataset/ 31 | data_path_name=ETTh2.csv 32 | model_id_name=ETTh2 33 | data_name=ETTh2 34 | 35 | model_name=DLinear 36 | seq_len=336 37 | for pred_len in 96 192 336 720 38 | do 39 | for random_seed in 2024 40 | do 41 | python -u run.py \ 42 | --is_training 1 \ 43 | --root_path $root_path_name \ 44 | --data_path $data_path_name \ 45 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 46 | --model $model_name \ 47 | --data $data_name \ 48 | --features M \ 49 | --seq_len $seq_len \ 50 | --pred_len $pred_len \ 51 | --enc_in 7 \ 52 | --train_epochs 30 \ 53 | --patience 5 \ 54 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 55 | done 56 | done 57 | 58 | root_path_name=./dataset/ 59 | data_path_name=ETTm1.csv 60 | model_id_name=ETTm1 61 | data_name=ETTm1 62 | 63 | model_name=DLinear 64 | seq_len=336 65 | for pred_len in 96 192 336 720 66 | do 67 | for random_seed in 2024 68 | do 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path $root_path_name \ 72 | --data_path $data_path_name \ 73 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 74 | --model $model_name \ 75 | --data $data_name \ 76 | --features M \ 77 | --seq_len $seq_len \ 78 | --pred_len $pred_len \ 79 | --enc_in 7 \ 80 | --train_epochs 30 \ 81 | --patience 5 \ 82 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 83 | done 84 | done 85 | 86 | 87 | root_path_name=./dataset/ 88 | data_path_name=ETTm2.csv 89 | model_id_name=ETTm2 90 | data_name=ETTm2 91 | 92 | model_name=DLinear 93 | seq_len=336 94 | for pred_len in 96 192 336 720 95 | do 96 | for random_seed in 2024 97 | do 98 | python -u run.py \ 99 | --is_training 1 \ 100 | --root_path $root_path_name \ 101 | --data_path $data_path_name \ 102 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 103 | --model $model_name \ 104 | --data $data_name \ 105 | --features M \ 106 | --seq_len $seq_len \ 107 | --pred_len $pred_len \ 108 | --enc_in 7 \ 109 | --train_epochs 30 \ 110 | --patience 5 \ 111 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 112 | done 113 | done 114 | 115 | 116 | root_path_name=./dataset/ 117 | data_path_name=weather.csv 118 | model_id_name=weather 119 | data_name=custom 120 | 121 | model_name=DLinear 122 | seq_len=336 123 | for pred_len in 96 192 336 720 124 | do 125 | for random_seed in 2024 126 | do 127 | python -u run.py \ 128 | --is_training 1 \ 129 | --root_path $root_path_name \ 130 | --data_path $data_path_name \ 131 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 132 | --model $model_name \ 133 | --data $data_name \ 134 | --features M \ 135 | --seq_len $seq_len \ 136 | --pred_len $pred_len \ 137 | --enc_in 21 \ 138 | --train_epochs 30 \ 139 | --patience 5 \ 140 | --itr 1 --batch_size 256 --learning_rate 0.001 --random_seed $random_seed 141 | done 142 | done 143 | 144 | 145 | root_path_name=./dataset/ 146 | data_path_name=solar_AL.txt 147 | model_id_name=Solar 148 | data_name=Solar 149 | 150 | model_name=DLinear 151 | seq_len=336 152 | for pred_len in 96 192 336 720 153 | do 154 | for random_seed in 2024 155 | do 156 | python -u run.py \ 157 | --is_training 1 \ 158 | --root_path $root_path_name \ 159 | --data_path $data_path_name \ 160 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 161 | --model $model_name \ 162 | --data $data_name \ 163 | --features M \ 164 | --seq_len $seq_len \ 165 | --pred_len $pred_len \ 166 | --enc_in 137 \ 167 | --train_epochs 30 \ 168 | --patience 5 \ 169 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 170 | done 171 | done 172 | 173 | 174 | root_path_name=./dataset/ 175 | data_path_name=electricity.csv 176 | model_id_name=Electricity 177 | data_name=custom 178 | 179 | model_name=DLinear 180 | seq_len=336 181 | for pred_len in 96 192 336 720 182 | do 183 | for random_seed in 2024 184 | do 185 | python -u run.py \ 186 | --is_training 1 \ 187 | --root_path $root_path_name \ 188 | --data_path $data_path_name \ 189 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 190 | --model $model_name \ 191 | --data $data_name \ 192 | --features M \ 193 | --seq_len $seq_len \ 194 | --pred_len $pred_len \ 195 | --enc_in 321 \ 196 | --train_epochs 30 \ 197 | --patience 5 \ 198 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 199 | done 200 | done 201 | 202 | root_path_name=./dataset/ 203 | data_path_name=traffic.csv 204 | model_id_name=traffic 205 | data_name=custom 206 | 207 | model_name=DLinear 208 | seq_len=336 209 | for pred_len in 96 192 336 720 210 | do 211 | for random_seed in 2024 212 | do 213 | python -u run.py \ 214 | --is_training 1 \ 215 | --root_path $root_path_name \ 216 | --data_path $data_path_name \ 217 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 218 | --model $model_name \ 219 | --data $data_name \ 220 | --features M \ 221 | --seq_len $seq_len \ 222 | --pred_len $pred_len \ 223 | --enc_in 321 \ 224 | --train_epochs 30 \ 225 | --patience 5 \ 226 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 227 | done 228 | done -------------------------------------------------------------------------------- /scripts/CycleNet/STD/LDLinear.sh: -------------------------------------------------------------------------------- 1 | 2 | root_path_name=./dataset/ 3 | data_path_name=ETTh1.csv 4 | model_id_name=ETTh1 5 | data_name=ETTh1 6 | 7 | model_name=LDLinear 8 | seq_len=336 9 | for pred_len in 96 192 336 720 10 | do 11 | for random_seed in 2024 12 | do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path $root_path_name \ 16 | --data_path $data_path_name \ 17 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 18 | --model $model_name \ 19 | --data $data_name \ 20 | --features M \ 21 | --seq_len $seq_len \ 22 | --pred_len $pred_len \ 23 | --enc_in 7 \ 24 | --train_epochs 30 \ 25 | --patience 5 \ 26 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 27 | done 28 | done 29 | 30 | root_path_name=./dataset/ 31 | data_path_name=ETTh2.csv 32 | model_id_name=ETTh2 33 | data_name=ETTh2 34 | 35 | model_name=LDLinear 36 | seq_len=336 37 | for pred_len in 96 192 336 720 38 | do 39 | for random_seed in 2024 40 | do 41 | python -u run.py \ 42 | --is_training 1 \ 43 | --root_path $root_path_name \ 44 | --data_path $data_path_name \ 45 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 46 | --model $model_name \ 47 | --data $data_name \ 48 | --features M \ 49 | --seq_len $seq_len \ 50 | --pred_len $pred_len \ 51 | --enc_in 7 \ 52 | --train_epochs 30 \ 53 | --patience 5 \ 54 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 55 | done 56 | done 57 | 58 | root_path_name=./dataset/ 59 | data_path_name=ETTm1.csv 60 | model_id_name=ETTm1 61 | data_name=ETTm1 62 | 63 | model_name=LDLinear 64 | seq_len=336 65 | for pred_len in 96 192 336 720 66 | do 67 | for random_seed in 2024 68 | do 69 | python -u run.py \ 70 | --is_training 1 \ 71 | --root_path $root_path_name \ 72 | --data_path $data_path_name \ 73 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 74 | --model $model_name \ 75 | --data $data_name \ 76 | --features M \ 77 | --seq_len $seq_len \ 78 | --pred_len $pred_len \ 79 | --enc_in 7 \ 80 | --train_epochs 30 \ 81 | --patience 5 \ 82 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 83 | done 84 | done 85 | 86 | 87 | root_path_name=./dataset/ 88 | data_path_name=ETTm2.csv 89 | model_id_name=ETTm2 90 | data_name=ETTm2 91 | 92 | model_name=LDLinear 93 | seq_len=336 94 | for pred_len in 96 192 336 720 95 | do 96 | for random_seed in 2024 97 | do 98 | python -u run.py \ 99 | --is_training 1 \ 100 | --root_path $root_path_name \ 101 | --data_path $data_path_name \ 102 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 103 | --model $model_name \ 104 | --data $data_name \ 105 | --features M \ 106 | --seq_len $seq_len \ 107 | --pred_len $pred_len \ 108 | --enc_in 7 \ 109 | --train_epochs 30 \ 110 | --patience 5 \ 111 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 112 | done 113 | done 114 | 115 | 116 | root_path_name=./dataset/ 117 | data_path_name=weather.csv 118 | model_id_name=weather 119 | data_name=custom 120 | 121 | model_name=LDLinear 122 | seq_len=336 123 | for pred_len in 96 192 336 720 124 | do 125 | for random_seed in 2024 126 | do 127 | python -u run.py \ 128 | --is_training 1 \ 129 | --root_path $root_path_name \ 130 | --data_path $data_path_name \ 131 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 132 | --model $model_name \ 133 | --data $data_name \ 134 | --features M \ 135 | --seq_len $seq_len \ 136 | --pred_len $pred_len \ 137 | --enc_in 21 \ 138 | --train_epochs 30 \ 139 | --patience 5 \ 140 | --itr 1 --batch_size 256 --learning_rate 0.001 --random_seed $random_seed 141 | done 142 | done 143 | 144 | 145 | root_path_name=./dataset/ 146 | data_path_name=solar_AL.txt 147 | model_id_name=Solar 148 | data_name=Solar 149 | 150 | model_name=LDLinear 151 | seq_len=336 152 | for pred_len in 96 192 336 720 153 | do 154 | for random_seed in 2024 155 | do 156 | python -u run.py \ 157 | --is_training 1 \ 158 | --root_path $root_path_name \ 159 | --data_path $data_path_name \ 160 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 161 | --model $model_name \ 162 | --data $data_name \ 163 | --features M \ 164 | --seq_len $seq_len \ 165 | --pred_len $pred_len \ 166 | --enc_in 137 \ 167 | --train_epochs 30 \ 168 | --patience 5 \ 169 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 170 | done 171 | done 172 | 173 | 174 | root_path_name=./dataset/ 175 | data_path_name=electricity.csv 176 | model_id_name=Electricity 177 | data_name=custom 178 | 179 | model_name=LDLinear 180 | seq_len=336 181 | for pred_len in 96 192 336 720 182 | do 183 | for random_seed in 2024 184 | do 185 | python -u run.py \ 186 | --is_training 1 \ 187 | --root_path $root_path_name \ 188 | --data_path $data_path_name \ 189 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 190 | --model $model_name \ 191 | --data $data_name \ 192 | --features M \ 193 | --seq_len $seq_len \ 194 | --pred_len $pred_len \ 195 | --enc_in 321 \ 196 | --train_epochs 30 \ 197 | --patience 5 \ 198 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 199 | done 200 | done 201 | 202 | root_path_name=./dataset/ 203 | data_path_name=traffic.csv 204 | model_id_name=traffic 205 | data_name=custom 206 | 207 | model_name=LDLinear 208 | seq_len=336 209 | for pred_len in 96 192 336 720 210 | do 211 | for random_seed in 2024 212 | do 213 | python -u run.py \ 214 | --is_training 1 \ 215 | --root_path $root_path_name \ 216 | --data_path $data_path_name \ 217 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 218 | --model $model_name \ 219 | --data $data_name \ 220 | --features M \ 221 | --seq_len $seq_len \ 222 | --pred_len $pred_len \ 223 | --enc_in 321 \ 224 | --train_epochs 30 \ 225 | --patience 5 \ 226 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 227 | done 228 | done -------------------------------------------------------------------------------- /scripts/CycleNet/STD/Linear.sh: -------------------------------------------------------------------------------- 1 | 2 | root_path_name=./dataset/ 3 | data_path_name=ETTh1.csv 4 | model_id_name=ETTh1 5 | data_name=ETTh1 6 | 7 | model_name=Linear 8 | seq_len=336 9 | for pred_len in 96 192 336 720 10 | do 11 | for random_seed in 2024 12 | do 13 | python -u run.py \ 14 | --is_training 1 \ 15 | --root_path $root_path_name \ 16 | --data_path $data_path_name \ 17 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 18 | --model $model_name \ 19 | --data $data_name \ 20 | --features M \ 21 | --seq_len $seq_len \ 22 | --pred_len $pred_len \ 23 | --enc_in 7 \ 24 | --train_epochs 30 \ 25 | --patience 5 \ 26 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 27 | done 28 | done 29 | 30 | root_path_name=./dataset/ 31 | data_path_name=ETTh2.csv 32 | model_id_name=ETTh2 33 | data_name=ETTh2 34 | 35 | model_name=Linear 36 | seq_len=336 37 | for pred_len in 96 192 336 720 38 | do 39 | for random_seed in 2024 40 | do 41 | python -u run.py \ 42 | --is_training 1 \ 43 | --root_path $root_path_name \ 44 | --data_path $data_path_name \ 45 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 46 | --model $model_name \ 47 | --data $data_name \ 48 | --features M \ 49 | --seq_len $seq_len \ 50 | --pred_len $pred_len \ 51 | --enc_in 7 \ 52 | --train_epochs 30 \ 53 | --patience 5 \ 54 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 55 | done 56 | done 57 | 58 | 59 | root_path_name=./dataset/ 60 | data_path_name=ETTm1.csv 61 | model_id_name=ETTm1 62 | data_name=ETTm1 63 | 64 | model_name=Linear 65 | seq_len=336 66 | for pred_len in 96 192 336 720 67 | do 68 | for random_seed in 2024 69 | do 70 | python -u run.py \ 71 | --is_training 1 \ 72 | --root_path $root_path_name \ 73 | --data_path $data_path_name \ 74 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 75 | --model $model_name \ 76 | --data $data_name \ 77 | --features M \ 78 | --seq_len $seq_len \ 79 | --pred_len $pred_len \ 80 | --enc_in 7 \ 81 | --train_epochs 30 \ 82 | --patience 5 \ 83 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 84 | done 85 | done 86 | 87 | root_path_name=./dataset/ 88 | data_path_name=ETTm2.csv 89 | model_id_name=ETTm2 90 | data_name=ETTm2 91 | 92 | model_name=Linear 93 | seq_len=336 94 | for pred_len in 96 192 336 720 95 | do 96 | for random_seed in 2024 97 | do 98 | python -u run.py \ 99 | --is_training 1 \ 100 | --root_path $root_path_name \ 101 | --data_path $data_path_name \ 102 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 103 | --model $model_name \ 104 | --data $data_name \ 105 | --features M \ 106 | --seq_len $seq_len \ 107 | --pred_len $pred_len \ 108 | --enc_in 7 \ 109 | --train_epochs 30 \ 110 | --patience 5 \ 111 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 112 | done 113 | done 114 | 115 | 116 | root_path_name=./dataset/ 117 | data_path_name=weather.csv 118 | model_id_name=weather 119 | data_name=custom 120 | 121 | model_name=Linear 122 | seq_len=336 123 | for pred_len in 96 192 336 720 124 | do 125 | for random_seed in 2024 126 | do 127 | python -u run.py \ 128 | --is_training 1 \ 129 | --root_path $root_path_name \ 130 | --data_path $data_path_name \ 131 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 132 | --model $model_name \ 133 | --data $data_name \ 134 | --features M \ 135 | --seq_len $seq_len \ 136 | --pred_len $pred_len \ 137 | --enc_in 21 \ 138 | --train_epochs 30 \ 139 | --patience 5 \ 140 | --itr 1 --batch_size 256 --learning_rate 0.001 --random_seed $random_seed 141 | done 142 | done 143 | 144 | 145 | root_path_name=./dataset/ 146 | data_path_name=solar_AL.txt 147 | model_id_name=Solar 148 | data_name=Solar 149 | 150 | model_name=Linear 151 | seq_len=336 152 | for pred_len in 96 192 336 720 153 | do 154 | for random_seed in 2024 155 | do 156 | python -u run.py \ 157 | --is_training 1 \ 158 | --root_path $root_path_name \ 159 | --data_path $data_path_name \ 160 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 161 | --model $model_name \ 162 | --data $data_name \ 163 | --features M \ 164 | --seq_len $seq_len \ 165 | --pred_len $pred_len \ 166 | --enc_in 137 \ 167 | --train_epochs 30 \ 168 | --patience 5 \ 169 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 170 | done 171 | done 172 | 173 | 174 | root_path_name=./dataset/ 175 | data_path_name=electricity.csv 176 | model_id_name=Electricity 177 | data_name=custom 178 | 179 | model_name=Linear 180 | seq_len=336 181 | for pred_len in 96 192 336 720 182 | do 183 | for random_seed in 2024 184 | do 185 | python -u run.py \ 186 | --is_training 1 \ 187 | --root_path $root_path_name \ 188 | --data_path $data_path_name \ 189 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 190 | --model $model_name \ 191 | --data $data_name \ 192 | --features M \ 193 | --seq_len $seq_len \ 194 | --pred_len $pred_len \ 195 | --enc_in 321 \ 196 | --train_epochs 30 \ 197 | --patience 5 \ 198 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 199 | done 200 | done 201 | 202 | root_path_name=./dataset/ 203 | data_path_name=traffic.csv 204 | model_id_name=traffic 205 | data_name=custom 206 | 207 | model_name=Linear 208 | seq_len=336 209 | for pred_len in 96 192 336 720 210 | do 211 | for random_seed in 2024 212 | do 213 | python -u run.py \ 214 | --is_training 1 \ 215 | --root_path $root_path_name \ 216 | --data_path $data_path_name \ 217 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 218 | --model $model_name \ 219 | --data $data_name \ 220 | --features M \ 221 | --seq_len $seq_len \ 222 | --pred_len $pred_len \ 223 | --enc_in 321 \ 224 | --train_epochs 30 \ 225 | --patience 5 \ 226 | --itr 1 --batch_size 256 --learning_rate 0.01 --random_seed $random_seed 227 | done 228 | done -------------------------------------------------------------------------------- /scripts/CycleNet/STD/SparseTSF.sh: -------------------------------------------------------------------------------- 1 | root_path_name=./dataset/ 2 | data_path_name=ETTh1.csv 3 | model_id_name=ETTh1 4 | data_name=ETTh1 5 | 6 | model_name=SparseTSF 7 | seq_len=336 8 | for pred_len in 96 192 336 720 9 | do 10 | for random_seed in 2024 11 | do 12 | python -u run.py \ 13 | --is_training 1 \ 14 | --root_path $root_path_name \ 15 | --data_path $data_path_name \ 16 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 17 | --model $model_name \ 18 | --data $data_name \ 19 | --features M \ 20 | --seq_len $seq_len \ 21 | --pred_len $pred_len \ 22 | --period_len 24 \ 23 | --enc_in 7 \ 24 | --train_epochs 30 \ 25 | --patience 5 \ 26 | --use_revin 0 \ 27 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed 28 | done 29 | done 30 | 31 | 32 | 33 | root_path_name=./dataset/ 34 | data_path_name=ETTh2.csv 35 | model_id_name=ETTh2 36 | data_name=ETTh2 37 | 38 | model_name=SparseTSF 39 | seq_len=336 40 | for pred_len in 96 192 336 720 41 | do 42 | for random_seed in 2024 43 | do 44 | python -u run.py \ 45 | --is_training 1 \ 46 | --root_path $root_path_name \ 47 | --data_path $data_path_name \ 48 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 49 | --model $model_name \ 50 | --data $data_name \ 51 | --features M \ 52 | --seq_len $seq_len \ 53 | --pred_len $pred_len \ 54 | --enc_in 7 \ 55 | --period_len 24 \ 56 | --train_epochs 30 \ 57 | --patience 5 \ 58 | --use_revin 0 \ 59 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed 60 | done 61 | done 62 | 63 | root_path_name=./dataset/ 64 | data_path_name=ETTm1.csv 65 | model_id_name=ETTm1 66 | data_name=ETTm1 67 | 68 | model_name=SparseTSF 69 | seq_len=336 70 | for pred_len in 96 192 336 720 71 | do 72 | for random_seed in 2024 73 | do 74 | python -u run.py \ 75 | --is_training 1 \ 76 | --root_path $root_path_name \ 77 | --data_path $data_path_name \ 78 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 79 | --model $model_name \ 80 | --data $data_name \ 81 | --features M \ 82 | --seq_len $seq_len \ 83 | --pred_len $pred_len \ 84 | --enc_in 7 \ 85 | --period_len 4 \ 86 | --train_epochs 30 \ 87 | --patience 5 \ 88 | --use_revin 0 \ 89 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed 90 | done 91 | done 92 | 93 | root_path_name=./dataset/ 94 | data_path_name=ETTm2.csv 95 | model_id_name=ETTm2 96 | data_name=ETTm2 97 | 98 | model_name=SparseTSF 99 | seq_len=336 100 | for pred_len in 96 192 336 720 101 | do 102 | for random_seed in 2024 103 | do 104 | python -u run.py \ 105 | --is_training 1 \ 106 | --root_path $root_path_name \ 107 | --data_path $data_path_name \ 108 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 109 | --model $model_name \ 110 | --data $data_name \ 111 | --features M \ 112 | --seq_len $seq_len \ 113 | --pred_len $pred_len \ 114 | --period_len 24 \ 115 | --enc_in 7 \ 116 | --train_epochs 30 \ 117 | --patience 5 \ 118 | --use_revin 0 \ 119 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed 120 | done 121 | done 122 | 123 | 124 | 125 | root_path_name=./dataset/ 126 | data_path_name=weather.csv 127 | model_id_name=weather 128 | data_name=custom 129 | 130 | model_name=SparseTSF 131 | seq_len=336 132 | for pred_len in 96 192 336 720 133 | do 134 | for random_seed in 2024 135 | do 136 | python -u run.py \ 137 | --is_training 1 \ 138 | --root_path $root_path_name \ 139 | --data_path $data_path_name \ 140 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 141 | --model $model_name \ 142 | --data $data_name \ 143 | --features M \ 144 | --seq_len $seq_len \ 145 | --pred_len $pred_len \ 146 | --period_len 4 \ 147 | --enc_in 21 \ 148 | --train_epochs 30 \ 149 | --patience 5 \ 150 | --use_revin 0 \ 151 | --itr 1 --batch_size 256 --learning_rate 0.005 --random_seed $random_seed 152 | done 153 | done 154 | 155 | 156 | root_path_name=./dataset/ 157 | data_path_name=solar_AL.txt 158 | model_id_name=Solar 159 | data_name=Solar 160 | 161 | model_name=SparseTSF 162 | seq_len=336 163 | for pred_len in 96 192 336 720 164 | do 165 | for random_seed in 2024 166 | do 167 | python -u run.py \ 168 | --is_training 1 \ 169 | --root_path $root_path_name \ 170 | --data_path $data_path_name \ 171 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 172 | --model $model_name \ 173 | --data $data_name \ 174 | --features M \ 175 | --seq_len $seq_len \ 176 | --pred_len $pred_len \ 177 | --period_len 4 \ 178 | --enc_in 137 \ 179 | --train_epochs 30 \ 180 | --patience 5 \ 181 | --use_revin 0 \ 182 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed 183 | done 184 | done 185 | 186 | 187 | 188 | root_path_name=./dataset/ 189 | data_path_name=electricity.csv 190 | model_id_name=Electricity 191 | data_name=custom 192 | 193 | model_name=SparseTSF 194 | seq_len=336 195 | for pred_len in 96 192 336 720 196 | do 197 | for random_seed in 2024 198 | do 199 | python -u run.py \ 200 | --is_training 1 \ 201 | --root_path $root_path_name \ 202 | --data_path $data_path_name \ 203 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 204 | --model $model_name \ 205 | --data $data_name \ 206 | --features M \ 207 | --seq_len $seq_len \ 208 | --pred_len $pred_len \ 209 | --period_len 24 \ 210 | --enc_in 862 \ 211 | --train_epochs 30 \ 212 | --patience 5 \ 213 | --use_revin 0 \ 214 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed 215 | done 216 | done 217 | 218 | root_path_name=./dataset/ 219 | data_path_name=traffic.csv 220 | model_id_name=traffic 221 | data_name=custom 222 | 223 | 224 | model_name=SparseTSF 225 | seq_len=336 226 | for pred_len in 96 192 336 720 227 | do 228 | for random_seed in 2024 229 | do 230 | python -u run.py \ 231 | --is_training 1 \ 232 | --root_path $root_path_name \ 233 | --data_path $data_path_name \ 234 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 235 | --model $model_name \ 236 | --data $data_name \ 237 | --features M \ 238 | --seq_len $seq_len \ 239 | --pred_len $pred_len \ 240 | --period_len 24 \ 241 | --enc_in 862 \ 242 | --train_epochs 30 \ 243 | --patience 5 \ 244 | --use_revin 0 \ 245 | --itr 1 --batch_size 256 --learning_rate 0.05 --random_seed $random_seed 246 | done 247 | done -------------------------------------------------------------------------------- /scripts/iTransformer/electricity.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleiTransformer 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=electricity.csv 5 | model_id_name=Electricity 6 | data_name=custom 7 | 8 | 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 321 \ 25 | --cycle 168 \ 26 | --d_model 512 \ 27 | --d_ff 512 \ 28 | --dropout 0.1 \ 29 | --e_layers 3 \ 30 | --train_epochs 10 \ 31 | --patience 3 \ 32 | --itr 1 --batch_size 16 --learning_rate 0.0005 --random_seed $random_seed 33 | done 34 | done 35 | 36 | 37 | -------------------------------------------------------------------------------- /scripts/iTransformer/traffic.sh: -------------------------------------------------------------------------------- 1 | model_name=CycleiTransformer 2 | 3 | root_path_name=./dataset/ 4 | data_path_name=traffic.csv 5 | model_id_name=traffic 6 | data_name=custom 7 | 8 | 9 | seq_len=96 10 | for pred_len in 96 192 336 720 11 | do 12 | for random_seed in 2024 13 | do 14 | python -u run.py \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 19 | --model $model_name \ 20 | --data $data_name \ 21 | --features M \ 22 | --seq_len $seq_len \ 23 | --pred_len $pred_len \ 24 | --enc_in 862 \ 25 | --cycle 168 \ 26 | --d_model 512 \ 27 | --d_ff 512 \ 28 | --e_layers 4 \ 29 | --dropout 0.1 \ 30 | --train_epochs 30 \ 31 | --patience 5 \ 32 | --itr 1 --batch_size 16 --learning_rate 0.001 --random_seed $random_seed 33 | done 34 | done 35 | -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 9 | 10 | @property 11 | def mask(self): 12 | return self._mask 13 | 14 | 15 | class ProbMask(): 16 | def __init__(self, B, H, L, index, scores, device="cpu"): 17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 19 | indicator = _mask_ex[torch.arange(B)[:, None, None], 20 | torch.arange(H)[None, :, None], 21 | index, :].to(device) 22 | self._mask = indicator.view(scores.shape).to(device) 23 | 24 | @property 25 | def mask(self): 26 | return self._mask 27 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 11 | d += 1e-12 12 | return 0.01*(u / d).mean(-1) 13 | 14 | 15 | def MAE(pred, true): 16 | return np.mean(np.abs(pred - true)) 17 | 18 | 19 | def MSE(pred, true): 20 | return np.mean((pred - true) ** 2) 21 | 22 | 23 | def RMSE(pred, true): 24 | return np.sqrt(MSE(pred, true)) 25 | 26 | 27 | def MAPE(pred, true): 28 | return np.mean(np.abs((pred - true) / true)) 29 | 30 | 31 | def MSPE(pred, true): 32 | return np.mean(np.square((pred - true) / true)) 33 | 34 | 35 | def metric(pred, true): 36 | mae = MAE(pred, true) 37 | mse = MSE(pred, true) 38 | rmse = RMSE(pred, true) 39 | mape = MAPE(pred, true) 40 | mspe = MSPE(pred, true) 41 | rse = RSE(pred, true) 42 | corr = CORR(pred, true) 43 | 44 | return mae, mse, rmse, mape, mspe, rse, corr 45 | -------------------------------------------------------------------------------- /utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from pandas.tseries import offsets 6 | from pandas.tseries.frequencies import to_offset 7 | 8 | 9 | class TimeFeature: 10 | def __init__(self): 11 | pass 12 | 13 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 14 | pass 15 | 16 | def __repr__(self): 17 | return self.__class__.__name__ + "()" 18 | 19 | 20 | class SecondOfMinute(TimeFeature): 21 | """Minute of hour encoded as value between [-0.5, 0.5]""" 22 | 23 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 24 | return index.second / 59.0 - 0.5 25 | 26 | 27 | class MinuteOfHour(TimeFeature): 28 | """Minute of hour encoded as value between [-0.5, 0.5]""" 29 | 30 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 31 | return index.minute / 59.0 - 0.5 32 | 33 | 34 | class HourOfDay(TimeFeature): 35 | """Hour of day encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.hour / 23.0 - 0.5 39 | 40 | 41 | class DayOfWeek(TimeFeature): 42 | """Hour of day encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.dayofweek / 6.0 - 0.5 46 | 47 | 48 | class DayOfMonth(TimeFeature): 49 | """Day of month encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return (index.day - 1) / 30.0 - 0.5 53 | 54 | 55 | class DayOfYear(TimeFeature): 56 | """Day of year encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return (index.dayofyear - 1) / 365.0 - 0.5 60 | 61 | 62 | class MonthOfYear(TimeFeature): 63 | """Month of year encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.month - 1) / 11.0 - 0.5 67 | 68 | 69 | class WeekOfYear(TimeFeature): 70 | """Week of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.isocalendar().week - 1) / 52.0 - 0.5 74 | 75 | 76 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 77 | """ 78 | Returns a list of time features that will be appropriate for the given frequency string. 79 | Parameters 80 | ---------- 81 | freq_str 82 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 83 | """ 84 | 85 | features_by_offsets = { 86 | offsets.YearEnd: [], 87 | offsets.QuarterEnd: [MonthOfYear], 88 | offsets.MonthEnd: [MonthOfYear], 89 | offsets.Week: [DayOfMonth, WeekOfYear], 90 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 91 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 92 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 93 | offsets.Minute: [ 94 | MinuteOfHour, 95 | HourOfDay, 96 | DayOfWeek, 97 | DayOfMonth, 98 | DayOfYear, 99 | ], 100 | offsets.Second: [ 101 | SecondOfMinute, 102 | MinuteOfHour, 103 | HourOfDay, 104 | DayOfWeek, 105 | DayOfMonth, 106 | DayOfYear, 107 | ], 108 | } 109 | 110 | offset = to_offset(freq_str) 111 | 112 | for offset_type, feature_classes in features_by_offsets.items(): 113 | if isinstance(offset, offset_type): 114 | return [cls() for cls in feature_classes] 115 | 116 | supported_freq_msg = f""" 117 | Unsupported frequency {freq_str} 118 | The following frequencies are supported: 119 | Y - yearly 120 | alias: A 121 | M - monthly 122 | W - weekly 123 | D - daily 124 | B - business days 125 | H - hourly 126 | T - minutely 127 | alias: min 128 | S - secondly 129 | """ 130 | raise RuntimeError(supported_freq_msg) 131 | 132 | 133 | def time_features(dates, freq='h'): 134 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 135 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import time 5 | 6 | plt.switch_backend('agg') 7 | 8 | 9 | def adjust_learning_rate(optimizer, scheduler, epoch, args, printout=True): 10 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 11 | if args.lradj == 'type1': 12 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 13 | elif args.lradj == 'type2': 14 | lr_adjust = { 15 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 16 | 10: 5e-7, 15: 1e-7, 20: 5e-8 17 | } 18 | elif args.lradj == 'type3': 19 | lr_adjust = {epoch: args.learning_rate if epoch < 3 else args.learning_rate * (0.8 ** ((epoch - 3) // 1))} 20 | elif args.lradj == 'constant': 21 | lr_adjust = {epoch: args.learning_rate} 22 | elif args.lradj == '3': 23 | lr_adjust = {epoch: args.learning_rate if epoch < 10 else args.learning_rate*0.1} 24 | elif args.lradj == '4': 25 | lr_adjust = {epoch: args.learning_rate if epoch < 15 else args.learning_rate*0.1} 26 | elif args.lradj == '5': 27 | lr_adjust = {epoch: args.learning_rate if epoch < 25 else args.learning_rate*0.1} 28 | elif args.lradj == '6': 29 | lr_adjust = {epoch: args.learning_rate if epoch < 5 else args.learning_rate*0.1} 30 | elif args.lradj == 'TST': 31 | lr_adjust = {epoch: scheduler.get_last_lr()[0]} 32 | 33 | if epoch in lr_adjust.keys(): 34 | lr = lr_adjust[epoch] 35 | for param_group in optimizer.param_groups: 36 | param_group['lr'] = lr 37 | if printout: print('Updating learning rate to {}'.format(lr)) 38 | 39 | 40 | class EarlyStopping: 41 | def __init__(self, patience=7, verbose=False, delta=0): 42 | self.patience = patience 43 | self.verbose = verbose 44 | self.counter = 0 45 | self.best_score = None 46 | self.early_stop = False 47 | self.val_loss_min = np.Inf 48 | self.delta = delta 49 | 50 | def __call__(self, val_loss, model, path): 51 | score = -val_loss 52 | if self.best_score is None: 53 | self.best_score = score 54 | self.save_checkpoint(val_loss, model, path) 55 | elif score < self.best_score + self.delta: 56 | self.counter += 1 57 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 58 | if self.counter >= self.patience: 59 | self.early_stop = True 60 | else: 61 | self.best_score = score 62 | self.save_checkpoint(val_loss, model, path) 63 | self.counter = 0 64 | 65 | def save_checkpoint(self, val_loss, model, path): 66 | if self.verbose: 67 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 68 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 69 | self.val_loss_min = val_loss 70 | 71 | 72 | class dotdict(dict): 73 | """dot.notation access to dictionary attributes""" 74 | __getattr__ = dict.get 75 | __setattr__ = dict.__setitem__ 76 | __delattr__ = dict.__delitem__ 77 | 78 | 79 | class StandardScaler(): 80 | def __init__(self, mean, std): 81 | self.mean = mean 82 | self.std = std 83 | 84 | def transform(self, data): 85 | return (data - self.mean) / self.std 86 | 87 | def inverse_transform(self, data): 88 | return (data * self.std) + self.mean 89 | 90 | 91 | def visual(true, preds=None, name='./pic/test.pdf'): 92 | """ 93 | Results visualization 94 | """ 95 | plt.figure() 96 | plt.plot(true, label='GroundTruth', linewidth=2) 97 | if preds is not None: 98 | plt.plot(preds, label='Prediction', linewidth=2) 99 | plt.legend() 100 | plt.savefig(name, bbox_inches='tight') 101 | 102 | def test_params_flop(model,x_shape): 103 | """ 104 | If you want to thest former's flop, you need to give default value to inputs in model.forward(), the following code can only pass one argument to forward() 105 | """ 106 | # model_params = 0 107 | # for parameter in model.parameters(): 108 | # model_params += parameter.numel() 109 | # print('INFO: Trainable parameter count: {:.2f}M'.format(model_params / 1000000.0)) 110 | from ptflops import get_model_complexity_info 111 | with torch.cuda.device(0): 112 | macs, params = get_model_complexity_info(model.cuda(), x_shape, as_strings=True, print_per_layer_stat=False) 113 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 114 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 115 | --------------------------------------------------------------------------------