├── README.md ├── data_provider ├── __init__.py ├── data_factory.py ├── data_loader.py ├── m4.py └── uea.py ├── dataset └── EPF │ ├── BE.csv │ ├── DE.csv │ ├── FR.csv │ ├── NP.csv │ └── PJM.csv ├── exp ├── __init__.py ├── exp_anomaly_detection.py ├── exp_basic.py ├── exp_classification.py ├── exp_imputation.py ├── exp_long_term_forecasting.py └── exp_short_term_forecasting.py ├── figures ├── ERA5.png ├── Introduction.png ├── Result_EPF.png ├── Result_Multivariate.png └── TimeXer.png ├── layers ├── AutoCorrelation.py ├── Autoformer_EncDec.py ├── Conv_Blocks.py ├── Crossformer_EncDec.py ├── ETSformer_EncDec.py ├── Embed.py ├── FourierCorrelation.py ├── MultiWaveletCorrelation.py ├── Pyraformer_EncDec.py ├── SelfAttention_Family.py ├── StandardNorm.py ├── Transformer_EncDec.py └── __init__.py ├── models ├── Autoformer.py ├── Crossformer.py ├── DLinear.py ├── ETSformer.py ├── FEDformer.py ├── FiLM.py ├── FreTS.py ├── Informer.py ├── Koopa.py ├── LightTS.py ├── MICN.py ├── Mamba.py ├── MambaSimple.py ├── Nonstationary_Transformer.py ├── PatchTST.py ├── Pyraformer.py ├── Reformer.py ├── SCINet.py ├── SegRNN.py ├── TSMixer.py ├── TemporalFusionTransformer.py ├── TiDE.py ├── TimeMixer.py ├── TimeXer.py ├── TimesNet.py ├── Transformer.py ├── __init__.py └── iTransformer.py ├── requirements.txt ├── run.py ├── scripts ├── forecast_exogenous │ ├── ECL │ │ └── TimeXer.sh │ ├── EPF │ │ └── TimeXer.sh │ ├── ETTh1 │ │ └── TimeXer.sh │ ├── ETTh2 │ │ └── TimeXer.sh │ ├── ETTm1 │ │ └── TimeXer.sh │ ├── ETTm2 │ │ └── TimeXer.sh │ ├── Traffic │ │ └── TimeXer.sh │ ├── Weather │ │ └── TimeXer.sh │ └── meteorology │ │ ├── temp.sh │ │ └── wind.sh └── multivariate │ ├── ECL │ └── TimeXer.sh │ ├── ETT │ ├── TimeXer_ETTh1.sh │ ├── TimeXer_ETTh2.sh │ ├── TimeXer_ETTm1.sh │ └── TimeXer_ETTm2.sh │ ├── Traffic │ └── TimeXer.sh │ └── Weather │ └── TimeXer.sh └── utils ├── __init__.py ├── augmentation.py ├── losses.py ├── m4_summary.py ├── masking.py ├── metrics.py ├── print_args.py ├── timefeatures.py └── tools.py /README.md: -------------------------------------------------------------------------------- 1 | # TimeXer 2 | 3 | This repo is the official implementation for the paper: [TimeXer: Empowering Transformers for Time Series Forecasting with Exogenous Variables](https://arxiv.org/abs/2402.19072). 4 | 5 | ## Introduction 6 | This paper focuses on forecasting with exogenous variables which is a practical forecasting paradigm applied extensively in real scenarios. TimeXer empower the canonical Transformer with the ability to reconcile endogenous and exogenous information without any architectural modifications and achieves consistent state-of-the-art performance on twelve real-world forecasting benchmarks. 7 | 8 |

9 | 10 |

11 | 12 | ## Overall Architecture 13 | TimeXer employs patch-level and variate-level representations respectively for endogenous and exogenous variables, with an endogenous global token as a bridge in-between. With this design, TimeXer can jointly capture intra-endogenous temporal dependencies and exogenous-to-endogenous correlations. 14 | 15 |

16 | 17 |

18 | 19 | ## Usage 20 | 21 | 1. Short-term Electricity Price Forecasting Dataset have alreadly included in "./dataset/EPF". Multivariate datasets can be obtained from [[Google Drive]](https://drive.google.com/drive/folders/13Cg1KYOlzM5C7K8gK8NfC-F3EYxkM3D2?usp=sharing) or [[Baidu Drive]](https://pan.baidu.com/s/1r3KhGd0Q9PJIUZdfEYoymg?pwd=i9iy). 22 | 23 | 2. Install Pytorch and other necessary dependencies. 24 | ``` 25 | pip install -r requirements.txt 26 | ``` 27 | 3. Train and evaluate model. We provide the experiment scripts under the folder ./scripts/. You can reproduce the experiment results as the following examples: 28 | 29 | ``` 30 | bash ./scripts/forecast_exogenous/EPF/TimeXer.sh 31 | ``` 32 | 33 | ## Main Results 34 | We evaluate TimeXer on short-term forecasting with exogenous variables and long-term multivariate forecasting benchmarks. Comprehensive forecasting results demonstrate that TimeXer effectively ingests exogenous information to facilitate the prediction of endogenous series. 35 | 36 | ### Forecasting with Exogenous 37 | 38 |

39 | 40 |

41 | 42 | ### Multivariate Forecasting 43 | 44 |

45 | 46 |

47 | 48 | ## Experiments on Large-scale Meteorology Dataset 49 | In this paper, we build a large-scale weather dataset for forecasting with exogenous variables, where the endogenous series is the hourly temperature of 3,850 stations worldwide obtained from the National Centers for Environmental Information (NCEI), and the exogenous variables are meteorological indicators of its adjacent area from the ERA5 dataset. You can obtain this meteorology dataset from [[Google Drive]](https://drive.google.com/file/d/1EuEedepUV2A_cia1plAHwA6fJXNio47i/view?usp=drive_link). 50 | 51 |

52 | 53 |

54 | 55 | ## Citation 56 | If you find this repo helpful, please cite our paper. 57 | 58 | ``` 59 | @article{wang2024timexer, 60 | title={Timexer: Empowering transformers for time series forecasting with exogenous variables}, 61 | author={Wang, Yuxuan and Wu, Haixu and Dong, Jiaxiang and Liu, Yong and Qiu, Yunzhong and Zhang, Haoran and Wang, Jianmin and Long, Mingsheng}, 62 | journal={Advances in Neural Information Processing Systems}, 63 | year={2024} 64 | } 65 | ``` 66 | 67 | ## Acknowledgement 68 | We appreciate the following GitHub repos a lot for their valuable code and efforts. 69 | 70 | Reformer (https://github.com/lucidrains/reformer-pytorch) 71 | 72 | Informer (https://github.com/zhouhaoyi/Informer2020) 73 | 74 | Autoformer (https://github.com/thuml/Autoformer) 75 | 76 | Stationary (https://github.com/thuml/Nonstationary_Transformers) 77 | 78 | Time-Series-Library (https://github.com/thuml/Time-Series-Library) 79 | 80 | ## Concat 81 | 82 | If you have any questions or want to use the code, please contact wangyuxu22@mails.tsinghua.edu.cn 83 | -------------------------------------------------------------------------------- /data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4, PSMSegLoader, \ 2 | MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader, Dataset_Meteorology 3 | from data_provider.uea import collate_fn 4 | from torch.utils.data import DataLoader 5 | 6 | data_dict = { 7 | 'ETTh1': Dataset_ETT_hour, 8 | 'ETTh2': Dataset_ETT_hour, 9 | 'ETTm1': Dataset_ETT_minute, 10 | 'ETTm2': Dataset_ETT_minute, 11 | 'custom': Dataset_Custom, 12 | 'm4': Dataset_M4, 13 | 'PSM': PSMSegLoader, 14 | 'MSL': MSLSegLoader, 15 | 'SMAP': SMAPSegLoader, 16 | 'SMD': SMDSegLoader, 17 | 'SWAT': SWATSegLoader, 18 | 'UEA': UEAloader, 19 | 'Meteorology' : Dataset_Meteorology 20 | } 21 | 22 | 23 | def data_provider(args, flag): 24 | Data = data_dict[args.data] 25 | timeenc = 0 if args.embed != 'timeF' else 1 26 | 27 | shuffle_flag = False if (flag == 'test' or flag == 'TEST') else True 28 | drop_last = False 29 | batch_size = args.batch_size 30 | freq = args.freq 31 | 32 | if args.task_name == 'anomaly_detection': 33 | drop_last = False 34 | data_set = Data( 35 | args = args, 36 | root_path=args.root_path, 37 | win_size=args.seq_len, 38 | flag=flag, 39 | ) 40 | print(flag, len(data_set)) 41 | data_loader = DataLoader( 42 | data_set, 43 | batch_size=batch_size, 44 | shuffle=shuffle_flag, 45 | num_workers=args.num_workers, 46 | drop_last=drop_last) 47 | return data_set, data_loader 48 | elif args.task_name == 'classification': 49 | drop_last = False 50 | data_set = Data( 51 | args = args, 52 | root_path=args.root_path, 53 | flag=flag, 54 | ) 55 | 56 | data_loader = DataLoader( 57 | data_set, 58 | batch_size=batch_size, 59 | shuffle=shuffle_flag, 60 | num_workers=args.num_workers, 61 | drop_last=drop_last, 62 | collate_fn=lambda x: collate_fn(x, max_len=args.seq_len) 63 | ) 64 | return data_set, data_loader 65 | else: 66 | if args.data == 'm4': 67 | drop_last = False 68 | data_set = Data( 69 | args = args, 70 | root_path=args.root_path, 71 | data_path=args.data_path, 72 | flag=flag, 73 | size=[args.seq_len, args.label_len, args.pred_len], 74 | features=args.features, 75 | target=args.target, 76 | timeenc=timeenc, 77 | freq=freq, 78 | seasonal_patterns=args.seasonal_patterns 79 | ) 80 | print(flag, len(data_set)) 81 | data_loader = DataLoader( 82 | data_set, 83 | batch_size=batch_size, 84 | shuffle=shuffle_flag, 85 | num_workers=args.num_workers, 86 | drop_last=drop_last) 87 | return data_set, data_loader -------------------------------------------------------------------------------- /data_provider/m4.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | M4 Dataset 17 | """ 18 | import logging 19 | import os 20 | from collections import OrderedDict 21 | from dataclasses import dataclass 22 | from glob import glob 23 | 24 | import numpy as np 25 | import pandas as pd 26 | import patoolib 27 | from tqdm import tqdm 28 | import logging 29 | import os 30 | import pathlib 31 | import sys 32 | from urllib import request 33 | 34 | 35 | def url_file_name(url: str) -> str: 36 | """ 37 | Extract file name from url. 38 | 39 | :param url: URL to extract file name from. 40 | :return: File name. 41 | """ 42 | return url.split('/')[-1] if len(url) > 0 else '' 43 | 44 | 45 | def download(url: str, file_path: str) -> None: 46 | """ 47 | Download a file to the given path. 48 | 49 | :param url: URL to download 50 | :param file_path: Where to download the content. 51 | """ 52 | 53 | def progress(count, block_size, total_size): 54 | progress_pct = float(count * block_size) / float(total_size) * 100.0 55 | sys.stdout.write('\rDownloading {} to {} {:.1f}%'.format(url, file_path, progress_pct)) 56 | sys.stdout.flush() 57 | 58 | if not os.path.isfile(file_path): 59 | opener = request.build_opener() 60 | opener.addheaders = [('User-agent', 'Mozilla/5.0')] 61 | request.install_opener(opener) 62 | pathlib.Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True) 63 | f, _ = request.urlretrieve(url, file_path, progress) 64 | sys.stdout.write('\n') 65 | sys.stdout.flush() 66 | file_info = os.stat(f) 67 | logging.info(f'Successfully downloaded {os.path.basename(file_path)} {file_info.st_size} bytes.') 68 | else: 69 | file_info = os.stat(file_path) 70 | logging.info(f'File already exists: {file_path} {file_info.st_size} bytes.') 71 | 72 | 73 | @dataclass() 74 | class M4Dataset: 75 | ids: np.ndarray 76 | groups: np.ndarray 77 | frequencies: np.ndarray 78 | horizons: np.ndarray 79 | values: np.ndarray 80 | 81 | @staticmethod 82 | def load(training: bool = True, dataset_file: str = '../dataset/m4') -> 'M4Dataset': 83 | """ 84 | Load cached dataset. 85 | 86 | :param training: Load training part if_inverted training is True, test part otherwise. 87 | """ 88 | info_file = os.path.join(dataset_file, 'M4-info.csv') 89 | train_cache_file = os.path.join(dataset_file, 'training.npz') 90 | test_cache_file = os.path.join(dataset_file, 'test.npz') 91 | m4_info = pd.read_csv(info_file) 92 | return M4Dataset(ids=m4_info.M4id.values, 93 | groups=m4_info.SP.values, 94 | frequencies=m4_info.Frequency.values, 95 | horizons=m4_info.Horizon.values, 96 | values=np.load( 97 | train_cache_file if training else test_cache_file, 98 | allow_pickle=True)) 99 | 100 | 101 | @dataclass() 102 | class M4Meta: 103 | seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly'] 104 | horizons = [6, 8, 18, 13, 14, 48] 105 | frequencies = [1, 4, 12, 1, 1, 24] 106 | horizons_map = { 107 | 'Yearly': 6, 108 | 'Quarterly': 8, 109 | 'Monthly': 18, 110 | 'Weekly': 13, 111 | 'Daily': 14, 112 | 'Hourly': 48 113 | } # different predict length 114 | frequency_map = { 115 | 'Yearly': 1, 116 | 'Quarterly': 4, 117 | 'Monthly': 12, 118 | 'Weekly': 1, 119 | 'Daily': 1, 120 | 'Hourly': 24 121 | } 122 | history_size = { 123 | 'Yearly': 1.5, 124 | 'Quarterly': 1.5, 125 | 'Monthly': 1.5, 126 | 'Weekly': 10, 127 | 'Daily': 10, 128 | 'Hourly': 10 129 | } # from interpretable.gin 130 | 131 | 132 | def load_m4_info() -> pd.DataFrame: 133 | """ 134 | Load M4Info file. 135 | 136 | :return: Pandas DataFrame of M4Info. 137 | """ 138 | return pd.read_csv(INFO_FILE_PATH) 139 | -------------------------------------------------------------------------------- /data_provider/uea.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | 6 | 7 | def collate_fn(data, max_len=None): 8 | """Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create 9 | Args: 10 | data: len(batch_size) list of tuples (X, y). 11 | - X: torch tensor of shape (seq_length, feat_dim); variable seq_length. 12 | - y: torch tensor of shape (num_labels,) : class indices or numerical targets 13 | (for classification or regression, respectively). num_labels > 1 for multi-task models 14 | max_len: global fixed sequence length. Used for architectures requiring fixed length input, 15 | where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s 16 | Returns: 17 | X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input) 18 | targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output) 19 | target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor 20 | 0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values 21 | padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 means padding 22 | """ 23 | 24 | batch_size = len(data) 25 | features, labels = zip(*data) 26 | 27 | # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension) 28 | lengths = [X.shape[0] for X in features] # original sequence length for each time series 29 | if max_len is None: 30 | max_len = max(lengths) 31 | 32 | X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim) 33 | for i in range(batch_size): 34 | end = min(lengths[i], max_len) 35 | X[i, :end, :] = features[i][:end, :] 36 | 37 | targets = torch.stack(labels, dim=0) # (batch_size, num_labels) 38 | 39 | padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16), 40 | max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep 41 | 42 | return X, targets, padding_masks 43 | 44 | 45 | def padding_mask(lengths, max_len=None): 46 | """ 47 | Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths, 48 | where 1 means keep element at this position (time step) 49 | """ 50 | batch_size = lengths.numel() 51 | max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types 52 | return (torch.arange(0, max_len, device=lengths.device) 53 | .type_as(lengths) 54 | .repeat(batch_size, 1) 55 | .lt(lengths.unsqueeze(1))) 56 | 57 | 58 | class Normalizer(object): 59 | """ 60 | Normalizes dataframe across ALL contained rows (time steps). Different from per-sample normalization. 61 | """ 62 | 63 | def __init__(self, norm_type='standardization', mean=None, std=None, min_val=None, max_val=None): 64 | """ 65 | Args: 66 | norm_type: choose from: 67 | "standardization", "minmax": normalizes dataframe across ALL contained rows (time steps) 68 | "per_sample_std", "per_sample_minmax": normalizes each sample separately (i.e. across only its own rows) 69 | mean, std, min_val, max_val: optional (num_feat,) Series of pre-computed values 70 | """ 71 | 72 | self.norm_type = norm_type 73 | self.mean = mean 74 | self.std = std 75 | self.min_val = min_val 76 | self.max_val = max_val 77 | 78 | def normalize(self, df): 79 | """ 80 | Args: 81 | df: input dataframe 82 | Returns: 83 | df: normalized dataframe 84 | """ 85 | if self.norm_type == "standardization": 86 | if self.mean is None: 87 | self.mean = df.mean() 88 | self.std = df.std() 89 | return (df - self.mean) / (self.std + np.finfo(float).eps) 90 | 91 | elif self.norm_type == "minmax": 92 | if self.max_val is None: 93 | self.max_val = df.max() 94 | self.min_val = df.min() 95 | return (df - self.min_val) / (self.max_val - self.min_val + np.finfo(float).eps) 96 | 97 | elif self.norm_type == "per_sample_std": 98 | grouped = df.groupby(by=df.index) 99 | return (df - grouped.transform('mean')) / grouped.transform('std') 100 | 101 | elif self.norm_type == "per_sample_minmax": 102 | grouped = df.groupby(by=df.index) 103 | min_vals = grouped.transform('min') 104 | return (df - min_vals) / (grouped.transform('max') - min_vals + np.finfo(float).eps) 105 | 106 | else: 107 | raise (NameError(f'Normalize method "{self.norm_type}" not implemented')) 108 | 109 | 110 | def interpolate_missing(y): 111 | """ 112 | Replaces NaN values in pd.Series `y` using linear interpolation 113 | """ 114 | if y.isna().any(): 115 | y = y.interpolate(method='linear', limit_direction='both') 116 | return y 117 | 118 | 119 | def subsample(y, limit=256, factor=2): 120 | """ 121 | If a given Series is longer than `limit`, returns subsampled sequence by the specified integer factor 122 | """ 123 | if len(y) > limit: 124 | return y[::factor].reset_index(drop=True) 125 | return y 126 | -------------------------------------------------------------------------------- /exp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/TimeXer/76011909357972bd55a27adba2e1be994d81b327/exp/__init__.py -------------------------------------------------------------------------------- /exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \ 4 | Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \ 5 | Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, MambaSimple, TemporalFusionTransformer, SCINet, TimeXer 6 | 7 | 8 | class Exp_Basic(object): 9 | def __init__(self, args): 10 | self.args = args 11 | self.model_dict = { 12 | 'TimesNet': TimesNet, 13 | 'Autoformer': Autoformer, 14 | 'Transformer': Transformer, 15 | 'Nonstationary_Transformer': Nonstationary_Transformer, 16 | 'DLinear': DLinear, 17 | 'FEDformer': FEDformer, 18 | 'Informer': Informer, 19 | 'LightTS': LightTS, 20 | 'Reformer': Reformer, 21 | 'ETSformer': ETSformer, 22 | 'PatchTST': PatchTST, 23 | 'Pyraformer': Pyraformer, 24 | 'MICN': MICN, 25 | 'Crossformer': Crossformer, 26 | 'FiLM': FiLM, 27 | 'iTransformer': iTransformer, 28 | 'Koopa': Koopa, 29 | 'TiDE': TiDE, 30 | 'FreTS': FreTS, 31 | 'MambaSimple': MambaSimple, 32 | 'TimeMixer': TimeMixer, 33 | 'TSMixer': TSMixer, 34 | 'SegRNN': SegRNN, 35 | 'TemporalFusionTransformer': TemporalFusionTransformer, 36 | "SCINet": SCINet, 37 | 'TimeXer': TimeXer 38 | } 39 | if args.model == 'Mamba': 40 | print('Please make sure you have successfully installed mamba_ssm') 41 | from models import Mamba 42 | self.model_dict['Mamba'] = Mamba 43 | 44 | self.device = self._acquire_device() 45 | self.model = self._build_model().to(self.device) 46 | 47 | def _build_model(self): 48 | raise NotImplementedError 49 | return None 50 | 51 | def _acquire_device(self): 52 | if self.args.use_gpu: 53 | os.environ["CUDA_VISIBLE_DEVICES"] = str( 54 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices 55 | device = torch.device('cuda:{}'.format(self.args.gpu)) 56 | print('Use GPU: cuda:{}'.format(self.args.gpu)) 57 | else: 58 | device = torch.device('cpu') 59 | print('Use CPU') 60 | return device 61 | 62 | def _get_data(self): 63 | pass 64 | 65 | def vali(self): 66 | pass 67 | 68 | def train(self): 69 | pass 70 | 71 | def test(self): 72 | pass 73 | -------------------------------------------------------------------------------- /figures/ERA5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/TimeXer/76011909357972bd55a27adba2e1be994d81b327/figures/ERA5.png -------------------------------------------------------------------------------- /figures/Introduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/TimeXer/76011909357972bd55a27adba2e1be994d81b327/figures/Introduction.png -------------------------------------------------------------------------------- /figures/Result_EPF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/TimeXer/76011909357972bd55a27adba2e1be994d81b327/figures/Result_EPF.png -------------------------------------------------------------------------------- /figures/Result_Multivariate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/TimeXer/76011909357972bd55a27adba2e1be994d81b327/figures/Result_Multivariate.png -------------------------------------------------------------------------------- /figures/TimeXer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/TimeXer/76011909357972bd55a27adba2e1be994d81b327/figures/TimeXer.png -------------------------------------------------------------------------------- /layers/AutoCorrelation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import math 7 | from math import sqrt 8 | import os 9 | 10 | 11 | class AutoCorrelation(nn.Module): 12 | """ 13 | AutoCorrelation Mechanism with the following two phases: 14 | (1) period-based dependencies discovery 15 | (2) time delay aggregation 16 | This block can replace the self-attention family mechanism seamlessly. 17 | """ 18 | 19 | def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False): 20 | super(AutoCorrelation, self).__init__() 21 | self.factor = factor 22 | self.scale = scale 23 | self.mask_flag = mask_flag 24 | self.output_attention = output_attention 25 | self.dropout = nn.Dropout(attention_dropout) 26 | 27 | def time_delay_agg_training(self, values, corr): 28 | """ 29 | SpeedUp version of Autocorrelation (a batch-normalization style design) 30 | This is for the training phase. 31 | """ 32 | head = values.shape[1] 33 | channel = values.shape[2] 34 | length = values.shape[3] 35 | # find top k 36 | top_k = int(self.factor * math.log(length)) 37 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 38 | index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] 39 | weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) 40 | # update corr 41 | tmp_corr = torch.softmax(weights, dim=-1) 42 | # aggregation 43 | tmp_values = values 44 | delays_agg = torch.zeros_like(values).float() 45 | for i in range(top_k): 46 | pattern = torch.roll(tmp_values, -int(index[i]), -1) 47 | delays_agg = delays_agg + pattern * \ 48 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 49 | return delays_agg 50 | 51 | def time_delay_agg_inference(self, values, corr): 52 | """ 53 | SpeedUp version of Autocorrelation (a batch-normalization style design) 54 | This is for the inference phase. 55 | """ 56 | batch = values.shape[0] 57 | head = values.shape[1] 58 | channel = values.shape[2] 59 | length = values.shape[3] 60 | # index init 61 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 62 | # find top k 63 | top_k = int(self.factor * math.log(length)) 64 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 65 | weights, delay = torch.topk(mean_value, top_k, dim=-1) 66 | # update corr 67 | tmp_corr = torch.softmax(weights, dim=-1) 68 | # aggregation 69 | tmp_values = values.repeat(1, 1, 1, 2) 70 | delays_agg = torch.zeros_like(values).float() 71 | for i in range(top_k): 72 | tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) 73 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 74 | delays_agg = delays_agg + pattern * \ 75 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 76 | return delays_agg 77 | 78 | def time_delay_agg_full(self, values, corr): 79 | """ 80 | Standard version of Autocorrelation 81 | """ 82 | batch = values.shape[0] 83 | head = values.shape[1] 84 | channel = values.shape[2] 85 | length = values.shape[3] 86 | # index init 87 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 88 | # find top k 89 | top_k = int(self.factor * math.log(length)) 90 | weights, delay = torch.topk(corr, top_k, dim=-1) 91 | # update corr 92 | tmp_corr = torch.softmax(weights, dim=-1) 93 | # aggregation 94 | tmp_values = values.repeat(1, 1, 1, 2) 95 | delays_agg = torch.zeros_like(values).float() 96 | for i in range(top_k): 97 | tmp_delay = init_index + delay[..., i].unsqueeze(-1) 98 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 99 | delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) 100 | return delays_agg 101 | 102 | def forward(self, queries, keys, values, attn_mask): 103 | B, L, H, E = queries.shape 104 | _, S, _, D = values.shape 105 | if L > S: 106 | zeros = torch.zeros_like(queries[:, :(L - S), :]).float() 107 | values = torch.cat([values, zeros], dim=1) 108 | keys = torch.cat([keys, zeros], dim=1) 109 | else: 110 | values = values[:, :L, :, :] 111 | keys = keys[:, :L, :, :] 112 | 113 | # period-based dependencies 114 | q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) 115 | k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) 116 | res = q_fft * torch.conj(k_fft) 117 | corr = torch.fft.irfft(res, dim=-1) 118 | 119 | # time delay agg 120 | if self.training: 121 | V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 122 | else: 123 | V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 124 | 125 | if self.output_attention: 126 | return (V.contiguous(), corr.permute(0, 3, 1, 2)) 127 | else: 128 | return (V.contiguous(), None) 129 | 130 | 131 | class AutoCorrelationLayer(nn.Module): 132 | def __init__(self, correlation, d_model, n_heads, d_keys=None, 133 | d_values=None): 134 | super(AutoCorrelationLayer, self).__init__() 135 | 136 | d_keys = d_keys or (d_model // n_heads) 137 | d_values = d_values or (d_model // n_heads) 138 | 139 | self.inner_correlation = correlation 140 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 141 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 142 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 143 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 144 | self.n_heads = n_heads 145 | 146 | def forward(self, queries, keys, values, attn_mask): 147 | B, L, _ = queries.shape 148 | _, S, _ = keys.shape 149 | H = self.n_heads 150 | 151 | queries = self.query_projection(queries).view(B, L, H, -1) 152 | keys = self.key_projection(keys).view(B, S, H, -1) 153 | values = self.value_projection(values).view(B, S, H, -1) 154 | 155 | out, attn = self.inner_correlation( 156 | queries, 157 | keys, 158 | values, 159 | attn_mask 160 | ) 161 | out = out.view(B, L, -1) 162 | 163 | return self.out_projection(out), attn 164 | -------------------------------------------------------------------------------- /layers/Autoformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class my_Layernorm(nn.Module): 7 | """ 8 | Special designed layernorm for the seasonal part 9 | """ 10 | 11 | def __init__(self, channels): 12 | super(my_Layernorm, self).__init__() 13 | self.layernorm = nn.LayerNorm(channels) 14 | 15 | def forward(self, x): 16 | x_hat = self.layernorm(x) 17 | bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) 18 | return x_hat - bias 19 | 20 | 21 | class moving_avg(nn.Module): 22 | """ 23 | Moving average block to highlight the trend of time series 24 | """ 25 | 26 | def __init__(self, kernel_size, stride): 27 | super(moving_avg, self).__init__() 28 | self.kernel_size = kernel_size 29 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 30 | 31 | def forward(self, x): 32 | # padding on the both ends of time series 33 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 34 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 35 | x = torch.cat([front, x, end], dim=1) 36 | x = self.avg(x.permute(0, 2, 1)) 37 | x = x.permute(0, 2, 1) 38 | return x 39 | 40 | 41 | class series_decomp(nn.Module): 42 | """ 43 | Series decomposition block 44 | """ 45 | 46 | def __init__(self, kernel_size): 47 | super(series_decomp, self).__init__() 48 | self.moving_avg = moving_avg(kernel_size, stride=1) 49 | 50 | def forward(self, x): 51 | moving_mean = self.moving_avg(x) 52 | res = x - moving_mean 53 | return res, moving_mean 54 | 55 | 56 | class series_decomp_multi(nn.Module): 57 | """ 58 | Multiple Series decomposition block from FEDformer 59 | """ 60 | 61 | def __init__(self, kernel_size): 62 | super(series_decomp_multi, self).__init__() 63 | self.kernel_size = kernel_size 64 | self.series_decomp = [series_decomp(kernel) for kernel in kernel_size] 65 | 66 | def forward(self, x): 67 | moving_mean = [] 68 | res = [] 69 | for func in self.series_decomp: 70 | sea, moving_avg = func(x) 71 | moving_mean.append(moving_avg) 72 | res.append(sea) 73 | 74 | sea = sum(res) / len(res) 75 | moving_mean = sum(moving_mean) / len(moving_mean) 76 | return sea, moving_mean 77 | 78 | 79 | class EncoderLayer(nn.Module): 80 | """ 81 | Autoformer encoder layer with the progressive decomposition architecture 82 | """ 83 | 84 | def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): 85 | super(EncoderLayer, self).__init__() 86 | d_ff = d_ff or 4 * d_model 87 | self.attention = attention 88 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 89 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 90 | self.decomp1 = series_decomp(moving_avg) 91 | self.decomp2 = series_decomp(moving_avg) 92 | self.dropout = nn.Dropout(dropout) 93 | self.activation = F.relu if activation == "relu" else F.gelu 94 | 95 | def forward(self, x, attn_mask=None): 96 | new_x, attn = self.attention( 97 | x, x, x, 98 | attn_mask=attn_mask 99 | ) 100 | x = x + self.dropout(new_x) 101 | x, _ = self.decomp1(x) 102 | y = x 103 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 104 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 105 | res, _ = self.decomp2(x + y) 106 | return res, attn 107 | 108 | 109 | class Encoder(nn.Module): 110 | """ 111 | Autoformer encoder 112 | """ 113 | 114 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 115 | super(Encoder, self).__init__() 116 | self.attn_layers = nn.ModuleList(attn_layers) 117 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 118 | self.norm = norm_layer 119 | 120 | def forward(self, x, attn_mask=None): 121 | attns = [] 122 | if self.conv_layers is not None: 123 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 124 | x, attn = attn_layer(x, attn_mask=attn_mask) 125 | x = conv_layer(x) 126 | attns.append(attn) 127 | x, attn = self.attn_layers[-1](x) 128 | attns.append(attn) 129 | else: 130 | for attn_layer in self.attn_layers: 131 | x, attn = attn_layer(x, attn_mask=attn_mask) 132 | attns.append(attn) 133 | 134 | if self.norm is not None: 135 | x = self.norm(x) 136 | 137 | return x, attns 138 | 139 | 140 | class DecoderLayer(nn.Module): 141 | """ 142 | Autoformer decoder layer with the progressive decomposition architecture 143 | """ 144 | 145 | def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None, 146 | moving_avg=25, dropout=0.1, activation="relu"): 147 | super(DecoderLayer, self).__init__() 148 | d_ff = d_ff or 4 * d_model 149 | self.self_attention = self_attention 150 | self.cross_attention = cross_attention 151 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 152 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 153 | self.decomp1 = series_decomp(moving_avg) 154 | self.decomp2 = series_decomp(moving_avg) 155 | self.decomp3 = series_decomp(moving_avg) 156 | self.dropout = nn.Dropout(dropout) 157 | self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1, 158 | padding_mode='circular', bias=False) 159 | self.activation = F.relu if activation == "relu" else F.gelu 160 | 161 | def forward(self, x, cross, x_mask=None, cross_mask=None): 162 | x = x + self.dropout(self.self_attention( 163 | x, x, x, 164 | attn_mask=x_mask 165 | )[0]) 166 | x, trend1 = self.decomp1(x) 167 | x = x + self.dropout(self.cross_attention( 168 | x, cross, cross, 169 | attn_mask=cross_mask 170 | )[0]) 171 | x, trend2 = self.decomp2(x) 172 | y = x 173 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 174 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 175 | x, trend3 = self.decomp3(x + y) 176 | 177 | residual_trend = trend1 + trend2 + trend3 178 | residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) 179 | return x, residual_trend 180 | 181 | 182 | class Decoder(nn.Module): 183 | """ 184 | Autoformer encoder 185 | """ 186 | 187 | def __init__(self, layers, norm_layer=None, projection=None): 188 | super(Decoder, self).__init__() 189 | self.layers = nn.ModuleList(layers) 190 | self.norm = norm_layer 191 | self.projection = projection 192 | 193 | def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): 194 | for layer in self.layers: 195 | x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 196 | trend = trend + residual_trend 197 | 198 | if self.norm is not None: 199 | x = self.norm(x) 200 | 201 | if self.projection is not None: 202 | x = self.projection(x) 203 | return x, trend 204 | -------------------------------------------------------------------------------- /layers/Conv_Blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Inception_Block_V1(nn.Module): 6 | def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): 7 | super(Inception_Block_V1, self).__init__() 8 | self.in_channels = in_channels 9 | self.out_channels = out_channels 10 | self.num_kernels = num_kernels 11 | kernels = [] 12 | for i in range(self.num_kernels): 13 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)) 14 | self.kernels = nn.ModuleList(kernels) 15 | if init_weight: 16 | self._initialize_weights() 17 | 18 | def _initialize_weights(self): 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 22 | if m.bias is not None: 23 | nn.init.constant_(m.bias, 0) 24 | 25 | def forward(self, x): 26 | res_list = [] 27 | for i in range(self.num_kernels): 28 | res_list.append(self.kernels[i](x)) 29 | res = torch.stack(res_list, dim=-1).mean(-1) 30 | return res 31 | 32 | 33 | class Inception_Block_V2(nn.Module): 34 | def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): 35 | super(Inception_Block_V2, self).__init__() 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | self.num_kernels = num_kernels 39 | kernels = [] 40 | for i in range(self.num_kernels // 2): 41 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1])) 42 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0])) 43 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) 44 | self.kernels = nn.ModuleList(kernels) 45 | if init_weight: 46 | self._initialize_weights() 47 | 48 | def _initialize_weights(self): 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 52 | if m.bias is not None: 53 | nn.init.constant_(m.bias, 0) 54 | 55 | def forward(self, x): 56 | res_list = [] 57 | for i in range(self.num_kernels // 2 * 2 + 1): 58 | res_list.append(self.kernels[i](x)) 59 | res = torch.stack(res_list, dim=-1).mean(-1) 60 | return res 61 | -------------------------------------------------------------------------------- /layers/Crossformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange, repeat 4 | from layers.SelfAttention_Family import TwoStageAttentionLayer 5 | 6 | 7 | class SegMerging(nn.Module): 8 | def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm): 9 | super().__init__() 10 | self.d_model = d_model 11 | self.win_size = win_size 12 | self.linear_trans = nn.Linear(win_size * d_model, d_model) 13 | self.norm = norm_layer(win_size * d_model) 14 | 15 | def forward(self, x): 16 | batch_size, ts_d, seg_num, d_model = x.shape 17 | pad_num = seg_num % self.win_size 18 | if pad_num != 0: 19 | pad_num = self.win_size - pad_num 20 | x = torch.cat((x, x[:, :, -pad_num:, :]), dim=-2) 21 | 22 | seg_to_merge = [] 23 | for i in range(self.win_size): 24 | seg_to_merge.append(x[:, :, i::self.win_size, :]) 25 | x = torch.cat(seg_to_merge, -1) 26 | 27 | x = self.norm(x) 28 | x = self.linear_trans(x) 29 | 30 | return x 31 | 32 | 33 | class scale_block(nn.Module): 34 | def __init__(self, configs, win_size, d_model, n_heads, d_ff, depth, dropout, \ 35 | seg_num=10, factor=10): 36 | super(scale_block, self).__init__() 37 | 38 | if win_size > 1: 39 | self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm) 40 | else: 41 | self.merge_layer = None 42 | 43 | self.encode_layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.encode_layers.append(TwoStageAttentionLayer(configs, seg_num, factor, d_model, n_heads, \ 47 | d_ff, dropout)) 48 | 49 | def forward(self, x, attn_mask=None, tau=None, delta=None): 50 | _, ts_dim, _, _ = x.shape 51 | 52 | if self.merge_layer is not None: 53 | x = self.merge_layer(x) 54 | 55 | for layer in self.encode_layers: 56 | x = layer(x) 57 | 58 | return x, None 59 | 60 | 61 | class Encoder(nn.Module): 62 | def __init__(self, attn_layers): 63 | super(Encoder, self).__init__() 64 | self.encode_blocks = nn.ModuleList(attn_layers) 65 | 66 | def forward(self, x): 67 | encode_x = [] 68 | encode_x.append(x) 69 | 70 | for block in self.encode_blocks: 71 | x, attns = block(x) 72 | encode_x.append(x) 73 | 74 | return encode_x, None 75 | 76 | 77 | class DecoderLayer(nn.Module): 78 | def __init__(self, self_attention, cross_attention, seg_len, d_model, d_ff=None, dropout=0.1): 79 | super(DecoderLayer, self).__init__() 80 | self.self_attention = self_attention 81 | self.cross_attention = cross_attention 82 | self.norm1 = nn.LayerNorm(d_model) 83 | self.norm2 = nn.LayerNorm(d_model) 84 | self.dropout = nn.Dropout(dropout) 85 | self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model), 86 | nn.GELU(), 87 | nn.Linear(d_model, d_model)) 88 | self.linear_pred = nn.Linear(d_model, seg_len) 89 | 90 | def forward(self, x, cross): 91 | batch = x.shape[0] 92 | x = self.self_attention(x) 93 | x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model') 94 | 95 | cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model') 96 | tmp, attn = self.cross_attention(x, cross, cross, None, None, None,) 97 | x = x + self.dropout(tmp) 98 | y = x = self.norm1(x) 99 | y = self.MLP1(y) 100 | dec_output = self.norm2(x + y) 101 | 102 | dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b=batch) 103 | layer_predict = self.linear_pred(dec_output) 104 | layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len') 105 | 106 | return dec_output, layer_predict 107 | 108 | 109 | class Decoder(nn.Module): 110 | def __init__(self, layers): 111 | super(Decoder, self).__init__() 112 | self.decode_layers = nn.ModuleList(layers) 113 | 114 | 115 | def forward(self, x, cross): 116 | final_predict = None 117 | i = 0 118 | 119 | ts_d = x.shape[1] 120 | for layer in self.decode_layers: 121 | cross_enc = cross[i] 122 | x, layer_predict = layer(x, cross_enc) 123 | if final_predict is None: 124 | final_predict = layer_predict 125 | else: 126 | final_predict = final_predict + layer_predict 127 | i += 1 128 | 129 | final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d=ts_d) 130 | 131 | return final_predict 132 | -------------------------------------------------------------------------------- /layers/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm 5 | import math 6 | 7 | 8 | class PositionalEmbedding(nn.Module): 9 | def __init__(self, d_model, max_len=5000): 10 | super(PositionalEmbedding, self).__init__() 11 | # Compute the positional encodings once in log space. 12 | pe = torch.zeros(max_len, d_model).float() 13 | pe.require_grad = False 14 | 15 | position = torch.arange(0, max_len).float().unsqueeze(1) 16 | div_term = (torch.arange(0, d_model, 2).float() 17 | * -(math.log(10000.0) / d_model)).exp() 18 | 19 | pe[:, 0::2] = torch.sin(position * div_term) 20 | pe[:, 1::2] = torch.cos(position * div_term) 21 | 22 | pe = pe.unsqueeze(0) 23 | self.register_buffer('pe', pe) 24 | 25 | def forward(self, x): 26 | return self.pe[:, :x.size(1)] 27 | 28 | 29 | class TokenEmbedding(nn.Module): 30 | def __init__(self, c_in, d_model): 31 | super(TokenEmbedding, self).__init__() 32 | padding = 1 if torch.__version__ >= '1.5.0' else 2 33 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 34 | kernel_size=3, padding=padding, padding_mode='circular', bias=False) 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv1d): 37 | nn.init.kaiming_normal_( 38 | m.weight, mode='fan_in', nonlinearity='leaky_relu') 39 | 40 | def forward(self, x): 41 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 42 | return x 43 | 44 | 45 | class FixedEmbedding(nn.Module): 46 | def __init__(self, c_in, d_model): 47 | super(FixedEmbedding, self).__init__() 48 | 49 | w = torch.zeros(c_in, d_model).float() 50 | w.require_grad = False 51 | 52 | position = torch.arange(0, c_in).float().unsqueeze(1) 53 | div_term = (torch.arange(0, d_model, 2).float() 54 | * -(math.log(10000.0) / d_model)).exp() 55 | 56 | w[:, 0::2] = torch.sin(position * div_term) 57 | w[:, 1::2] = torch.cos(position * div_term) 58 | 59 | self.emb = nn.Embedding(c_in, d_model) 60 | self.emb.weight = nn.Parameter(w, requires_grad=False) 61 | 62 | def forward(self, x): 63 | return self.emb(x).detach() 64 | 65 | 66 | class TemporalEmbedding(nn.Module): 67 | def __init__(self, d_model, embed_type='fixed', freq='h'): 68 | super(TemporalEmbedding, self).__init__() 69 | 70 | minute_size = 4 71 | hour_size = 24 72 | weekday_size = 7 73 | day_size = 32 74 | month_size = 13 75 | 76 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding 77 | if freq == 't': 78 | self.minute_embed = Embed(minute_size, d_model) 79 | self.hour_embed = Embed(hour_size, d_model) 80 | self.weekday_embed = Embed(weekday_size, d_model) 81 | self.day_embed = Embed(day_size, d_model) 82 | self.month_embed = Embed(month_size, d_model) 83 | 84 | def forward(self, x): 85 | x = x.long() 86 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr( 87 | self, 'minute_embed') else 0. 88 | hour_x = self.hour_embed(x[:, :, 3]) 89 | weekday_x = self.weekday_embed(x[:, :, 2]) 90 | day_x = self.day_embed(x[:, :, 1]) 91 | month_x = self.month_embed(x[:, :, 0]) 92 | 93 | return hour_x + weekday_x + day_x + month_x + minute_x 94 | 95 | 96 | class TimeFeatureEmbedding(nn.Module): 97 | def __init__(self, d_model, embed_type='timeF', freq='h'): 98 | super(TimeFeatureEmbedding, self).__init__() 99 | 100 | freq_map = {'h': 4, 't': 5, 's': 6, 101 | 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 102 | d_inp = freq_map[freq] 103 | self.embed = nn.Linear(d_inp, d_model, bias=False) 104 | 105 | def forward(self, x): 106 | return self.embed(x) 107 | 108 | 109 | class DataEmbedding(nn.Module): 110 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 111 | super(DataEmbedding, self).__init__() 112 | 113 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 114 | self.position_embedding = PositionalEmbedding(d_model=d_model) 115 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 116 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 117 | d_model=d_model, embed_type=embed_type, freq=freq) 118 | self.dropout = nn.Dropout(p=dropout) 119 | 120 | def forward(self, x, x_mark): 121 | if x_mark is None: 122 | x = self.value_embedding(x) + self.position_embedding(x) 123 | else: 124 | x = self.value_embedding( 125 | x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 126 | return self.dropout(x) 127 | 128 | 129 | class DataEmbedding_inverted(nn.Module): 130 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 131 | super(DataEmbedding_inverted, self).__init__() 132 | self.value_embedding = nn.Linear(c_in, d_model) 133 | self.dropout = nn.Dropout(p=dropout) 134 | 135 | def forward(self, x, x_mark): 136 | x = x.permute(0, 2, 1) 137 | # x: [Batch Variate Time] 138 | if x_mark is None: 139 | x = self.value_embedding(x) 140 | else: 141 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) 142 | # x: [Batch Variate d_model] 143 | return self.dropout(x) 144 | 145 | 146 | class DataEmbedding_wo_pos(nn.Module): 147 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 148 | super(DataEmbedding_wo_pos, self).__init__() 149 | 150 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 151 | self.position_embedding = PositionalEmbedding(d_model=d_model) 152 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 153 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 154 | d_model=d_model, embed_type=embed_type, freq=freq) 155 | self.dropout = nn.Dropout(p=dropout) 156 | 157 | def forward(self, x, x_mark): 158 | if x_mark is None: 159 | x = self.value_embedding(x) 160 | else: 161 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) 162 | return self.dropout(x) 163 | 164 | 165 | class PatchEmbedding(nn.Module): 166 | def __init__(self, d_model, patch_len, stride, padding, dropout): 167 | super(PatchEmbedding, self).__init__() 168 | # Patching 169 | self.patch_len = patch_len 170 | self.stride = stride 171 | self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) 172 | 173 | # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space 174 | self.value_embedding = nn.Linear(patch_len, d_model, bias=False) 175 | 176 | # Positional embedding 177 | self.position_embedding = PositionalEmbedding(d_model) 178 | 179 | # Residual dropout 180 | self.dropout = nn.Dropout(dropout) 181 | 182 | def forward(self, x): 183 | # do patching 184 | n_vars = x.shape[1] 185 | x = self.padding_patch_layer(x) 186 | x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) 187 | x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) 188 | # Input encoding 189 | x = self.value_embedding(x) + self.position_embedding(x) 190 | return self.dropout(x), n_vars 191 | -------------------------------------------------------------------------------- /layers/FourierCorrelation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # author=maziqing 3 | # email=maziqing.mzq@alibaba-inc.com 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def get_frequency_modes(seq_len, modes=64, mode_select_method='random'): 11 | """ 12 | get modes on frequency domain: 13 | 'random' means sampling randomly; 14 | 'else' means sampling the lowest modes; 15 | """ 16 | modes = min(modes, seq_len // 2) 17 | if mode_select_method == 'random': 18 | index = list(range(0, seq_len // 2)) 19 | np.random.shuffle(index) 20 | index = index[:modes] 21 | else: 22 | index = list(range(0, modes)) 23 | index.sort() 24 | return index 25 | 26 | 27 | # ########## fourier layer ############# 28 | class FourierBlock(nn.Module): 29 | def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'): 30 | super(FourierBlock, self).__init__() 31 | print('fourier enhanced block used!') 32 | """ 33 | 1D Fourier block. It performs representation learning on frequency domain, 34 | it does FFT, linear transform, and Inverse FFT. 35 | """ 36 | # get modes on frequency domain 37 | self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method) 38 | print('modes={}, index={}'.format(modes, self.index)) 39 | 40 | self.scale = (1 / (in_channels * out_channels)) 41 | self.weights1 = nn.Parameter( 42 | self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float)) 43 | self.weights2 = nn.Parameter( 44 | self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float)) 45 | 46 | # Complex multiplication 47 | def compl_mul1d(self, order, x, weights): 48 | x_flag = True 49 | w_flag = True 50 | if not torch.is_complex(x): 51 | x_flag = False 52 | x = torch.complex(x, torch.zeros_like(x).to(x.device)) 53 | if not torch.is_complex(weights): 54 | w_flag = False 55 | weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) 56 | if x_flag or w_flag: 57 | return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), 58 | torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real)) 59 | else: 60 | return torch.einsum(order, x.real, weights.real) 61 | 62 | def forward(self, q, k, v, mask): 63 | # size = [B, L, H, E] 64 | B, L, H, E = q.shape 65 | x = q.permute(0, 2, 3, 1) 66 | # Compute Fourier coefficients 67 | x_ft = torch.fft.rfft(x, dim=-1) 68 | # Perform Fourier neural operations 69 | out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) 70 | for wi, i in enumerate(self.index): 71 | if i >= x_ft.shape[3] or wi >= out_ft.shape[3]: 72 | continue 73 | out_ft[:, :, :, wi] = self.compl_mul1d("bhi,hio->bho", x_ft[:, :, :, i], 74 | torch.complex(self.weights1, self.weights2)[:, :, :, wi]) 75 | # Return to time domain 76 | x = torch.fft.irfft(out_ft, n=x.size(-1)) 77 | return (x, None) 78 | 79 | 80 | # ########## Fourier Cross Former #################### 81 | class FourierCrossAttention(nn.Module): 82 | def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random', 83 | activation='tanh', policy=0, num_heads=8): 84 | super(FourierCrossAttention, self).__init__() 85 | print(' fourier enhanced cross attention used!') 86 | """ 87 | 1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT. 88 | """ 89 | self.activation = activation 90 | self.in_channels = in_channels 91 | self.out_channels = out_channels 92 | # get modes for queries and keys (& values) on frequency domain 93 | self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method) 94 | self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method) 95 | 96 | print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q)) 97 | print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv)) 98 | 99 | self.scale = (1 / (in_channels * out_channels)) 100 | self.weights1 = nn.Parameter( 101 | self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float)) 102 | self.weights2 = nn.Parameter( 103 | self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float)) 104 | 105 | # Complex multiplication 106 | def compl_mul1d(self, order, x, weights): 107 | x_flag = True 108 | w_flag = True 109 | if not torch.is_complex(x): 110 | x_flag = False 111 | x = torch.complex(x, torch.zeros_like(x).to(x.device)) 112 | if not torch.is_complex(weights): 113 | w_flag = False 114 | weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) 115 | if x_flag or w_flag: 116 | return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), 117 | torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real)) 118 | else: 119 | return torch.einsum(order, x.real, weights.real) 120 | 121 | def forward(self, q, k, v, mask): 122 | # size = [B, L, H, E] 123 | B, L, H, E = q.shape 124 | xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L] 125 | xk = k.permute(0, 2, 3, 1) 126 | xv = v.permute(0, 2, 3, 1) 127 | 128 | # Compute Fourier coefficients 129 | xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) 130 | xq_ft = torch.fft.rfft(xq, dim=-1) 131 | for i, j in enumerate(self.index_q): 132 | if j >= xq_ft.shape[3]: 133 | continue 134 | xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] 135 | xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat) 136 | xk_ft = torch.fft.rfft(xk, dim=-1) 137 | for i, j in enumerate(self.index_kv): 138 | if j >= xk_ft.shape[3]: 139 | continue 140 | xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] 141 | 142 | # perform attention mechanism on frequency domain 143 | xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)) 144 | if self.activation == 'tanh': 145 | xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh()) 146 | elif self.activation == 'softmax': 147 | xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) 148 | xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) 149 | else: 150 | raise Exception('{} actiation function is not implemented'.format(self.activation)) 151 | xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_) 152 | xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2)) 153 | out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) 154 | for i, j in enumerate(self.index_q): 155 | if i >= xqkvw.shape[3] or j >= out_ft.shape[3]: 156 | continue 157 | out_ft[:, :, :, j] = xqkvw[:, :, :, i] 158 | # Return to time domain 159 | out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)) 160 | return (out, None) 161 | -------------------------------------------------------------------------------- /layers/StandardNorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Normalize(nn.Module): 6 | def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): 7 | """ 8 | :param num_features: the number of features or channels 9 | :param eps: a value added for numerical stability 10 | :param affine: if True, RevIN has learnable affine parameters 11 | """ 12 | super(Normalize, self).__init__() 13 | self.num_features = num_features 14 | self.eps = eps 15 | self.affine = affine 16 | self.subtract_last = subtract_last 17 | self.non_norm = non_norm 18 | if self.affine: 19 | self._init_params() 20 | 21 | def forward(self, x, mode: str): 22 | if mode == 'norm': 23 | self._get_statistics(x) 24 | x = self._normalize(x) 25 | elif mode == 'denorm': 26 | x = self._denormalize(x) 27 | else: 28 | raise NotImplementedError 29 | return x 30 | 31 | def _init_params(self): 32 | # initialize RevIN params: (C,) 33 | self.affine_weight = nn.Parameter(torch.ones(self.num_features)) 34 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) 35 | 36 | def _get_statistics(self, x): 37 | dim2reduce = tuple(range(1, x.ndim - 1)) 38 | if self.subtract_last: 39 | self.last = x[:, -1, :].unsqueeze(1) 40 | else: 41 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() 42 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() 43 | 44 | def _normalize(self, x): 45 | if self.non_norm: 46 | return x 47 | if self.subtract_last: 48 | x = x - self.last 49 | else: 50 | x = x - self.mean 51 | x = x / self.stdev 52 | if self.affine: 53 | x = x * self.affine_weight 54 | x = x + self.affine_bias 55 | return x 56 | 57 | def _denormalize(self, x): 58 | if self.non_norm: 59 | return x 60 | if self.affine: 61 | x = x - self.affine_bias 62 | x = x / (self.affine_weight + self.eps * self.eps) 63 | x = x * self.stdev 64 | if self.subtract_last: 65 | x = x + self.last 66 | else: 67 | x = x + self.mean 68 | return x 69 | -------------------------------------------------------------------------------- /layers/Transformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLayer(nn.Module): 7 | def __init__(self, c_in): 8 | super(ConvLayer, self).__init__() 9 | self.downConv = nn.Conv1d(in_channels=c_in, 10 | out_channels=c_in, 11 | kernel_size=3, 12 | padding=2, 13 | padding_mode='circular') 14 | self.norm = nn.BatchNorm1d(c_in) 15 | self.activation = nn.ELU() 16 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 17 | 18 | def forward(self, x): 19 | x = self.downConv(x.permute(0, 2, 1)) 20 | x = self.norm(x) 21 | x = self.activation(x) 22 | x = self.maxPool(x) 23 | x = x.transpose(1, 2) 24 | return x 25 | 26 | 27 | class EncoderLayer(nn.Module): 28 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 29 | super(EncoderLayer, self).__init__() 30 | d_ff = d_ff or 4 * d_model 31 | self.attention = attention 32 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 33 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 34 | self.norm1 = nn.LayerNorm(d_model) 35 | self.norm2 = nn.LayerNorm(d_model) 36 | self.dropout = nn.Dropout(dropout) 37 | self.activation = F.relu if activation == "relu" else F.gelu 38 | 39 | def forward(self, x, attn_mask=None, tau=None, delta=None): 40 | new_x, attn = self.attention( 41 | x, x, x, 42 | attn_mask=attn_mask, 43 | tau=tau, delta=delta 44 | ) 45 | x = x + self.dropout(new_x) 46 | 47 | y = x = self.norm1(x) 48 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 49 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 50 | 51 | return self.norm2(x + y), attn 52 | 53 | 54 | class Encoder(nn.Module): 55 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 56 | super(Encoder, self).__init__() 57 | self.attn_layers = nn.ModuleList(attn_layers) 58 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 59 | self.norm = norm_layer 60 | 61 | def forward(self, x, attn_mask=None, tau=None, delta=None): 62 | # x [B, L, D] 63 | attns = [] 64 | if self.conv_layers is not None: 65 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): 66 | delta = delta if i == 0 else None 67 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 68 | x = conv_layer(x) 69 | attns.append(attn) 70 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None) 71 | attns.append(attn) 72 | else: 73 | for attn_layer in self.attn_layers: 74 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 75 | attns.append(attn) 76 | 77 | if self.norm is not None: 78 | x = self.norm(x) 79 | 80 | return x, attns 81 | 82 | 83 | class DecoderLayer(nn.Module): 84 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 85 | dropout=0.1, activation="relu"): 86 | super(DecoderLayer, self).__init__() 87 | d_ff = d_ff or 4 * d_model 88 | self.self_attention = self_attention 89 | self.cross_attention = cross_attention 90 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 91 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 92 | self.norm1 = nn.LayerNorm(d_model) 93 | self.norm2 = nn.LayerNorm(d_model) 94 | self.norm3 = nn.LayerNorm(d_model) 95 | self.dropout = nn.Dropout(dropout) 96 | self.activation = F.relu if activation == "relu" else F.gelu 97 | 98 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 99 | x = x + self.dropout(self.self_attention( 100 | x, x, x, 101 | attn_mask=x_mask, 102 | tau=tau, delta=None 103 | )[0]) 104 | x = self.norm1(x) 105 | 106 | x = x + self.dropout(self.cross_attention( 107 | x, cross, cross, 108 | attn_mask=cross_mask, 109 | tau=tau, delta=delta 110 | )[0]) 111 | 112 | y = x = self.norm2(x) 113 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 114 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 115 | 116 | return self.norm3(x + y) 117 | 118 | 119 | class Decoder(nn.Module): 120 | def __init__(self, layers, norm_layer=None, projection=None): 121 | super(Decoder, self).__init__() 122 | self.layers = nn.ModuleList(layers) 123 | self.norm = norm_layer 124 | self.projection = projection 125 | 126 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 127 | for layer in self.layers: 128 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta) 129 | 130 | if self.norm is not None: 131 | x = self.norm(x) 132 | 133 | if self.projection is not None: 134 | x = self.projection(x) 135 | return x 136 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/TimeXer/76011909357972bd55a27adba2e1be994d81b327/layers/__init__.py -------------------------------------------------------------------------------- /models/Autoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Embed import DataEmbedding, DataEmbedding_wo_pos 5 | from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer 6 | from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp 7 | import math 8 | import numpy as np 9 | 10 | 11 | class Model(nn.Module): 12 | """ 13 | Autoformer is the first method to achieve the series-wise connection, 14 | with inherent O(LlogL) complexity 15 | Paper link: https://openreview.net/pdf?id=I55UqU-M11y 16 | """ 17 | 18 | def __init__(self, configs): 19 | super(Model, self).__init__() 20 | self.task_name = configs.task_name 21 | self.seq_len = configs.seq_len 22 | self.label_len = configs.label_len 23 | self.pred_len = configs.pred_len 24 | 25 | # Decomp 26 | kernel_size = configs.moving_avg 27 | self.decomp = series_decomp(kernel_size) 28 | 29 | # Embedding 30 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, 31 | configs.dropout) 32 | # Encoder 33 | self.encoder = Encoder( 34 | [ 35 | EncoderLayer( 36 | AutoCorrelationLayer( 37 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 38 | output_attention=False), 39 | configs.d_model, configs.n_heads), 40 | configs.d_model, 41 | configs.d_ff, 42 | moving_avg=configs.moving_avg, 43 | dropout=configs.dropout, 44 | activation=configs.activation 45 | ) for l in range(configs.e_layers) 46 | ], 47 | norm_layer=my_Layernorm(configs.d_model) 48 | ) 49 | # Decoder 50 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 51 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, 52 | configs.dropout) 53 | self.decoder = Decoder( 54 | [ 55 | DecoderLayer( 56 | AutoCorrelationLayer( 57 | AutoCorrelation(True, configs.factor, attention_dropout=configs.dropout, 58 | output_attention=False), 59 | configs.d_model, configs.n_heads), 60 | AutoCorrelationLayer( 61 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 62 | output_attention=False), 63 | configs.d_model, configs.n_heads), 64 | configs.d_model, 65 | configs.c_out, 66 | configs.d_ff, 67 | moving_avg=configs.moving_avg, 68 | dropout=configs.dropout, 69 | activation=configs.activation, 70 | ) 71 | for l in range(configs.d_layers) 72 | ], 73 | norm_layer=my_Layernorm(configs.d_model), 74 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 75 | ) 76 | if self.task_name == 'imputation': 77 | self.projection = nn.Linear( 78 | configs.d_model, configs.c_out, bias=True) 79 | if self.task_name == 'anomaly_detection': 80 | self.projection = nn.Linear( 81 | configs.d_model, configs.c_out, bias=True) 82 | if self.task_name == 'classification': 83 | self.act = F.gelu 84 | self.dropout = nn.Dropout(configs.dropout) 85 | self.projection = nn.Linear( 86 | configs.d_model * configs.seq_len, configs.num_class) 87 | 88 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 89 | # decomp init 90 | mean = torch.mean(x_enc, dim=1).unsqueeze( 91 | 1).repeat(1, self.pred_len, 1) 92 | zeros = torch.zeros([x_dec.shape[0], self.pred_len, 93 | x_dec.shape[2]], device=x_enc.device) 94 | seasonal_init, trend_init = self.decomp(x_enc) 95 | # decoder input 96 | trend_init = torch.cat( 97 | [trend_init[:, -self.label_len:, :], mean], dim=1) 98 | seasonal_init = torch.cat( 99 | [seasonal_init[:, -self.label_len:, :], zeros], dim=1) 100 | # enc 101 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 102 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 103 | # dec 104 | dec_out = self.dec_embedding(seasonal_init, x_mark_dec) 105 | seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, 106 | trend=trend_init) 107 | # final 108 | dec_out = trend_part + seasonal_part 109 | return dec_out 110 | 111 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 112 | # enc 113 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 114 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 115 | # final 116 | dec_out = self.projection(enc_out) 117 | return dec_out 118 | 119 | def anomaly_detection(self, x_enc): 120 | # enc 121 | enc_out = self.enc_embedding(x_enc, None) 122 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 123 | # final 124 | dec_out = self.projection(enc_out) 125 | return dec_out 126 | 127 | def classification(self, x_enc, x_mark_enc): 128 | # enc 129 | enc_out = self.enc_embedding(x_enc, None) 130 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 131 | 132 | # Output 133 | # the output transformer encoder/decoder embeddings don't include non-linearity 134 | output = self.act(enc_out) 135 | output = self.dropout(output) 136 | # zero-out padding embeddings 137 | output = output * x_mark_enc.unsqueeze(-1) 138 | # (batch_size, seq_length * d_model) 139 | output = output.reshape(output.shape[0], -1) 140 | output = self.projection(output) # (batch_size, num_classes) 141 | return output 142 | 143 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 144 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 145 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 146 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 147 | if self.task_name == 'imputation': 148 | dec_out = self.imputation( 149 | x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 150 | return dec_out # [B, L, D] 151 | if self.task_name == 'anomaly_detection': 152 | dec_out = self.anomaly_detection(x_enc) 153 | return dec_out # [B, L, D] 154 | if self.task_name == 'classification': 155 | dec_out = self.classification(x_enc, x_mark_enc) 156 | return dec_out # [B, N] 157 | return None 158 | -------------------------------------------------------------------------------- /models/Crossformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from layers.Crossformer_EncDec import scale_block, Encoder, Decoder, DecoderLayer 6 | from layers.Embed import PatchEmbedding 7 | from layers.SelfAttention_Family import AttentionLayer, FullAttention, TwoStageAttentionLayer 8 | from models.PatchTST import FlattenHead 9 | 10 | 11 | from math import ceil 12 | 13 | 14 | class Model(nn.Module): 15 | """ 16 | Paper link: https://openreview.net/pdf?id=vSVLM2j9eie 17 | """ 18 | def __init__(self, configs): 19 | super(Model, self).__init__() 20 | self.enc_in = configs.enc_in 21 | self.seq_len = configs.seq_len 22 | self.pred_len = configs.pred_len 23 | self.seg_len = 12 24 | self.win_size = 2 25 | self.task_name = configs.task_name 26 | 27 | # The padding operation to handle invisible sgemnet length 28 | self.pad_in_len = ceil(1.0 * configs.seq_len / self.seg_len) * self.seg_len 29 | self.pad_out_len = ceil(1.0 * configs.pred_len / self.seg_len) * self.seg_len 30 | self.in_seg_num = self.pad_in_len // self.seg_len 31 | self.out_seg_num = ceil(self.in_seg_num / (self.win_size ** (configs.e_layers - 1))) 32 | self.head_nf = configs.d_model * self.out_seg_num 33 | 34 | # Embedding 35 | self.enc_value_embedding = PatchEmbedding(configs.d_model, self.seg_len, self.seg_len, self.pad_in_len - configs.seq_len, 0) 36 | self.enc_pos_embedding = nn.Parameter( 37 | torch.randn(1, configs.enc_in, self.in_seg_num, configs.d_model)) 38 | self.pre_norm = nn.LayerNorm(configs.d_model) 39 | 40 | # Encoder 41 | self.encoder = Encoder( 42 | [ 43 | scale_block(configs, 1 if l is 0 else self.win_size, configs.d_model, configs.n_heads, configs.d_ff, 44 | 1, configs.dropout, 45 | self.in_seg_num if l is 0 else ceil(self.in_seg_num / self.win_size ** l), configs.factor 46 | ) for l in range(configs.e_layers) 47 | ] 48 | ) 49 | # Decoder 50 | self.dec_pos_embedding = nn.Parameter( 51 | torch.randn(1, configs.enc_in, (self.pad_out_len // self.seg_len), configs.d_model)) 52 | 53 | self.decoder = Decoder( 54 | [ 55 | DecoderLayer( 56 | TwoStageAttentionLayer(configs, (self.pad_out_len // self.seg_len), configs.factor, configs.d_model, configs.n_heads, 57 | configs.d_ff, configs.dropout), 58 | AttentionLayer( 59 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 60 | output_attention=False), 61 | configs.d_model, configs.n_heads), 62 | self.seg_len, 63 | configs.d_model, 64 | configs.d_ff, 65 | dropout=configs.dropout, 66 | # activation=configs.activation, 67 | ) 68 | for l in range(configs.e_layers + 1) 69 | ], 70 | ) 71 | if self.task_name == 'imputation' or self.task_name == 'anomaly_detection': 72 | self.head = FlattenHead(configs.enc_in, self.head_nf, configs.seq_len, 73 | head_dropout=configs.dropout) 74 | elif self.task_name == 'classification': 75 | self.flatten = nn.Flatten(start_dim=-2) 76 | self.dropout = nn.Dropout(configs.dropout) 77 | self.projection = nn.Linear( 78 | self.head_nf * configs.enc_in, configs.num_class) 79 | 80 | 81 | 82 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 83 | # embedding 84 | x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1)) 85 | x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d = n_vars) 86 | x_enc += self.enc_pos_embedding 87 | x_enc = self.pre_norm(x_enc) 88 | enc_out, attns = self.encoder(x_enc) 89 | 90 | dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat=x_enc.shape[0]) 91 | dec_out = self.decoder(dec_in, enc_out) 92 | return dec_out 93 | 94 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 95 | # embedding 96 | x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1)) 97 | x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars) 98 | x_enc += self.enc_pos_embedding 99 | x_enc = self.pre_norm(x_enc) 100 | enc_out, attns = self.encoder(x_enc) 101 | 102 | dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1) 103 | 104 | return dec_out 105 | 106 | def anomaly_detection(self, x_enc): 107 | # embedding 108 | x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1)) 109 | x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars) 110 | x_enc += self.enc_pos_embedding 111 | x_enc = self.pre_norm(x_enc) 112 | enc_out, attns = self.encoder(x_enc) 113 | 114 | dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1) 115 | return dec_out 116 | 117 | def classification(self, x_enc, x_mark_enc): 118 | # embedding 119 | x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1)) 120 | 121 | x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars) 122 | x_enc += self.enc_pos_embedding 123 | x_enc = self.pre_norm(x_enc) 124 | enc_out, attns = self.encoder(x_enc) 125 | # Output from Non-stationary Transformer 126 | output = self.flatten(enc_out[-1].permute(0, 1, 3, 2)) 127 | output = self.dropout(output) 128 | output = output.reshape(output.shape[0], -1) 129 | output = self.projection(output) 130 | return output 131 | 132 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 133 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 134 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 135 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 136 | if self.task_name == 'imputation': 137 | dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 138 | return dec_out # [B, L, D] 139 | if self.task_name == 'anomaly_detection': 140 | dec_out = self.anomaly_detection(x_enc) 141 | return dec_out # [B, L, D] 142 | if self.task_name == 'classification': 143 | dec_out = self.classification(x_enc, x_mark_enc) 144 | return dec_out # [B, N] 145 | return None -------------------------------------------------------------------------------- /models/DLinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Autoformer_EncDec import series_decomp 5 | 6 | 7 | class Model(nn.Module): 8 | """ 9 | Paper link: https://arxiv.org/pdf/2205.13504.pdf 10 | """ 11 | 12 | def __init__(self, configs, individual=False): 13 | """ 14 | individual: Bool, whether shared model among different variates. 15 | """ 16 | super(Model, self).__init__() 17 | self.task_name = configs.task_name 18 | self.seq_len = configs.seq_len 19 | if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation': 20 | self.pred_len = configs.seq_len 21 | else: 22 | self.pred_len = configs.pred_len 23 | # Series decomposition block from Autoformer 24 | self.decompsition = series_decomp(configs.moving_avg) 25 | self.individual = individual 26 | self.channels = configs.enc_in 27 | 28 | if self.individual: 29 | self.Linear_Seasonal = nn.ModuleList() 30 | self.Linear_Trend = nn.ModuleList() 31 | 32 | for i in range(self.channels): 33 | self.Linear_Seasonal.append( 34 | nn.Linear(self.seq_len, self.pred_len)) 35 | self.Linear_Trend.append( 36 | nn.Linear(self.seq_len, self.pred_len)) 37 | 38 | self.Linear_Seasonal[i].weight = nn.Parameter( 39 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 40 | self.Linear_Trend[i].weight = nn.Parameter( 41 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 42 | else: 43 | self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len) 44 | self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len) 45 | 46 | self.Linear_Seasonal.weight = nn.Parameter( 47 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 48 | self.Linear_Trend.weight = nn.Parameter( 49 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 50 | 51 | if self.task_name == 'classification': 52 | self.projection = nn.Linear( 53 | configs.enc_in * configs.seq_len, configs.num_class) 54 | 55 | def encoder(self, x): 56 | seasonal_init, trend_init = self.decompsition(x) 57 | seasonal_init, trend_init = seasonal_init.permute( 58 | 0, 2, 1), trend_init.permute(0, 2, 1) 59 | if self.individual: 60 | seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size(1), self.pred_len], 61 | dtype=seasonal_init.dtype).to(seasonal_init.device) 62 | trend_output = torch.zeros([trend_init.size(0), trend_init.size(1), self.pred_len], 63 | dtype=trend_init.dtype).to(trend_init.device) 64 | for i in range(self.channels): 65 | seasonal_output[:, i, :] = self.Linear_Seasonal[i]( 66 | seasonal_init[:, i, :]) 67 | trend_output[:, i, :] = self.Linear_Trend[i]( 68 | trend_init[:, i, :]) 69 | else: 70 | seasonal_output = self.Linear_Seasonal(seasonal_init) 71 | trend_output = self.Linear_Trend(trend_init) 72 | x = seasonal_output + trend_output 73 | return x.permute(0, 2, 1) 74 | 75 | def forecast(self, x_enc): 76 | # Encoder 77 | return self.encoder(x_enc) 78 | 79 | def imputation(self, x_enc): 80 | # Encoder 81 | return self.encoder(x_enc) 82 | 83 | def anomaly_detection(self, x_enc): 84 | # Encoder 85 | return self.encoder(x_enc) 86 | 87 | def classification(self, x_enc): 88 | # Encoder 89 | enc_out = self.encoder(x_enc) 90 | # Output 91 | # (batch_size, seq_length * d_model) 92 | output = enc_out.reshape(enc_out.shape[0], -1) 93 | # (batch_size, num_classes) 94 | output = self.projection(output) 95 | return output 96 | 97 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 98 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 99 | dec_out = self.forecast(x_enc) 100 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 101 | if self.task_name == 'imputation': 102 | dec_out = self.imputation(x_enc) 103 | return dec_out # [B, L, D] 104 | if self.task_name == 'anomaly_detection': 105 | dec_out = self.anomaly_detection(x_enc) 106 | return dec_out # [B, L, D] 107 | if self.task_name == 'classification': 108 | dec_out = self.classification(x_enc) 109 | return dec_out # [B, N] 110 | return None 111 | -------------------------------------------------------------------------------- /models/ETSformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers.Embed import DataEmbedding 4 | from layers.ETSformer_EncDec import EncoderLayer, Encoder, DecoderLayer, Decoder, Transform 5 | 6 | 7 | class Model(nn.Module): 8 | """ 9 | Paper link: https://arxiv.org/abs/2202.01381 10 | """ 11 | 12 | def __init__(self, configs): 13 | super(Model, self).__init__() 14 | self.task_name = configs.task_name 15 | self.seq_len = configs.seq_len 16 | self.label_len = configs.label_len 17 | if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation': 18 | self.pred_len = configs.seq_len 19 | else: 20 | self.pred_len = configs.pred_len 21 | 22 | assert configs.e_layers == configs.d_layers, "Encoder and decoder layers must be equal" 23 | 24 | # Embedding 25 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 26 | configs.dropout) 27 | 28 | # Encoder 29 | self.encoder = Encoder( 30 | [ 31 | EncoderLayer( 32 | configs.d_model, configs.n_heads, configs.enc_in, configs.seq_len, self.pred_len, configs.top_k, 33 | dim_feedforward=configs.d_ff, 34 | dropout=configs.dropout, 35 | activation=configs.activation, 36 | ) for _ in range(configs.e_layers) 37 | ] 38 | ) 39 | # Decoder 40 | self.decoder = Decoder( 41 | [ 42 | DecoderLayer( 43 | configs.d_model, configs.n_heads, configs.c_out, self.pred_len, 44 | dropout=configs.dropout, 45 | ) for _ in range(configs.d_layers) 46 | ], 47 | ) 48 | self.transform = Transform(sigma=0.2) 49 | 50 | if self.task_name == 'classification': 51 | self.act = torch.nn.functional.gelu 52 | self.dropout = nn.Dropout(configs.dropout) 53 | self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class) 54 | 55 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 56 | with torch.no_grad(): 57 | if self.training: 58 | x_enc = self.transform.transform(x_enc) 59 | res = self.enc_embedding(x_enc, x_mark_enc) 60 | level, growths, seasons = self.encoder(res, x_enc, attn_mask=None) 61 | 62 | growth, season = self.decoder(growths, seasons) 63 | preds = level[:, -1:] + growth + season 64 | return preds 65 | 66 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 67 | res = self.enc_embedding(x_enc, x_mark_enc) 68 | level, growths, seasons = self.encoder(res, x_enc, attn_mask=None) 69 | growth, season = self.decoder(growths, seasons) 70 | preds = level[:, -1:] + growth + season 71 | return preds 72 | 73 | def anomaly_detection(self, x_enc): 74 | res = self.enc_embedding(x_enc, None) 75 | level, growths, seasons = self.encoder(res, x_enc, attn_mask=None) 76 | growth, season = self.decoder(growths, seasons) 77 | preds = level[:, -1:] + growth + season 78 | return preds 79 | 80 | def classification(self, x_enc, x_mark_enc): 81 | res = self.enc_embedding(x_enc, None) 82 | _, growths, seasons = self.encoder(res, x_enc, attn_mask=None) 83 | 84 | growths = torch.sum(torch.stack(growths, 0), 0)[:, :self.seq_len, :] 85 | seasons = torch.sum(torch.stack(seasons, 0), 0)[:, :self.seq_len, :] 86 | 87 | enc_out = growths + seasons 88 | output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity 89 | output = self.dropout(output) 90 | 91 | # Output 92 | output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings 93 | output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model) 94 | output = self.projection(output) # (batch_size, num_classes) 95 | return output 96 | 97 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 98 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 99 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 100 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 101 | if self.task_name == 'imputation': 102 | dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 103 | return dec_out # [B, L, D] 104 | if self.task_name == 'anomaly_detection': 105 | dec_out = self.anomaly_detection(x_enc) 106 | return dec_out # [B, L, D] 107 | if self.task_name == 'classification': 108 | dec_out = self.classification(x_enc, x_mark_enc) 109 | return dec_out # [B, N] 110 | return None 111 | -------------------------------------------------------------------------------- /models/FreTS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class Model(nn.Module): 8 | """ 9 | Paper link: https://arxiv.org/pdf/2311.06184.pdf 10 | """ 11 | 12 | def __init__(self, configs): 13 | super(Model, self).__init__() 14 | self.task_name = configs.task_name 15 | if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation': 16 | self.pred_len = configs.seq_len 17 | else: 18 | self.pred_len = configs.pred_len 19 | self.embed_size = 128 # embed_size 20 | self.hidden_size = 256 # hidden_size 21 | self.pred_len = configs.pred_len 22 | self.feature_size = configs.enc_in # channels 23 | self.seq_len = configs.seq_len 24 | self.channel_independence = configs.channel_independence 25 | self.sparsity_threshold = 0.01 26 | self.scale = 0.02 27 | self.embeddings = nn.Parameter(torch.randn(1, self.embed_size)) 28 | self.r1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) 29 | self.i1 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) 30 | self.rb1 = nn.Parameter(self.scale * torch.randn(self.embed_size)) 31 | self.ib1 = nn.Parameter(self.scale * torch.randn(self.embed_size)) 32 | self.r2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) 33 | self.i2 = nn.Parameter(self.scale * torch.randn(self.embed_size, self.embed_size)) 34 | self.rb2 = nn.Parameter(self.scale * torch.randn(self.embed_size)) 35 | self.ib2 = nn.Parameter(self.scale * torch.randn(self.embed_size)) 36 | 37 | self.fc = nn.Sequential( 38 | nn.Linear(self.seq_len * self.embed_size, self.hidden_size), 39 | nn.LeakyReLU(), 40 | nn.Linear(self.hidden_size, self.pred_len) 41 | ) 42 | 43 | # dimension extension 44 | def tokenEmb(self, x): 45 | # x: [Batch, Input length, Channel] 46 | x = x.permute(0, 2, 1) 47 | x = x.unsqueeze(3) 48 | # N*T*1 x 1*D = N*T*D 49 | y = self.embeddings 50 | return x * y 51 | 52 | # frequency temporal learner 53 | def MLP_temporal(self, x, B, N, L): 54 | # [B, N, T, D] 55 | x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on L dimension 56 | y = self.FreMLP(B, N, L, x, self.r2, self.i2, self.rb2, self.ib2) 57 | x = torch.fft.irfft(y, n=self.seq_len, dim=2, norm="ortho") 58 | return x 59 | 60 | # frequency channel learner 61 | def MLP_channel(self, x, B, N, L): 62 | # [B, N, T, D] 63 | x = x.permute(0, 2, 1, 3) 64 | # [B, T, N, D] 65 | x = torch.fft.rfft(x, dim=2, norm='ortho') # FFT on N dimension 66 | y = self.FreMLP(B, L, N, x, self.r1, self.i1, self.rb1, self.ib1) 67 | x = torch.fft.irfft(y, n=self.feature_size, dim=2, norm="ortho") 68 | x = x.permute(0, 2, 1, 3) 69 | # [B, N, T, D] 70 | return x 71 | 72 | # frequency-domain MLPs 73 | # dimension: FFT along the dimension, r: the real part of weights, i: the imaginary part of weights 74 | # rb: the real part of bias, ib: the imaginary part of bias 75 | def FreMLP(self, B, nd, dimension, x, r, i, rb, ib): 76 | o1_real = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size], 77 | device=x.device) 78 | o1_imag = torch.zeros([B, nd, dimension // 2 + 1, self.embed_size], 79 | device=x.device) 80 | 81 | o1_real = F.relu( 82 | torch.einsum('bijd,dd->bijd', x.real, r) - \ 83 | torch.einsum('bijd,dd->bijd', x.imag, i) + \ 84 | rb 85 | ) 86 | 87 | o1_imag = F.relu( 88 | torch.einsum('bijd,dd->bijd', x.imag, r) + \ 89 | torch.einsum('bijd,dd->bijd', x.real, i) + \ 90 | ib 91 | ) 92 | 93 | y = torch.stack([o1_real, o1_imag], dim=-1) 94 | y = F.softshrink(y, lambd=self.sparsity_threshold) 95 | y = torch.view_as_complex(y) 96 | return y 97 | 98 | def forecast(self, x_enc): 99 | # x: [Batch, Input length, Channel] 100 | B, T, N = x_enc.shape 101 | # embedding x: [B, N, T, D] 102 | x = self.tokenEmb(x_enc) 103 | bias = x 104 | # [B, N, T, D] 105 | if self.channel_independence == '0': 106 | x = self.MLP_channel(x, B, N, T) 107 | # [B, N, T, D] 108 | x = self.MLP_temporal(x, B, N, T) 109 | x = x + bias 110 | x = self.fc(x.reshape(B, N, -1)).permute(0, 2, 1) 111 | return x 112 | 113 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 114 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 115 | dec_out = self.forecast(x_enc) 116 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 117 | else: 118 | raise ValueError('Only forecast tasks implemented yet') 119 | -------------------------------------------------------------------------------- /models/Informer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer 5 | from layers.SelfAttention_Family import ProbAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding 7 | 8 | 9 | class Model(nn.Module): 10 | """ 11 | Informer with Propspare attention in O(LlogL) complexity 12 | Paper link: https://ojs.aaai.org/index.php/AAAI/article/view/17325/17132 13 | """ 14 | 15 | def __init__(self, configs): 16 | super(Model, self).__init__() 17 | self.task_name = configs.task_name 18 | self.pred_len = configs.pred_len 19 | self.label_len = configs.label_len 20 | 21 | # Embedding 22 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 23 | configs.dropout) 24 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 25 | configs.dropout) 26 | 27 | # Encoder 28 | self.encoder = Encoder( 29 | [ 30 | EncoderLayer( 31 | AttentionLayer( 32 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, 33 | output_attention=False), 34 | configs.d_model, configs.n_heads), 35 | configs.d_model, 36 | configs.d_ff, 37 | dropout=configs.dropout, 38 | activation=configs.activation 39 | ) for l in range(configs.e_layers) 40 | ], 41 | [ 42 | ConvLayer( 43 | configs.d_model 44 | ) for l in range(configs.e_layers - 1) 45 | ] if configs.distil and ('forecast' in configs.task_name) else None, 46 | norm_layer=torch.nn.LayerNorm(configs.d_model) 47 | ) 48 | # Decoder 49 | self.decoder = Decoder( 50 | [ 51 | DecoderLayer( 52 | AttentionLayer( 53 | ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False), 54 | configs.d_model, configs.n_heads), 55 | AttentionLayer( 56 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), 57 | configs.d_model, configs.n_heads), 58 | configs.d_model, 59 | configs.d_ff, 60 | dropout=configs.dropout, 61 | activation=configs.activation, 62 | ) 63 | for l in range(configs.d_layers) 64 | ], 65 | norm_layer=torch.nn.LayerNorm(configs.d_model), 66 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 67 | ) 68 | if self.task_name == 'imputation': 69 | self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True) 70 | if self.task_name == 'anomaly_detection': 71 | self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True) 72 | if self.task_name == 'classification': 73 | self.act = F.gelu 74 | self.dropout = nn.Dropout(configs.dropout) 75 | self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class) 76 | 77 | def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 78 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 79 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 80 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 81 | 82 | dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None) 83 | 84 | return dec_out # [B, L, D] 85 | 86 | def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 87 | # Normalization 88 | mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E 89 | x_enc = x_enc - mean_enc 90 | std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E 91 | x_enc = x_enc / std_enc 92 | 93 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 94 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 95 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 96 | 97 | dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None) 98 | 99 | dec_out = dec_out * std_enc + mean_enc 100 | return dec_out # [B, L, D] 101 | 102 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 103 | # enc 104 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 105 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 106 | # final 107 | dec_out = self.projection(enc_out) 108 | return dec_out 109 | 110 | def anomaly_detection(self, x_enc): 111 | # enc 112 | enc_out = self.enc_embedding(x_enc, None) 113 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 114 | # final 115 | dec_out = self.projection(enc_out) 116 | return dec_out 117 | 118 | def classification(self, x_enc, x_mark_enc): 119 | # enc 120 | enc_out = self.enc_embedding(x_enc, None) 121 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 122 | 123 | # Output 124 | output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity 125 | output = self.dropout(output) 126 | output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings 127 | output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model) 128 | output = self.projection(output) # (batch_size, num_classes) 129 | return output 130 | 131 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 132 | if self.task_name == 'long_term_forecast': 133 | dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 134 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 135 | if self.task_name == 'short_term_forecast': 136 | dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 137 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 138 | if self.task_name == 'imputation': 139 | dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 140 | return dec_out # [B, L, D] 141 | if self.task_name == 'anomaly_detection': 142 | dec_out = self.anomaly_detection(x_enc) 143 | return dec_out # [B, L, D] 144 | if self.task_name == 'classification': 145 | dec_out = self.classification(x_enc, x_mark_enc) 146 | return dec_out # [B, N] 147 | return None 148 | -------------------------------------------------------------------------------- /models/LightTS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class IEBlock(nn.Module): 7 | def __init__(self, input_dim, hid_dim, output_dim, num_node): 8 | super(IEBlock, self).__init__() 9 | 10 | self.input_dim = input_dim 11 | self.hid_dim = hid_dim 12 | self.output_dim = output_dim 13 | self.num_node = num_node 14 | 15 | self._build() 16 | 17 | def _build(self): 18 | self.spatial_proj = nn.Sequential( 19 | nn.Linear(self.input_dim, self.hid_dim), 20 | nn.LeakyReLU(), 21 | nn.Linear(self.hid_dim, self.hid_dim // 4) 22 | ) 23 | 24 | self.channel_proj = nn.Linear(self.num_node, self.num_node) 25 | torch.nn.init.eye_(self.channel_proj.weight) 26 | 27 | self.output_proj = nn.Linear(self.hid_dim // 4, self.output_dim) 28 | 29 | def forward(self, x): 30 | x = self.spatial_proj(x.permute(0, 2, 1)) 31 | x = x.permute(0, 2, 1) + self.channel_proj(x.permute(0, 2, 1)) 32 | x = self.output_proj(x.permute(0, 2, 1)) 33 | 34 | x = x.permute(0, 2, 1) 35 | 36 | return x 37 | 38 | 39 | class Model(nn.Module): 40 | """ 41 | Paper link: https://arxiv.org/abs/2207.01186 42 | """ 43 | 44 | def __init__(self, configs, chunk_size=24): 45 | """ 46 | chunk_size: int, reshape T into [num_chunks, chunk_size] 47 | """ 48 | super(Model, self).__init__() 49 | self.task_name = configs.task_name 50 | self.seq_len = configs.seq_len 51 | if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation': 52 | self.pred_len = configs.seq_len 53 | else: 54 | self.pred_len = configs.pred_len 55 | 56 | if configs.task_name == 'long_term_forecast' or configs.task_name == 'short_term_forecast': 57 | self.chunk_size = min(configs.pred_len, configs.seq_len, chunk_size) 58 | else: 59 | self.chunk_size = min(configs.seq_len, chunk_size) 60 | # assert (self.seq_len % self.chunk_size == 0) 61 | if self.seq_len % self.chunk_size != 0: 62 | self.seq_len += (self.chunk_size - self.seq_len % self.chunk_size) # padding in order to ensure complete division 63 | self.num_chunks = self.seq_len // self.chunk_size 64 | 65 | self.d_model = configs.d_model 66 | self.enc_in = configs.enc_in 67 | self.dropout = configs.dropout 68 | if self.task_name == 'classification': 69 | self.act = F.gelu 70 | self.dropout = nn.Dropout(configs.dropout) 71 | self.projection = nn.Linear(configs.enc_in * configs.seq_len, configs.num_class) 72 | self._build() 73 | 74 | def _build(self): 75 | self.layer_1 = IEBlock( 76 | input_dim=self.chunk_size, 77 | hid_dim=self.d_model // 4, 78 | output_dim=self.d_model // 4, 79 | num_node=self.num_chunks 80 | ) 81 | 82 | self.chunk_proj_1 = nn.Linear(self.num_chunks, 1) 83 | 84 | self.layer_2 = IEBlock( 85 | input_dim=self.chunk_size, 86 | hid_dim=self.d_model // 4, 87 | output_dim=self.d_model // 4, 88 | num_node=self.num_chunks 89 | ) 90 | 91 | self.chunk_proj_2 = nn.Linear(self.num_chunks, 1) 92 | 93 | self.layer_3 = IEBlock( 94 | input_dim=self.d_model // 2, 95 | hid_dim=self.d_model // 2, 96 | output_dim=self.pred_len, 97 | num_node=self.enc_in 98 | ) 99 | 100 | self.ar = nn.Linear(self.seq_len, self.pred_len) 101 | 102 | def encoder(self, x): 103 | B, T, N = x.size() 104 | 105 | highway = self.ar(x.permute(0, 2, 1)) 106 | highway = highway.permute(0, 2, 1) 107 | 108 | # continuous sampling 109 | x1 = x.reshape(B, self.num_chunks, self.chunk_size, N) 110 | x1 = x1.permute(0, 3, 2, 1) 111 | x1 = x1.reshape(-1, self.chunk_size, self.num_chunks) 112 | x1 = self.layer_1(x1) 113 | x1 = self.chunk_proj_1(x1).squeeze(dim=-1) 114 | 115 | # interval sampling 116 | x2 = x.reshape(B, self.chunk_size, self.num_chunks, N) 117 | x2 = x2.permute(0, 3, 1, 2) 118 | x2 = x2.reshape(-1, self.chunk_size, self.num_chunks) 119 | x2 = self.layer_2(x2) 120 | x2 = self.chunk_proj_2(x2).squeeze(dim=-1) 121 | 122 | x3 = torch.cat([x1, x2], dim=-1) 123 | 124 | x3 = x3.reshape(B, N, -1) 125 | x3 = x3.permute(0, 2, 1) 126 | 127 | out = self.layer_3(x3) 128 | 129 | out = out + highway 130 | return out 131 | 132 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 133 | return self.encoder(x_enc) 134 | 135 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 136 | return self.encoder(x_enc) 137 | 138 | def anomaly_detection(self, x_enc): 139 | return self.encoder(x_enc) 140 | 141 | def classification(self, x_enc, x_mark_enc): 142 | # padding 143 | x_enc = torch.cat([x_enc, torch.zeros((x_enc.shape[0], self.seq_len-x_enc.shape[1], x_enc.shape[2])).to(x_enc.device)], dim=1) 144 | 145 | enc_out = self.encoder(x_enc) 146 | 147 | # Output 148 | output = enc_out.reshape(enc_out.shape[0], -1) # (batch_size, seq_length * d_model) 149 | output = self.projection(output) # (batch_size, num_classes) 150 | return output 151 | 152 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 153 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 154 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 155 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 156 | if self.task_name == 'imputation': 157 | dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 158 | return dec_out # [B, L, D] 159 | if self.task_name == 'anomaly_detection': 160 | dec_out = self.anomaly_detection(x_enc) 161 | return dec_out # [B, L, D] 162 | if self.task_name == 'classification': 163 | dec_out = self.classification(x_enc, x_mark_enc) 164 | return dec_out # [B, N] 165 | return None 166 | -------------------------------------------------------------------------------- /models/Mamba.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from mamba_ssm import Mamba 8 | 9 | from layers.Embed import DataEmbedding 10 | 11 | class Model(nn.Module): 12 | 13 | def __init__(self, configs): 14 | super(Model, self).__init__() 15 | self.task_name = configs.task_name 16 | self.pred_len = configs.pred_len 17 | 18 | self.d_inner = configs.d_model * configs.expand 19 | self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto" 20 | 21 | self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout) 22 | 23 | self.mamba = Mamba( 24 | d_model = configs.d_model, 25 | d_state = configs.d_ff, 26 | d_conv = configs.d_conv, 27 | expand = configs.expand, 28 | ) 29 | 30 | self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False) 31 | 32 | def forecast(self, x_enc, x_mark_enc): 33 | mean_enc = x_enc.mean(1, keepdim=True).detach() 34 | x_enc = x_enc - mean_enc 35 | std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() 36 | x_enc = x_enc / std_enc 37 | 38 | x = self.embedding(x_enc, x_mark_enc) 39 | x = self.mamba(x) 40 | x_out = self.out_layer(x) 41 | 42 | x_out = x_out * std_enc + mean_enc 43 | return x_out 44 | 45 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 46 | if self.task_name in ['short_term_forecast', 'long_term_forecast']: 47 | x_out = self.forecast(x_enc, x_mark_enc) 48 | return x_out[:, -self.pred_len:, :] 49 | 50 | # other tasks not implemented -------------------------------------------------------------------------------- /models/MambaSimple.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat, einsum 7 | 8 | from layers.Embed import DataEmbedding 9 | 10 | 11 | class Model(nn.Module): 12 | """ 13 | Mamba, linear-time sequence modeling with selective state spaces O(L) 14 | Paper link: https://arxiv.org/abs/2312.00752 15 | Implementation refernce: https://github.com/johnma2006/mamba-minimal/ 16 | """ 17 | 18 | def __init__(self, configs): 19 | super(Model, self).__init__() 20 | self.task_name = configs.task_name 21 | self.pred_len = configs.pred_len 22 | 23 | self.d_inner = configs.d_model * configs.expand 24 | self.dt_rank = math.ceil(configs.d_model / 16) 25 | 26 | self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout) 27 | 28 | self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)]) 29 | self.norm = RMSNorm(configs.d_model) 30 | 31 | self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False) 32 | 33 | def forecast(self, x_enc, x_mark_enc): 34 | mean_enc = x_enc.mean(1, keepdim=True).detach() 35 | x_enc = x_enc - mean_enc 36 | std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() 37 | x_enc = x_enc / std_enc 38 | 39 | x = self.embedding(x_enc, x_mark_enc) 40 | for layer in self.layers: 41 | x = layer(x) 42 | 43 | x = self.norm(x) 44 | x_out = self.out_layer(x) 45 | 46 | x_out = x_out * std_enc + mean_enc 47 | return x_out 48 | 49 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 50 | if self.task_name in ['short_term_forecast', 'long_term_forecast']: 51 | x_out = self.forecast(x_enc, x_mark_enc) 52 | return x_out[:, -self.pred_len:, :] 53 | 54 | 55 | class ResidualBlock(nn.Module): 56 | def __init__(self, configs, d_inner, dt_rank): 57 | super(ResidualBlock, self).__init__() 58 | 59 | self.mixer = MambaBlock(configs, d_inner, dt_rank) 60 | self.norm = RMSNorm(configs.d_model) 61 | 62 | def forward(self, x): 63 | output = self.mixer(self.norm(x)) + x 64 | return output 65 | 66 | class MambaBlock(nn.Module): 67 | def __init__(self, configs, d_inner, dt_rank): 68 | super(MambaBlock, self).__init__() 69 | self.d_inner = d_inner 70 | self.dt_rank = dt_rank 71 | 72 | self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False) 73 | 74 | self.conv1d = nn.Conv1d( 75 | in_channels = self.d_inner, 76 | out_channels = self.d_inner, 77 | bias = True, 78 | kernel_size = configs.d_conv, 79 | padding = configs.d_conv - 1, 80 | groups = self.d_inner 81 | ) 82 | 83 | # takes in x and outputs the input-specific delta, B, C 84 | self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False) 85 | 86 | # projects delta 87 | self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) 88 | 89 | A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner) 90 | self.A_log = nn.Parameter(torch.log(A)) 91 | self.D = nn.Parameter(torch.ones(self.d_inner)) 92 | 93 | self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False) 94 | 95 | def forward(self, x): 96 | """ 97 | Figure 3 in Section 3.4 in the paper 98 | """ 99 | (b, l, d) = x.shape 100 | 101 | x_and_res = self.in_proj(x) # [B, L, 2 * d_inner] 102 | (x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1) 103 | 104 | x = rearrange(x, "b l d -> b d l") 105 | x = self.conv1d(x)[:, :, :l] 106 | x = rearrange(x, "b d l -> b l d") 107 | 108 | x = F.silu(x) 109 | 110 | y = self.ssm(x) 111 | y = y * F.silu(res) 112 | 113 | output = self.out_proj(y) 114 | return output 115 | 116 | 117 | def ssm(self, x): 118 | """ 119 | Algorithm 2 in Section 3.2 in the paper 120 | """ 121 | 122 | (d_in, n) = self.A_log.shape 123 | 124 | A = -torch.exp(self.A_log.float()) # [d_in, n] 125 | D = self.D.float() # [d_in] 126 | 127 | x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff] 128 | (delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n] 129 | delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in] 130 | y = self.selective_scan(x, delta, A, B, C, D) 131 | 132 | return y 133 | 134 | def selective_scan(self, u, delta, A, B, C, D): 135 | (b, l, d_in) = u.shape 136 | n = A.shape[1] 137 | 138 | deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization 139 | deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B" 140 | 141 | # selective scan, sequential instead of parallel 142 | x = torch.zeros((b, d_in, n), device=deltaA.device) 143 | ys = [] 144 | for i in range(l): 145 | x = deltaA[:, i] * x + deltaB_u[:, i] 146 | y = einsum(x, C[:, i, :], "b d n, b n -> b d") 147 | ys.append(y) 148 | 149 | y = torch.stack(ys, dim=1) # [B, L, d_in] 150 | y = y + u * D 151 | 152 | return y 153 | 154 | class RMSNorm(nn.Module): 155 | def __init__(self, d_model, eps=1e-5): 156 | super(RMSNorm, self).__init__() 157 | self.eps = eps 158 | self.weight = nn.Parameter(torch.ones(d_model)) 159 | 160 | def forward(self, x): 161 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight 162 | return output 163 | -------------------------------------------------------------------------------- /models/Pyraformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers.Pyraformer_EncDec import Encoder 4 | 5 | 6 | class Model(nn.Module): 7 | """ 8 | Pyraformer: Pyramidal attention to reduce complexity 9 | Paper link: https://openreview.net/pdf?id=0EXmFzUn5I 10 | """ 11 | 12 | def __init__(self, configs, window_size=[4,4], inner_size=5): 13 | """ 14 | window_size: list, the downsample window size in pyramidal attention. 15 | inner_size: int, the size of neighbour attention 16 | """ 17 | super().__init__() 18 | self.task_name = configs.task_name 19 | self.pred_len = configs.pred_len 20 | self.d_model = configs.d_model 21 | 22 | if self.task_name == 'short_term_forecast': 23 | window_size = [2,2] 24 | self.encoder = Encoder(configs, window_size, inner_size) 25 | 26 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 27 | self.projection = nn.Linear( 28 | (len(window_size)+1)*self.d_model, self.pred_len * configs.enc_in) 29 | elif self.task_name == 'imputation' or self.task_name == 'anomaly_detection': 30 | self.projection = nn.Linear( 31 | (len(window_size)+1)*self.d_model, configs.enc_in, bias=True) 32 | elif self.task_name == 'classification': 33 | self.act = torch.nn.functional.gelu 34 | self.dropout = nn.Dropout(configs.dropout) 35 | self.projection = nn.Linear( 36 | (len(window_size)+1)*self.d_model * configs.seq_len, configs.num_class) 37 | 38 | def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 39 | enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :] 40 | dec_out = self.projection(enc_out).view( 41 | enc_out.size(0), self.pred_len, -1) 42 | return dec_out 43 | 44 | def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 45 | # Normalization 46 | mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E 47 | x_enc = x_enc - mean_enc 48 | std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E 49 | x_enc = x_enc / std_enc 50 | 51 | enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :] 52 | dec_out = self.projection(enc_out).view( 53 | enc_out.size(0), self.pred_len, -1) 54 | 55 | dec_out = dec_out * std_enc + mean_enc 56 | return dec_out 57 | 58 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 59 | enc_out = self.encoder(x_enc, x_mark_enc) 60 | dec_out = self.projection(enc_out) 61 | return dec_out 62 | 63 | def anomaly_detection(self, x_enc, x_mark_enc): 64 | enc_out = self.encoder(x_enc, x_mark_enc) 65 | dec_out = self.projection(enc_out) 66 | return dec_out 67 | 68 | def classification(self, x_enc, x_mark_enc): 69 | # enc 70 | enc_out = self.encoder(x_enc, x_mark_enc=None) 71 | 72 | # Output 73 | # the output transformer encoder/decoder embeddings don't include non-linearity 74 | output = self.act(enc_out) 75 | output = self.dropout(output) 76 | # zero-out padding embeddings 77 | output = output * x_mark_enc.unsqueeze(-1) 78 | # (batch_size, seq_length * d_model) 79 | output = output.reshape(output.shape[0], -1) 80 | output = self.projection(output) # (batch_size, num_classes) 81 | 82 | return output 83 | 84 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 85 | if self.task_name == 'long_term_forecast': 86 | dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 87 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 88 | if self.task_name == 'short_term_forecast': 89 | dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 90 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 91 | if self.task_name == 'imputation': 92 | dec_out = self.imputation( 93 | x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 94 | return dec_out # [B, L, D] 95 | if self.task_name == 'anomaly_detection': 96 | dec_out = self.anomaly_detection(x_enc, x_mark_enc) 97 | return dec_out # [B, L, D] 98 | if self.task_name == 'classification': 99 | dec_out = self.classification(x_enc, x_mark_enc) 100 | return dec_out # [B, N] 101 | return None 102 | -------------------------------------------------------------------------------- /models/Reformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import ReformerLayer 6 | from layers.Embed import DataEmbedding 7 | 8 | 9 | class Model(nn.Module): 10 | """ 11 | Reformer with O(LlogL) complexity 12 | Paper link: https://openreview.net/forum?id=rkgNKkHtvB 13 | """ 14 | 15 | def __init__(self, configs, bucket_size=4, n_hashes=4): 16 | """ 17 | bucket_size: int, 18 | n_hashes: int, 19 | """ 20 | super(Model, self).__init__() 21 | self.task_name = configs.task_name 22 | self.pred_len = configs.pred_len 23 | self.seq_len = configs.seq_len 24 | 25 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 26 | configs.dropout) 27 | # Encoder 28 | self.encoder = Encoder( 29 | [ 30 | EncoderLayer( 31 | ReformerLayer(None, configs.d_model, configs.n_heads, 32 | bucket_size=bucket_size, n_hashes=n_hashes), 33 | configs.d_model, 34 | configs.d_ff, 35 | dropout=configs.dropout, 36 | activation=configs.activation 37 | ) for l in range(configs.e_layers) 38 | ], 39 | norm_layer=torch.nn.LayerNorm(configs.d_model) 40 | ) 41 | 42 | if self.task_name == 'classification': 43 | self.act = F.gelu 44 | self.dropout = nn.Dropout(configs.dropout) 45 | self.projection = nn.Linear( 46 | configs.d_model * configs.seq_len, configs.num_class) 47 | else: 48 | self.projection = nn.Linear( 49 | configs.d_model, configs.c_out, bias=True) 50 | 51 | def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 52 | # add placeholder 53 | x_enc = torch.cat([x_enc, x_dec[:, -self.pred_len:, :]], dim=1) 54 | if x_mark_enc is not None: 55 | x_mark_enc = torch.cat( 56 | [x_mark_enc, x_mark_dec[:, -self.pred_len:, :]], dim=1) 57 | 58 | enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C] 59 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 60 | dec_out = self.projection(enc_out) 61 | 62 | return dec_out # [B, L, D] 63 | 64 | def short_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 65 | # Normalization 66 | mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E 67 | x_enc = x_enc - mean_enc 68 | std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E 69 | x_enc = x_enc / std_enc 70 | 71 | # add placeholder 72 | x_enc = torch.cat([x_enc, x_dec[:, -self.pred_len:, :]], dim=1) 73 | if x_mark_enc is not None: 74 | x_mark_enc = torch.cat( 75 | [x_mark_enc, x_mark_dec[:, -self.pred_len:, :]], dim=1) 76 | 77 | enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C] 78 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 79 | dec_out = self.projection(enc_out) 80 | 81 | dec_out = dec_out * std_enc + mean_enc 82 | return dec_out # [B, L, D] 83 | 84 | def imputation(self, x_enc, x_mark_enc): 85 | enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C] 86 | 87 | enc_out, attns = self.encoder(enc_out) 88 | enc_out = self.projection(enc_out) 89 | 90 | return enc_out # [B, L, D] 91 | 92 | def anomaly_detection(self, x_enc): 93 | enc_out = self.enc_embedding(x_enc, None) # [B,T,C] 94 | 95 | enc_out, attns = self.encoder(enc_out) 96 | enc_out = self.projection(enc_out) 97 | 98 | return enc_out # [B, L, D] 99 | 100 | def classification(self, x_enc, x_mark_enc): 101 | # enc 102 | enc_out = self.enc_embedding(x_enc, None) 103 | enc_out, attns = self.encoder(enc_out) 104 | 105 | # Output 106 | # the output transformer encoder/decoder embeddings don't include non-linearity 107 | output = self.act(enc_out) 108 | output = self.dropout(output) 109 | # zero-out padding embeddings 110 | output = output * x_mark_enc.unsqueeze(-1) 111 | # (batch_size, seq_length * d_model) 112 | output = output.reshape(output.shape[0], -1) 113 | output = self.projection(output) # (batch_size, num_classes) 114 | return output 115 | 116 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 117 | if self.task_name == 'long_term_forecast': 118 | dec_out = self.long_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 119 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 120 | if self.task_name == 'short_term_forecast': 121 | dec_out = self.short_forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 122 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 123 | if self.task_name == 'imputation': 124 | dec_out = self.imputation(x_enc, x_mark_enc) 125 | return dec_out # [B, L, D] 126 | if self.task_name == 'anomaly_detection': 127 | dec_out = self.anomaly_detection(x_enc) 128 | return dec_out # [B, L, D] 129 | if self.task_name == 'classification': 130 | dec_out = self.classification(x_enc, x_mark_enc) 131 | return dec_out # [B, N] 132 | return None 133 | -------------------------------------------------------------------------------- /models/SCINet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class Splitting(nn.Module): 7 | def __init__(self): 8 | super(Splitting, self).__init__() 9 | 10 | def even(self, x): 11 | return x[:, ::2, :] 12 | 13 | def odd(self, x): 14 | return x[:, 1::2, :] 15 | 16 | def forward(self, x): 17 | # return the odd and even part 18 | return self.even(x), self.odd(x) 19 | 20 | 21 | class CausalConvBlock(nn.Module): 22 | def __init__(self, d_model, kernel_size=5, dropout=0.0): 23 | super(CausalConvBlock, self).__init__() 24 | module_list = [ 25 | nn.ReplicationPad1d((kernel_size - 1, kernel_size - 1)), 26 | 27 | nn.Conv1d(d_model, d_model, 28 | kernel_size=kernel_size), 29 | nn.LeakyReLU(negative_slope=0.01, inplace=True), 30 | 31 | nn.Dropout(dropout), 32 | nn.Conv1d(d_model, d_model, 33 | kernel_size=kernel_size), 34 | nn.Tanh() 35 | ] 36 | self.causal_conv = nn.Sequential(*module_list) 37 | 38 | def forward(self, x): 39 | return self.causal_conv(x) # return value is the same as input dimension 40 | 41 | 42 | class SCIBlock(nn.Module): 43 | def __init__(self, d_model, kernel_size=5, dropout=0.0): 44 | super(SCIBlock, self).__init__() 45 | self.splitting = Splitting() 46 | self.modules_even, self.modules_odd, self.interactor_even, self.interactor_odd = [CausalConvBlock(d_model) for _ in range(4)] 47 | 48 | def forward(self, x): 49 | x_even, x_odd = self.splitting(x) 50 | x_even = x_even.permute(0, 2, 1) 51 | x_odd = x_odd.permute(0, 2, 1) 52 | 53 | x_even_temp = x_even.mul(torch.exp(self.modules_even(x_odd))) 54 | x_odd_temp = x_odd.mul(torch.exp(self.modules_odd(x_even))) 55 | 56 | x_even_update = x_even_temp + self.interactor_even(x_odd_temp) 57 | x_odd_update = x_odd_temp - self.interactor_odd(x_even_temp) 58 | 59 | return x_even_update.permute(0, 2, 1), x_odd_update.permute(0, 2, 1) 60 | 61 | 62 | class SCINet(nn.Module): 63 | def __init__(self, d_model, current_level=3, kernel_size=5, dropout=0.0): 64 | super(SCINet, self).__init__() 65 | self.current_level = current_level 66 | self.working_block = SCIBlock(d_model, kernel_size, dropout) 67 | 68 | if current_level != 0: 69 | self.SCINet_Tree_odd = SCINet(d_model, current_level-1, kernel_size, dropout) 70 | self.SCINet_Tree_even = SCINet(d_model, current_level-1, kernel_size, dropout) 71 | 72 | def forward(self, x): 73 | odd_flag = False 74 | if x.shape[1] % 2 == 1: 75 | odd_flag = True 76 | x = torch.cat((x, x[:, -1:, :]), dim=1) 77 | x_even_update, x_odd_update = self.working_block(x) 78 | if odd_flag: 79 | x_odd_update = x_odd_update[:, :-1] 80 | 81 | if self.current_level == 0: 82 | return self.zip_up_the_pants(x_even_update, x_odd_update) 83 | else: 84 | return self.zip_up_the_pants(self.SCINet_Tree_even(x_even_update), self.SCINet_Tree_odd(x_odd_update)) 85 | 86 | def zip_up_the_pants(self, even, odd): 87 | even = even.permute(1, 0, 2) 88 | odd = odd.permute(1, 0, 2) 89 | even_len = even.shape[0] 90 | odd_len = odd.shape[0] 91 | min_len = min(even_len, odd_len) 92 | 93 | zipped_data = [] 94 | for i in range(min_len): 95 | zipped_data.append(even[i].unsqueeze(0)) 96 | zipped_data.append(odd[i].unsqueeze(0)) 97 | if even_len > odd_len: 98 | zipped_data.append(even[-1].unsqueeze(0)) 99 | return torch.cat(zipped_data,0).permute(1, 0, 2) 100 | 101 | 102 | class Model(nn.Module): 103 | def __init__(self, configs): 104 | super(Model, self).__init__() 105 | self.task_name = configs.task_name 106 | self.seq_len = configs.seq_len 107 | self.label_len = configs.label_len 108 | self.pred_len = configs.pred_len 109 | 110 | # You can set the number of SCINet stacks by argument "d_layers", but should choose 1 or 2. 111 | self.num_stacks = configs.d_layers 112 | if self.num_stacks == 1: 113 | self.sci_net_1 = SCINet(configs.enc_in, dropout=configs.dropout) 114 | self.projection_1 = nn.Conv1d(self.seq_len, self.seq_len + self.pred_len, kernel_size=1, stride=1, bias=False) 115 | else: 116 | self.sci_net_1, self.sci_net_2 = [SCINet(configs.enc_in, dropout=configs.dropout) for _ in range(2)] 117 | self.projection_1 = nn.Conv1d(self.seq_len, self.pred_len, kernel_size=1, stride=1, bias=False) 118 | self.projection_2 = nn.Conv1d(self.seq_len+self.pred_len, self.seq_len+self.pred_len, 119 | kernel_size = 1, bias = False) 120 | 121 | # For positional encoding 122 | self.pe_hidden_size = configs.enc_in 123 | if self.pe_hidden_size % 2 == 1: 124 | self.pe_hidden_size += 1 125 | 126 | num_timescales = self.pe_hidden_size // 2 127 | max_timescale = 10000.0 128 | min_timescale = 1.0 129 | 130 | log_timescale_increment = ( 131 | math.log(float(max_timescale) / float(min_timescale)) / 132 | max(num_timescales - 1, 1)) 133 | inv_timescales = min_timescale * torch.exp( 134 | torch.arange(num_timescales, dtype=torch.float32) * 135 | -log_timescale_increment) 136 | self.register_buffer('inv_timescales', inv_timescales) 137 | 138 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 139 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 140 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) # [B,pred_len,C] 141 | dec_out = torch.cat([torch.zeros_like(x_enc), dec_out], dim=1) 142 | return dec_out # [B, T, D] 143 | return None 144 | 145 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 146 | # Normalization from Non-stationary Transformer 147 | means = x_enc.mean(1, keepdim=True).detach() 148 | x_enc = x_enc - means 149 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 150 | x_enc /= stdev 151 | 152 | # position-encoding 153 | pe = self.get_position_encoding(x_enc) 154 | if pe.shape[2] > x_enc.shape[2]: 155 | x_enc += pe[:, :, :-1] 156 | else: 157 | x_enc += self.get_position_encoding(x_enc) 158 | 159 | # SCINet 160 | dec_out = self.sci_net_1(x_enc) 161 | dec_out += x_enc 162 | dec_out = self.projection_1(dec_out) 163 | if self.num_stacks != 1: 164 | dec_out = torch.cat((x_enc, dec_out), dim=1) 165 | temp = dec_out 166 | dec_out = self.sci_net_2(dec_out) 167 | dec_out += temp 168 | dec_out = self.projection_2(dec_out) 169 | 170 | # De-Normalization from Non-stationary Transformer 171 | dec_out = dec_out * \ 172 | (stdev[:, 0, :].unsqueeze(1).repeat( 173 | 1, self.pred_len + self.seq_len, 1)) 174 | dec_out = dec_out + \ 175 | (means[:, 0, :].unsqueeze(1).repeat( 176 | 1, self.pred_len + self.seq_len, 1)) 177 | return dec_out 178 | 179 | def get_position_encoding(self, x): 180 | max_length = x.size()[1] 181 | position = torch.arange(max_length, dtype=torch.float32, 182 | device=x.device) # tensor([0., 1., 2., 3., 4.], device='cuda:0') 183 | scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0) # 5 256 184 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) # [T, C] 185 | signal = F.pad(signal, (0, 0, 0, self.pe_hidden_size % 2)) 186 | signal = signal.view(1, max_length, self.pe_hidden_size) 187 | 188 | return signal -------------------------------------------------------------------------------- /models/SegRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Autoformer_EncDec import series_decomp 5 | 6 | 7 | class Model(nn.Module): 8 | """ 9 | Paper link: https://arxiv.org/abs/2308.11200.pdf 10 | """ 11 | 12 | def __init__(self, configs): 13 | super(Model, self).__init__() 14 | 15 | # get parameters 16 | self.seq_len = configs.seq_len 17 | self.enc_in = configs.enc_in 18 | self.d_model = configs.d_model 19 | self.dropout = configs.dropout 20 | 21 | self.task_name = configs.task_name 22 | if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation': 23 | self.pred_len = configs.seq_len 24 | else: 25 | self.pred_len = configs.pred_len 26 | 27 | self.seg_len = configs.seg_len 28 | self.seg_num_x = self.seq_len // self.seg_len 29 | self.seg_num_y = self.pred_len // self.seg_len 30 | 31 | # building model 32 | self.valueEmbedding = nn.Sequential( 33 | nn.Linear(self.seg_len, self.d_model), 34 | nn.ReLU() 35 | ) 36 | self.rnn = nn.GRU(input_size=self.d_model, hidden_size=self.d_model, num_layers=1, bias=True, 37 | batch_first=True, bidirectional=False) 38 | self.pos_emb = nn.Parameter(torch.randn(self.seg_num_y, self.d_model // 2)) 39 | self.channel_emb = nn.Parameter(torch.randn(self.enc_in, self.d_model // 2)) 40 | 41 | self.predict = nn.Sequential( 42 | nn.Dropout(self.dropout), 43 | nn.Linear(self.d_model, self.seg_len) 44 | ) 45 | 46 | if self.task_name == 'classification': 47 | self.act = F.gelu 48 | self.dropout = nn.Dropout(configs.dropout) 49 | self.projection = nn.Linear( 50 | configs.enc_in * configs.seq_len, configs.num_class) 51 | 52 | def encoder(self, x): 53 | # b:batch_size c:channel_size s:seq_len s:seq_len 54 | # d:d_model w:seg_len n:seg_num_x m:seg_num_y 55 | batch_size = x.size(0) 56 | 57 | # normalization and permute b,s,c -> b,c,s 58 | seq_last = x[:, -1:, :].detach() 59 | x = (x - seq_last).permute(0, 2, 1) # b,c,s 60 | 61 | # segment and embedding b,c,s -> bc,n,w -> bc,n,d 62 | x = self.valueEmbedding(x.reshape(-1, self.seg_num_x, self.seg_len)) 63 | 64 | # encoding 65 | _, hn = self.rnn(x) # bc,n,d 1,bc,d 66 | 67 | # m,d//2 -> 1,m,d//2 -> c,m,d//2 68 | # c,d//2 -> c,1,d//2 -> c,m,d//2 69 | # c,m,d -> cm,1,d -> bcm, 1, d 70 | pos_emb = torch.cat([ 71 | self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1), 72 | self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1) 73 | ], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1) 74 | 75 | _, hy = self.rnn(pos_emb, hn.repeat(1, 1, self.seg_num_y).view(1, -1, self.d_model)) # bcm,1,d 1,bcm,d 76 | 77 | # 1,bcm,d -> 1,bcm,w -> b,c,s 78 | y = self.predict(hy).view(-1, self.enc_in, self.pred_len) 79 | 80 | # permute and denorm 81 | y = y.permute(0, 2, 1) + seq_last 82 | return y 83 | 84 | def forecast(self, x_enc): 85 | # Encoder 86 | return self.encoder(x_enc) 87 | 88 | def imputation(self, x_enc): 89 | # Encoder 90 | return self.encoder(x_enc) 91 | 92 | def anomaly_detection(self, x_enc): 93 | # Encoder 94 | return self.encoder(x_enc) 95 | 96 | def classification(self, x_enc): 97 | # Encoder 98 | enc_out = self.encoder(x_enc) 99 | # Output 100 | # (batch_size, seq_length * d_model) 101 | output = enc_out.reshape(enc_out.shape[0], -1) 102 | # (batch_size, num_classes) 103 | output = self.projection(output) 104 | return output 105 | 106 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 107 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 108 | dec_out = self.forecast(x_enc) 109 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 110 | if self.task_name == 'imputation': 111 | dec_out = self.imputation(x_enc) 112 | return dec_out # [B, L, D] 113 | if self.task_name == 'anomaly_detection': 114 | dec_out = self.anomaly_detection(x_enc) 115 | return dec_out # [B, L, D] 116 | if self.task_name == 'classification': 117 | dec_out = self.classification(x_enc) 118 | return dec_out # [B, N] 119 | return None 120 | -------------------------------------------------------------------------------- /models/TSMixer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ResBlock(nn.Module): 5 | def __init__(self, configs): 6 | super(ResBlock, self).__init__() 7 | 8 | self.temporal = nn.Sequential( 9 | nn.Linear(configs.seq_len, configs.d_model), 10 | nn.ReLU(), 11 | nn.Linear(configs.d_model, configs.seq_len), 12 | nn.Dropout(configs.dropout) 13 | ) 14 | 15 | self.channel = nn.Sequential( 16 | nn.Linear(configs.enc_in, configs.d_model), 17 | nn.ReLU(), 18 | nn.Linear(configs.d_model, configs.enc_in), 19 | nn.Dropout(configs.dropout) 20 | ) 21 | 22 | def forward(self, x): 23 | # x: [B, L, D] 24 | x = x + self.temporal(x.transpose(1, 2)).transpose(1, 2) 25 | x = x + self.channel(x) 26 | 27 | return x 28 | 29 | 30 | class Model(nn.Module): 31 | def __init__(self, configs): 32 | super(Model, self).__init__() 33 | self.task_name = configs.task_name 34 | self.layer = configs.e_layers 35 | self.model = nn.ModuleList([ResBlock(configs) 36 | for _ in range(configs.e_layers)]) 37 | self.pred_len = configs.pred_len 38 | self.projection = nn.Linear(configs.seq_len, configs.pred_len) 39 | 40 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 41 | 42 | # x: [B, L, D] 43 | for i in range(self.layer): 44 | x_enc = self.model[i](x_enc) 45 | enc_out = self.projection(x_enc.transpose(1, 2)).transpose(1, 2) 46 | 47 | return enc_out 48 | 49 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 50 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 51 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 52 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 53 | else: 54 | raise ValueError('Only forecast tasks implemented yet') 55 | -------------------------------------------------------------------------------- /models/TiDE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LayerNorm(nn.Module): 7 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 8 | 9 | def __init__(self, ndim, bias): 10 | super().__init__() 11 | self.weight = nn.Parameter(torch.ones(ndim)) 12 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 13 | 14 | def forward(self, input): 15 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 16 | 17 | 18 | 19 | class ResBlock(nn.Module): 20 | def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1, bias=True): 21 | super().__init__() 22 | 23 | self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias) 24 | self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias) 25 | self.fc3 = nn.Linear(input_dim, output_dim, bias=bias) 26 | self.dropout = nn.Dropout(dropout) 27 | self.relu = nn.ReLU() 28 | self.ln = LayerNorm(output_dim, bias=bias) 29 | 30 | def forward(self, x): 31 | 32 | out = self.fc1(x) 33 | out = self.relu(out) 34 | out = self.fc2(out) 35 | out = self.dropout(out) 36 | out = out + self.fc3(x) 37 | out = self.ln(out) 38 | return out 39 | 40 | 41 | #TiDE 42 | class Model(nn.Module): 43 | """ 44 | paper: https://arxiv.org/pdf/2304.08424.pdf 45 | """ 46 | def __init__(self, configs, bias=True, feature_encode_dim=2): 47 | super(Model, self).__init__() 48 | self.configs = configs 49 | self.task_name = configs.task_name 50 | self.seq_len = configs.seq_len #L 51 | self.label_len = configs.label_len 52 | self.pred_len = configs.pred_len #H 53 | self.hidden_dim=configs.d_model 54 | self.res_hidden=configs.d_model 55 | self.encoder_num=configs.e_layers 56 | self.decoder_num=configs.d_layers 57 | self.freq=configs.freq 58 | self.feature_encode_dim=feature_encode_dim 59 | self.decode_dim = configs.c_out 60 | self.temporalDecoderHidden=configs.d_ff 61 | dropout=configs.dropout 62 | 63 | 64 | freq_map = {'h': 4, 't': 5, 's': 6, 65 | 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 66 | 67 | self.feature_dim=freq_map[self.freq] 68 | 69 | 70 | flatten_dim = self.seq_len + (self.seq_len + self.pred_len) * self.feature_encode_dim 71 | 72 | self.feature_encoder = ResBlock(self.feature_dim, self.res_hidden, self.feature_encode_dim, dropout, bias) 73 | self.encoders = nn.Sequential(ResBlock(flatten_dim, self.res_hidden, self.hidden_dim, dropout, bias),*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.encoder_num-1))) 74 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 75 | self.decoders = nn.Sequential(*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.decoder_num-1)),ResBlock(self.hidden_dim, self.res_hidden, self.decode_dim * self.pred_len, dropout, bias)) 76 | self.temporalDecoder = ResBlock(self.decode_dim + self.feature_encode_dim, self.temporalDecoderHidden, 1, dropout, bias) 77 | self.residual_proj = nn.Linear(self.seq_len, self.pred_len, bias=bias) 78 | if self.task_name == 'imputation': 79 | self.decoders = nn.Sequential(*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.decoder_num-1)),ResBlock(self.hidden_dim, self.res_hidden, self.decode_dim * self.seq_len, dropout, bias)) 80 | self.temporalDecoder = ResBlock(self.decode_dim + self.feature_encode_dim, self.temporalDecoderHidden, 1, dropout, bias) 81 | self.residual_proj = nn.Linear(self.seq_len, self.seq_len, bias=bias) 82 | if self.task_name == 'anomaly_detection': 83 | self.decoders = nn.Sequential(*([ ResBlock(self.hidden_dim, self.res_hidden, self.hidden_dim, dropout, bias)]*(self.decoder_num-1)),ResBlock(self.hidden_dim, self.res_hidden, self.decode_dim * self.seq_len, dropout, bias)) 84 | self.temporalDecoder = ResBlock(self.decode_dim + self.feature_encode_dim, self.temporalDecoderHidden, 1, dropout, bias) 85 | self.residual_proj = nn.Linear(self.seq_len, self.seq_len, bias=bias) 86 | 87 | 88 | def forecast(self, x_enc, x_mark_enc, x_dec, batch_y_mark): 89 | # Normalization 90 | means = x_enc.mean(1, keepdim=True).detach() 91 | x_enc = x_enc - means 92 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 93 | x_enc /= stdev 94 | 95 | feature = self.feature_encoder(batch_y_mark) 96 | hidden = self.encoders(torch.cat([x_enc, feature.reshape(feature.shape[0], -1)], dim=-1)) 97 | decoded = self.decoders(hidden).reshape(hidden.shape[0], self.pred_len, self.decode_dim) 98 | dec_out = self.temporalDecoder(torch.cat([feature[:,self.seq_len:], decoded], dim=-1)).squeeze(-1) + self.residual_proj(x_enc) 99 | 100 | 101 | # De-Normalization 102 | dec_out = dec_out * (stdev[:, 0].unsqueeze(1).repeat(1, self.pred_len)) 103 | dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.pred_len)) 104 | return dec_out 105 | 106 | def imputation(self, x_enc, x_mark_enc, x_dec, batch_y_mark, mask): 107 | # Normalization 108 | means = x_enc.mean(1, keepdim=True).detach() 109 | x_enc = x_enc - means 110 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 111 | x_enc /= stdev 112 | 113 | feature = self.feature_encoder(x_mark_enc) 114 | hidden = self.encoders(torch.cat([x_enc, feature.reshape(feature.shape[0], -1)], dim=-1)) 115 | decoded = self.decoders(hidden).reshape(hidden.shape[0], self.seq_len, self.decode_dim) 116 | dec_out = self.temporalDecoder(torch.cat([feature[:,:self.seq_len], decoded], dim=-1)).squeeze(-1) + self.residual_proj(x_enc) 117 | 118 | # De-Normalization 119 | dec_out = dec_out * (stdev[:, 0].unsqueeze(1).repeat(1, self.seq_len)) 120 | dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.seq_len)) 121 | return dec_out 122 | 123 | 124 | def forward(self, x_enc, x_mark_enc, x_dec, batch_y_mark, mask=None): 125 | '''x_mark_enc is the exogenous dynamic feature described in the original paper''' 126 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 127 | if batch_y_mark is None: 128 | batch_y_mark = torch.zeros((x_enc.shape[0], self.seq_len+self.pred_len, self.feature_dim)).to(x_enc.device).detach() 129 | else: 130 | batch_y_mark = torch.concat([x_mark_enc, batch_y_mark[:, -self.pred_len:, :]],dim=1) 131 | dec_out = torch.stack([self.forecast(x_enc[:, :, feature], x_mark_enc, x_dec, batch_y_mark) for feature in range(x_enc.shape[-1])],dim=-1) 132 | return dec_out # [B, L, D] 133 | if self.task_name == 'imputation': 134 | dec_out = torch.stack([self.imputation(x_enc[:, :, feature], x_mark_enc, x_dec, batch_y_mark, mask) for feature in range(x_enc.shape[-1])],dim=-1) 135 | return dec_out # [B, L, D] 136 | if self.task_name == 'anomaly_detection': 137 | raise NotImplementedError("Task anomaly_detection for Tide is temporarily not supported") 138 | if self.task_name == 'classification': 139 | raise NotImplementedError("Task classification for Tide is temporarily not supported") 140 | return None 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /models/Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer 5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Vanilla Transformer 13 | with O(L^2) complexity 14 | Paper link: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf 15 | """ 16 | 17 | def __init__(self, configs): 18 | super(Model, self).__init__() 19 | self.task_name = configs.task_name 20 | self.pred_len = configs.pred_len 21 | # Embedding 22 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 23 | configs.dropout) 24 | # Encoder 25 | self.encoder = Encoder( 26 | [ 27 | EncoderLayer( 28 | AttentionLayer( 29 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 30 | output_attention=False), configs.d_model, configs.n_heads), 31 | configs.d_model, 32 | configs.d_ff, 33 | dropout=configs.dropout, 34 | activation=configs.activation 35 | ) for l in range(configs.e_layers) 36 | ], 37 | norm_layer=torch.nn.LayerNorm(configs.d_model) 38 | ) 39 | # Decoder 40 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 41 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 42 | configs.dropout) 43 | self.decoder = Decoder( 44 | [ 45 | DecoderLayer( 46 | AttentionLayer( 47 | FullAttention(True, configs.factor, attention_dropout=configs.dropout, 48 | output_attention=False), 49 | configs.d_model, configs.n_heads), 50 | AttentionLayer( 51 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 52 | output_attention=False), 53 | configs.d_model, configs.n_heads), 54 | configs.d_model, 55 | configs.d_ff, 56 | dropout=configs.dropout, 57 | activation=configs.activation, 58 | ) 59 | for l in range(configs.d_layers) 60 | ], 61 | norm_layer=torch.nn.LayerNorm(configs.d_model), 62 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 63 | ) 64 | if self.task_name == 'imputation': 65 | self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True) 66 | if self.task_name == 'anomaly_detection': 67 | self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True) 68 | if self.task_name == 'classification': 69 | self.act = F.gelu 70 | self.dropout = nn.Dropout(configs.dropout) 71 | self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class) 72 | 73 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 74 | # Embedding 75 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 76 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 77 | 78 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 79 | dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None) 80 | return dec_out 81 | 82 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 83 | # Embedding 84 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 85 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 86 | 87 | dec_out = self.projection(enc_out) 88 | return dec_out 89 | 90 | def anomaly_detection(self, x_enc): 91 | # Embedding 92 | enc_out = self.enc_embedding(x_enc, None) 93 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 94 | 95 | dec_out = self.projection(enc_out) 96 | return dec_out 97 | 98 | def classification(self, x_enc, x_mark_enc): 99 | # Embedding 100 | enc_out = self.enc_embedding(x_enc, None) 101 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 102 | 103 | # Output 104 | output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity 105 | output = self.dropout(output) 106 | output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings 107 | output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model) 108 | output = self.projection(output) # (batch_size, num_classes) 109 | return output 110 | 111 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 112 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 113 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 114 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 115 | if self.task_name == 'imputation': 116 | dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 117 | return dec_out # [B, L, D] 118 | if self.task_name == 'anomaly_detection': 119 | dec_out = self.anomaly_detection(x_enc) 120 | return dec_out # [B, L, D] 121 | if self.task_name == 'classification': 122 | dec_out = self.classification(x_enc, x_mark_enc) 123 | return dec_out # [B, N] 124 | return None 125 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/TimeXer/76011909357972bd55a27adba2e1be994d81b327/models/__init__.py -------------------------------------------------------------------------------- /models/iTransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding_inverted 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Paper link: https://arxiv.org/abs/2310.06625 13 | """ 14 | 15 | def __init__(self, configs): 16 | super(Model, self).__init__() 17 | self.task_name = configs.task_name 18 | self.seq_len = configs.seq_len 19 | self.pred_len = configs.pred_len 20 | # Embedding 21 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq, 22 | configs.dropout) 23 | # Encoder 24 | self.encoder = Encoder( 25 | [ 26 | EncoderLayer( 27 | AttentionLayer( 28 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 29 | output_attention=False), configs.d_model, configs.n_heads), 30 | configs.d_model, 31 | configs.d_ff, 32 | dropout=configs.dropout, 33 | activation=configs.activation 34 | ) for l in range(configs.e_layers) 35 | ], 36 | norm_layer=torch.nn.LayerNorm(configs.d_model) 37 | ) 38 | # Decoder 39 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 40 | self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True) 41 | if self.task_name == 'imputation': 42 | self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True) 43 | if self.task_name == 'anomaly_detection': 44 | self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True) 45 | if self.task_name == 'classification': 46 | self.act = F.gelu 47 | self.dropout = nn.Dropout(configs.dropout) 48 | self.projection = nn.Linear(configs.d_model * configs.enc_in, configs.num_class) 49 | 50 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 51 | # Normalization from Non-stationary Transformer 52 | means = x_enc.mean(1, keepdim=True).detach() 53 | x_enc = x_enc - means 54 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 55 | x_enc /= stdev 56 | 57 | _, _, N = x_enc.shape 58 | 59 | # Embedding 60 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 61 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 62 | 63 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] 64 | # De-Normalization from Non-stationary Transformer 65 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 66 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 67 | return dec_out 68 | 69 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 70 | # Normalization from Non-stationary Transformer 71 | means = x_enc.mean(1, keepdim=True).detach() 72 | x_enc = x_enc - means 73 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 74 | x_enc /= stdev 75 | 76 | _, L, N = x_enc.shape 77 | 78 | # Embedding 79 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 80 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 81 | 82 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] 83 | # De-Normalization from Non-stationary Transformer 84 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 85 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 86 | return dec_out 87 | 88 | def anomaly_detection(self, x_enc): 89 | # Normalization from Non-stationary Transformer 90 | means = x_enc.mean(1, keepdim=True).detach() 91 | x_enc = x_enc - means 92 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 93 | x_enc /= stdev 94 | 95 | _, L, N = x_enc.shape 96 | 97 | # Embedding 98 | enc_out = self.enc_embedding(x_enc, None) 99 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 100 | 101 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] 102 | # De-Normalization from Non-stationary Transformer 103 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 104 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 105 | return dec_out 106 | 107 | def classification(self, x_enc, x_mark_enc): 108 | # Embedding 109 | enc_out = self.enc_embedding(x_enc, None) 110 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 111 | 112 | # Output 113 | output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity 114 | output = self.dropout(output) 115 | output = output.reshape(output.shape[0], -1) # (batch_size, c_in * d_model) 116 | output = self.projection(output) # (batch_size, num_classes) 117 | return output 118 | 119 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 120 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 121 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 122 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 123 | if self.task_name == 'imputation': 124 | dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 125 | return dec_out # [B, L, D] 126 | if self.task_name == 'anomaly_detection': 127 | dec_out = self.anomaly_detection(x_enc) 128 | return dec_out # [B, L, D] 129 | if self.task_name == 'classification': 130 | dec_out = self.classification(x_enc, x_mark_enc) 131 | return dec_out # [B, N] 132 | return None 133 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.4.0 2 | matplotlib==3.7.0 3 | numpy==1.23.5 4 | pandas==1.5.3 5 | patool==1.12 6 | reformer-pytorch==1.4.4 7 | scikit-learn==1.2.2 8 | scipy==1.10.1 9 | sktime==0.16.1 10 | sympy==1.11.1 11 | torch==2.0.0 12 | tqdm==4.64.1 13 | -------------------------------------------------------------------------------- /scripts/forecast_exogenous/ECL/TimeXer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=TimeXer 4 | des='Timexer-MS' 5 | 6 | 7 | python3 -u run.py \ 8 | --task_name long_term_forecast \ 9 | --is_training 1 \ 10 | --root_path ./dataset/electricity/ \ 11 | --data_path electricity.csv \ 12 | --model_id ECL_96_96 \ 13 | --model $model_name \ 14 | --data custom \ 15 | --features MS \ 16 | --seq_len 96 \ 17 | --label_len 48 \ 18 | --pred_len 96 \ 19 | --e_layers 1 \ 20 | --factor 3 \ 21 | --enc_in 321 \ 22 | --dec_in 321 \ 23 | --c_out 321 \ 24 | --des $des \ 25 | --batch_size 4 \ 26 | --itr 1 27 | 28 | python3 -u run.py \ 29 | --task_name long_term_forecast \ 30 | --is_training 1 \ 31 | --root_path ./dataset/electricity/ \ 32 | --data_path electricity.csv \ 33 | --model_id ECL_96_192 \ 34 | --model $model_name \ 35 | --data custom \ 36 | --features MS \ 37 | --seq_len 96 \ 38 | --label_len 48 \ 39 | --pred_len 192 \ 40 | --e_layers 1 \ 41 | --factor 3 \ 42 | --enc_in 321 \ 43 | --dec_in 321 \ 44 | --c_out 321 \ 45 | --des $des \ 46 | --batch_size 32 \ 47 | --itr 1 48 | 49 | python3 -u run.py \ 50 | --task_name long_term_forecast \ 51 | --is_training 1 \ 52 | --root_path ./dataset/electricity/ \ 53 | --data_path electricity.csv \ 54 | --model_id ECL_96_336 \ 55 | --model $model_name \ 56 | --data custom \ 57 | --features MS \ 58 | --seq_len 96 \ 59 | --label_len 48 \ 60 | --pred_len 336 \ 61 | --e_layers 1 \ 62 | --factor 3 \ 63 | --enc_in 321 \ 64 | --dec_in 321 \ 65 | --c_out 321 \ 66 | --des $des \ 67 | --batch_size 32 \ 68 | --itr 1 69 | 70 | python3 -u run.py \ 71 | --task_name long_term_forecast \ 72 | --is_training 1 \ 73 | --root_path ./dataset/electricity/ \ 74 | --data_path electricity.csv \ 75 | --model_id ECL_96_720 \ 76 | --model $model_name \ 77 | --data custom \ 78 | --features MS \ 79 | --seq_len 96 \ 80 | --label_len 48 \ 81 | --pred_len 720 \ 82 | --e_layers 3 \ 83 | --factor 3 \ 84 | --enc_in 321 \ 85 | --dec_in 321 \ 86 | --c_out 321 \ 87 | --des $des \ 88 | --d_model 512 \ 89 | --itr 1 90 | -------------------------------------------------------------------------------- /scripts/forecast_exogenous/EPF/TimeXer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=TimeXer 4 | des='Timexer-MS' 5 | patch_len=24 6 | 7 | 8 | python3 -u run.py \ 9 | --is_training 1 \ 10 | --task_name long_term_forecast \ 11 | --root_path ./dataset/EPF/ \ 12 | --data_path NP.csv \ 13 | --model_id NP_168_24 \ 14 | --model $model_name \ 15 | --data custom \ 16 | --features MS \ 17 | --seq_len 168 \ 18 | --pred_len 24 \ 19 | --e_layers 3 \ 20 | --enc_in 3 \ 21 | --dec_in 3 \ 22 | --c_out 1 \ 23 | --des $des \ 24 | --patch_len $patch_len \ 25 | --d_model 512 \ 26 | --d_ff 512 \ 27 | --batch_size 4 \ 28 | --itr 1 29 | 30 | python3 -u run.py \ 31 | --is_training 1 \ 32 | --task_name long_term_forecast \ 33 | --root_path ./dataset/EPF/ \ 34 | --data_path PJM.csv \ 35 | --model_id PJM_168_24 \ 36 | --model $model_name \ 37 | --data custom \ 38 | --features MS \ 39 | --seq_len 168 \ 40 | --pred_len 24 \ 41 | --e_layers 3 \ 42 | --enc_in 3 \ 43 | --dec_in 3 \ 44 | --c_out 1 \ 45 | --des $des \ 46 | --patch_len $patch_len \ 47 | --d_model 512 \ 48 | --batch_size 16 \ 49 | --itr 1 50 | 51 | python3 -u run.py \ 52 | --is_training 1 \ 53 | --task_name long_term_forecast \ 54 | --root_path ./dataset/EPF/ \ 55 | --data_path BE.csv \ 56 | --model_id BE_168_24 \ 57 | --model $model_name \ 58 | --data custom \ 59 | --features MS \ 60 | --seq_len 168 \ 61 | --pred_len 24 \ 62 | --e_layers 2 \ 63 | --enc_in 3 \ 64 | --dec_in 3 \ 65 | --c_out 1 \ 66 | --des $des \ 67 | --patch_len $patch_len \ 68 | --d_model 512 \ 69 | --d_ff 512 \ 70 | --batch_size 16 \ 71 | --itr 1 72 | 73 | 74 | python3 -u run.py \ 75 | --is_training 1 \ 76 | --task_name long_term_forecast \ 77 | --root_path ./dataset/EPF/ \ 78 | --data_path FR.csv \ 79 | --model_id FR_168_24 \ 80 | --model $model_name \ 81 | --data custom \ 82 | --features MS \ 83 | --seq_len 168 \ 84 | --pred_len 24 \ 85 | --e_layers 2 \ 86 | --enc_in 3 \ 87 | --dec_in 3 \ 88 | --c_out 1 \ 89 | --des $des \ 90 | --patch_len $patch_len \ 91 | --batch_size 16 \ 92 | --d_model 512 \ 93 | --itr 1 94 | 95 | python3 -u run.py \ 96 | --is_training 1 \ 97 | --task_name long_term_forecast \ 98 | --root_path ./dataset/EPF/ \ 99 | --data_path DE.csv \ 100 | --model_id DE_168_24 \ 101 | --model $model_name \ 102 | --data custom \ 103 | --features MS \ 104 | --seq_len 168 \ 105 | --pred_len 24 \ 106 | --e_layers 1 \ 107 | --enc_in 3 \ 108 | --dec_in 3 \ 109 | --c_out 1 \ 110 | --des $des \ 111 | --patch_len $patch_len \ 112 | --batch_size 4 \ 113 | --d_model 512 \ 114 | --itr 1 115 | -------------------------------------------------------------------------------- /scripts/forecast_exogenous/ETTh1/TimeXer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=TimeXer 4 | des='Timexer-MS' 5 | 6 | python3 -u run.py \ 7 | --task_name long_term_forecast \ 8 | --is_training 1 \ 9 | --root_path ./dataset/ETT-small/ \ 10 | --data_path ETTh1.csv \ 11 | --model_id ETTh1_96_96 \ 12 | --model $model_name \ 13 | --data ETTh1 \ 14 | --features MS \ 15 | --seq_len 96 \ 16 | --label_len 48 \ 17 | --pred_len 96 \ 18 | --e_layers 2 \ 19 | --factor 3 \ 20 | --enc_in 7 \ 21 | --dec_in 7 \ 22 | --c_out 7 \ 23 | --d_model 512 \ 24 | --d_ff 512 \ 25 | --des $des \ 26 | --itr 1 27 | 28 | python3 -u run.py \ 29 | --task_name long_term_forecast \ 30 | --is_training 1 \ 31 | --root_path ./dataset/ETT-small/ \ 32 | --data_path ETTh1.csv \ 33 | --model_id ETTh1_96_192 \ 34 | --model $model_name \ 35 | --data ETTh1 \ 36 | --features MS \ 37 | --seq_len 96 \ 38 | --label_len 48 \ 39 | --pred_len 192 \ 40 | --e_layers 2 \ 41 | --factor 3 \ 42 | --enc_in 7 \ 43 | --dec_in 7 \ 44 | --c_out 7 \ 45 | --d_model 128 \ 46 | --d_ff 128 \ 47 | --batch_size 4 \ 48 | --des $des \ 49 | --itr 1 50 | 51 | python3 -u run.py \ 52 | --task_name long_term_forecast \ 53 | --is_training 1 \ 54 | --root_path ./dataset/ETT-small/ \ 55 | --data_path ETTh1.csv \ 56 | --model_id ETTh1_96_336 \ 57 | --model $model_name \ 58 | --data ETTh1 \ 59 | --features MS \ 60 | --seq_len 96 \ 61 | --label_len 48 \ 62 | --pred_len 336 \ 63 | --e_layers 2 \ 64 | --factor 3 \ 65 | --enc_in 7 \ 66 | --dec_in 7 \ 67 | --c_out 7 \ 68 | --d_model 512 \ 69 | --d_ff 512 \ 70 | --batch_size 32 \ 71 | --des $des \ 72 | --itr 1 73 | 74 | python3 -u run.py \ 75 | --task_name long_term_forecast \ 76 | --is_training 1 \ 77 | --root_path ./dataset/ETT-small/ \ 78 | --data_path ETTh1.csv \ 79 | --model_id ETTh1_96_720 \ 80 | --model $model_name \ 81 | --data ETTh1 \ 82 | --features MS \ 83 | --seq_len 96 \ 84 | --label_len 48 \ 85 | --pred_len 720 \ 86 | --e_layers 2 \ 87 | --factor 3 \ 88 | --enc_in 7 \ 89 | --dec_in 7 \ 90 | --c_out 7 \ 91 | --d_model 512 \ 92 | --batch_size 128 \ 93 | --des $des \ 94 | --itr 1 95 | -------------------------------------------------------------------------------- /scripts/forecast_exogenous/ETTh2/TimeXer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | model_name=TimeXer 4 | des='Timexer-MS' 5 | 6 | python3 -u run.py \ 7 | --task_name long_term_forecast \ 8 | --is_training 1 \ 9 | --root_path ./dataset/ETT-small/ \ 10 | --data_path ETTh2.csv \ 11 | --model_id ETTh2_96_96 \ 12 | --model $model_name \ 13 | --data ETTh2 \ 14 | --features MS \ 15 | --seq_len 96 \ 16 | --label_len 48 \ 17 | --pred_len 96 \ 18 | --e_layers 1 \ 19 | --factor 3 \ 20 | --enc_in 7 \ 21 | --dec_in 7 \ 22 | --c_out 7 \ 23 | --d_model 128 \ 24 | --d_ff 128 \ 25 | --batch_size 128 \ 26 | --des $des \ 27 | --itr 1 28 | 29 | python3 -u run.py \ 30 | --task_name long_term_forecast \ 31 | --is_training 1 \ 32 | --root_path ./dataset/ETT-small/ \ 33 | --data_path ETTh2.csv \ 34 | --model_id ETTh2_96_192 \ 35 | --model $model_name \ 36 | --data ETTh2 \ 37 | --features MS \ 38 | --seq_len 96 \ 39 | --label_len 48 \ 40 | --pred_len 192 \ 41 | --e_layers 1 \ 42 | --factor 3 \ 43 | --enc_in 7 \ 44 | --dec_in 7 \ 45 | --c_out 7 \ 46 | --d_model 128 \ 47 | --d_ff 512 \ 48 | --batch_size 128 \ 49 | --des $des \ 50 | --itr 1 51 | 52 | python3 -u run.py \ 53 | --task_name long_term_forecast \ 54 | --is_training 1 \ 55 | --root_path ./dataset/ETT-small/ \ 56 | --data_path ETTh2.csv \ 57 | --model_id ETTh2_96_336 \ 58 | --model $model_name \ 59 | --data ETTh2 \ 60 | --features MS \ 61 | --seq_len 96 \ 62 | --label_len 48 \ 63 | --pred_len 336 \ 64 | --e_layers 2 \ 65 | --factor 3 \ 66 | --enc_in 7 \ 67 | --dec_in 7 \ 68 | --c_out 7 \ 69 | --d_model 128 \ 70 | --d_ff 256 \ 71 | --batch_size 16 \ 72 | --des $des \ 73 | --itr 1 74 | 75 | python3 -u run.py \ 76 | --task_name long_term_forecast \ 77 | --is_training 1 \ 78 | --root_path ./dataset/ETT-small/ \ 79 | --data_path ETTh2.csv \ 80 | --model_id ETTh2_96_720 \ 81 | --model $model_name \ 82 | --data ETTh2 \ 83 | --features MS \ 84 | --seq_len 96 \ 85 | --label_len 48 \ 86 | --pred_len 720 \ 87 | --e_layers 1 \ 88 | --factor 3 \ 89 | --enc_in 7 \ 90 | --dec_in 7 \ 91 | --c_out 7 \ 92 | --d_model 256 \ 93 | --d_ff 512 \ 94 | --des $des \ 95 | --itr 1 -------------------------------------------------------------------------------- /scripts/forecast_exogenous/ETTm1/TimeXer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=TimeXer 4 | des='Timexer-MS' 5 | 6 | python3 -u run.py \ 7 | --task_name long_term_forecast \ 8 | --is_training 1 \ 9 | --root_path ./dataset/ETT-small/ \ 10 | --data_path ETTm1.csv \ 11 | --model_id ETTm1_96_96 \ 12 | --model $model_name \ 13 | --data ETTm1 \ 14 | --features MS \ 15 | --seq_len 96 \ 16 | --label_len 48 \ 17 | --pred_len 96 \ 18 | --e_layers 1 \ 19 | --factor 3 \ 20 | --enc_in 7 \ 21 | --dec_in 7 \ 22 | --c_out 7 \ 23 | --d_model 256 \ 24 | --batch_size 128 \ 25 | --des $des \ 26 | --itr 1 27 | 28 | python3 -u run.py \ 29 | --task_name long_term_forecast \ 30 | --is_training 1 \ 31 | --root_path ./dataset/ETT-small/ \ 32 | --data_path ETTm1.csv \ 33 | --model_id ETTm1_96_192 \ 34 | --model $model_name \ 35 | --data ETTm1 \ 36 | --features MS \ 37 | --seq_len 96 \ 38 | --label_len 48 \ 39 | --pred_len 192 \ 40 | --e_layers 1 \ 41 | --factor 3 \ 42 | --enc_in 7 \ 43 | --dec_in 7 \ 44 | --c_out 7 \ 45 | --d_model 128 \ 46 | --batch_size 128 \ 47 | --des $des \ 48 | --itr 1 49 | 50 | python3 -u run.py \ 51 | --task_name long_term_forecast \ 52 | --is_training 1 \ 53 | --root_path ./dataset/ETT-small/ \ 54 | --data_path ETTm1.csv \ 55 | --model_id ETTm1_96_336 \ 56 | --model $model_name \ 57 | --data ETTm1 \ 58 | --features MS \ 59 | --seq_len 96 \ 60 | --label_len 48 \ 61 | --pred_len 336 \ 62 | --e_layers 1 \ 63 | --factor 3 \ 64 | --enc_in 7 \ 65 | --dec_in 7 \ 66 | --c_out 7 \ 67 | --d_model 128 \ 68 | --batch_size 128 \ 69 | --des $des \ 70 | --itr 1 71 | 72 | python3 -u run.py \ 73 | --task_name long_term_forecast \ 74 | --is_training 1 \ 75 | --root_path ./dataset/ETT-small/ \ 76 | --data_path ETTm1.csv \ 77 | --model_id ETTm1_96_720 \ 78 | --model $model_name \ 79 | --data ETTm1 \ 80 | --features MS \ 81 | --seq_len 96 \ 82 | --label_len 48 \ 83 | --pred_len 720 \ 84 | --e_layers 1 \ 85 | --factor 3 \ 86 | --enc_in 7 \ 87 | --dec_in 7 \ 88 | --c_out 7 \ 89 | --d_model 128 \ 90 | --batch_size 128 \ 91 | --des $des \ 92 | --itr 1 -------------------------------------------------------------------------------- /scripts/forecast_exogenous/ETTm2/TimeXer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | model_name=TimeXer 4 | des='Timexer-MS' 5 | 6 | python3 -u run.py \ 7 | --task_name long_term_forecast \ 8 | --is_training 1 \ 9 | --root_path ./dataset/ETT-small/ \ 10 | --data_path ETTm2.csv \ 11 | --model_id ETTm2_96_96 \ 12 | --model $model_name \ 13 | --data ETTm2 \ 14 | --features MS \ 15 | --seq_len 96 \ 16 | --label_len 48 \ 17 | --pred_len 96 \ 18 | --e_layers 1 \ 19 | --factor 3 \ 20 | --enc_in 7 \ 21 | --dec_in 7 \ 22 | --c_out 7 \ 23 | --d_model 512 \ 24 | --batch_size 16 \ 25 | --des $des \ 26 | --itr 1 27 | 28 | python3 -u run.py \ 29 | --task_name long_term_forecast \ 30 | --is_training 1 \ 31 | --root_path ./dataset/ETT-small/ \ 32 | --data_path ETTm2.csv \ 33 | --model_id ETTm2_96_192 \ 34 | --model $model_name \ 35 | --data ETTm2 \ 36 | --features MS \ 37 | --seq_len 96 \ 38 | --label_len 48 \ 39 | --pred_len 192 \ 40 | --e_layers 1 \ 41 | --factor 3 \ 42 | --enc_in 7 \ 43 | --dec_in 7 \ 44 | --c_out 7 \ 45 | --d_model 256 \ 46 | --batch_size 4 \ 47 | --des $des \ 48 | --itr 1 49 | 50 | python3 -u run.py \ 51 | --task_name long_term_forecast \ 52 | --is_training 1 \ 53 | --root_path ./dataset/ETT-small/ \ 54 | --data_path ETTm2.csv \ 55 | --model_id ETTm2_96_336 \ 56 | --model $model_name \ 57 | --data ETTm2 \ 58 | --features MS \ 59 | --seq_len 96 \ 60 | --label_len 48 \ 61 | --pred_len 336 \ 62 | --e_layers 1 \ 63 | --factor 3 \ 64 | --enc_in 7 \ 65 | --dec_in 7 \ 66 | --c_out 7 \ 67 | --d_model 128 \ 68 | --batch_size 128 \ 69 | --des $des \ 70 | --itr 1 71 | 72 | python3 -u run.py \ 73 | --task_name long_term_forecast \ 74 | --is_training 1 \ 75 | --root_path ./dataset/ETT-small/ \ 76 | --data_path ETTm2.csv \ 77 | --model_id ETTm2_96_720 \ 78 | --model $model_name \ 79 | --data ETTm2 \ 80 | --features MS \ 81 | --seq_len 96 \ 82 | --label_len 48 \ 83 | --pred_len 720 \ 84 | --e_layers 1 \ 85 | --factor 3 \ 86 | --enc_in 7 \ 87 | --dec_in 7 \ 88 | --c_out 7 \ 89 | --d_model 128 \ 90 | --batch_size 128 \ 91 | --des $des \ 92 | --itr 1 93 | -------------------------------------------------------------------------------- /scripts/forecast_exogenous/Traffic/TimeXer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=TimeXer 4 | des='Timexer-MS' 5 | 6 | python3 -u run.py \ 7 | --task_name long_term_forecast \ 8 | --is_training 1 \ 9 | --root_path ./dataset/traffic/ \ 10 | --data_path traffic.csv \ 11 | --model_id traffic_96_96 \ 12 | --model $model_name \ 13 | --data custom \ 14 | --features MS \ 15 | --seq_len 96 \ 16 | --label_len 48 \ 17 | --pred_len 96 \ 18 | --e_layers 1 \ 19 | --d_layers 1 \ 20 | --factor 3 \ 21 | --enc_in 862 \ 22 | --dec_in 862 \ 23 | --c_out 862 \ 24 | --d_model 512 \ 25 | --des $des \ 26 | --batch_size 4 \ 27 | --itr 1 28 | 29 | python3 -u run.py \ 30 | --task_name long_term_forecast \ 31 | --is_training 1 \ 32 | --root_path ./dataset/traffic/ \ 33 | --data_path traffic.csv \ 34 | --model_id traffic_96_192 \ 35 | --model $model_name \ 36 | --data custom \ 37 | --features MS \ 38 | --seq_len 96 \ 39 | --label_len 48 \ 40 | --pred_len 192 \ 41 | --e_layers 1 \ 42 | --d_layers 1 \ 43 | --factor 3 \ 44 | --enc_in 862 \ 45 | --dec_in 862 \ 46 | --c_out 862 \ 47 | --d_model 512 \ 48 | --des 'Exp' \ 49 | --batch_size 4 \ 50 | --itr 1 51 | 52 | python3 -u run.py \ 53 | --task_name long_term_forecast \ 54 | --is_training 1 \ 55 | --root_path ./dataset/traffic/ \ 56 | --data_path traffic.csv \ 57 | --model_id traffic_96_336 \ 58 | --model $model_name \ 59 | --data custom \ 60 | --features MS \ 61 | --seq_len 96 \ 62 | --label_len 48 \ 63 | --pred_len 336 \ 64 | --e_layers 1 \ 65 | --d_layers 1 \ 66 | --factor 3 \ 67 | --enc_in 862 \ 68 | --dec_in 862 \ 69 | --c_out 862 \ 70 | --d_model 512 \ 71 | --des $des \ 72 | --batch_size 4 \ 73 | --itr 1 74 | 75 | python3 -u run.py \ 76 | --task_name long_term_forecast \ 77 | --is_training 1 \ 78 | --root_path ./dataset/traffic/ \ 79 | --data_path traffic.csv \ 80 | --model_id traffic_96_720 \ 81 | --model $model_name \ 82 | --data custom \ 83 | --features MS \ 84 | --seq_len 96 \ 85 | --label_len 48 \ 86 | --pred_len 720 \ 87 | --e_layers 1 \ 88 | --d_layers 1 \ 89 | --factor 3 \ 90 | --enc_in 862 \ 91 | --dec_in 862 \ 92 | --c_out 862 \ 93 | --d_model 512 \ 94 | --des $des \ 95 | --batch_size 4 \ 96 | --itr 1 97 | -------------------------------------------------------------------------------- /scripts/forecast_exogenous/Weather/TimeXer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | model_name=TimeXer 4 | des='Timexer-MS' 5 | 6 | 7 | python3 -u run.py \ 8 | --task_name long_term_forecast \ 9 | --is_training 1 \ 10 | --root_path ./dataset/weather/ \ 11 | --data_path weather.csv \ 12 | --model_id weather_96_96 \ 13 | --model $model_name \ 14 | --data custom \ 15 | --features MS \ 16 | --seq_len 96 \ 17 | --label_len 48 \ 18 | --pred_len 96 \ 19 | --e_layers 1 \ 20 | --factor 3 \ 21 | --enc_in 21 \ 22 | --dec_in 21 \ 23 | --c_out 21 \ 24 | --des $des \ 25 | --d_model 128 \ 26 | --itr 1 27 | 28 | python3 -u run.py \ 29 | --task_name long_term_forecast \ 30 | --is_training 1 \ 31 | --root_path ./dataset/weather/ \ 32 | --data_path weather.csv \ 33 | --model_id weather_96_192 \ 34 | --model $model_name \ 35 | --data custom \ 36 | --features MS \ 37 | --seq_len 96 \ 38 | --label_len 48 \ 39 | --pred_len 192 \ 40 | --e_layers 1 \ 41 | --factor 3 \ 42 | --enc_in 21 \ 43 | --dec_in 21 \ 44 | --c_out 21 \ 45 | --des $des \ 46 | --d_model 128 \ 47 | --itr 1 48 | 49 | python3 -u run.py \ 50 | --task_name long_term_forecast \ 51 | --is_training 1 \ 52 | --root_path ./dataset/weather/ \ 53 | --data_path weather.csv \ 54 | --model_id weather_96_336 \ 55 | --model $model_name \ 56 | --data custom \ 57 | --features MS \ 58 | --seq_len 96 \ 59 | --label_len 48 \ 60 | --pred_len 336 \ 61 | --e_layers 1 \ 62 | --factor 3 \ 63 | --enc_in 21 \ 64 | --dec_in 21 \ 65 | --c_out 21 \ 66 | --des $des \ 67 | --d_model 128 \ 68 | --itr 1 69 | 70 | python3 -u run.py \ 71 | --task_name long_term_forecast \ 72 | --is_training 1 \ 73 | --root_path ./dataset/weather/ \ 74 | --data_path weather.csv \ 75 | --model_id weather_96_720 \ 76 | --model $model_name \ 77 | --data custom \ 78 | --features MS \ 79 | --seq_len 96 \ 80 | --label_len 48 \ 81 | --pred_len 720 \ 82 | --e_layers 1 \ 83 | --factor 3 \ 84 | --enc_in 21 \ 85 | --dec_in 21 \ 86 | --c_out 21 \ 87 | --des $des \ 88 | --d_model 128 \ 89 | --itr 1 -------------------------------------------------------------------------------- /scripts/forecast_exogenous/meteorology/temp.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | model_name=TimeXer 3 | 4 | python3 -u run.py \ 5 | --task_name long_term_forecast \ 6 | --is_training 1 \ 7 | --root_path ./dataset/meteorology \ 8 | --data_path temp.npy \ 9 | --model_id temp \ 10 | --model $model_name \ 11 | --data Meteorology \ 12 | --features MS \ 13 | --seq_len 168 \ 14 | --label_len 1 \ 15 | --pred_len 72 \ 16 | --patch_len 8 \ 17 | --e_layers 2 \ 18 | --enc_in 37 \ 19 | --d_model 512 \ 20 | --d_ff 512 \ 21 | --des 'global_temp' \ 22 | --learning_rate 0.0001 \ 23 | --batch_size 4096 -------------------------------------------------------------------------------- /scripts/forecast_exogenous/meteorology/wind.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | model_name=TimeXer 3 | 4 | python3 -u run.py \ 5 | --task_name long_term_forecast \ 6 | --is_training 1 \ 7 | --root_path ./dataset/meteorology \ 8 | --data_path wind.npy \ 9 | --model_id wind \ 10 | --model $model_name \ 11 | --data Meteorology \ 12 | --features MS \ 13 | --seq_len 168 \ 14 | --label_len 1 \ 15 | --pred_len 72 \ 16 | --patch_len 8 \ 17 | --e_layers 2 \ 18 | --enc_in 37 \ 19 | --d_model 512 \ 20 | --d_ff 512 \ 21 | --des 'global_wind' \ 22 | --learning_rate 0.0001 \ 23 | --batch_size 4096 -------------------------------------------------------------------------------- /scripts/multivariate/ECL/TimeXer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=TimeXer 4 | 5 | python3 -u run.py \ 6 | --task_name long_term_forecast \ 7 | --is_training 1 \ 8 | --root_path ./dataset/electricity/ \ 9 | --data_path electricity.csv \ 10 | --model_id ECL_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --label_len 48 \ 16 | --pred_len 96 \ 17 | --e_layers 4 \ 18 | --factor 3 \ 19 | --enc_in 321 \ 20 | --dec_in 321 \ 21 | --c_out 321 \ 22 | --des 'Exp' \ 23 | --d_ff 512 \ 24 | --batch_size 4 \ 25 | --itr 1 26 | 27 | python3 -u run.py \ 28 | --task_name long_term_forecast \ 29 | --is_training 1 \ 30 | --root_path ./dataset/electricity/ \ 31 | --data_path electricity.csv \ 32 | --model_id ECL_96_192 \ 33 | --model $model_name \ 34 | --data custom \ 35 | --features M \ 36 | --seq_len 96 \ 37 | --label_len 48 \ 38 | --pred_len 192 \ 39 | --e_layers 3 \ 40 | --factor 3 \ 41 | --enc_in 321 \ 42 | --dec_in 321 \ 43 | --c_out 321 \ 44 | --des 'Exp' \ 45 | --batch_size 4 \ 46 | --itr 1 47 | 48 | python3 -u run.py \ 49 | --task_name long_term_forecast \ 50 | --is_training 1 \ 51 | --root_path ./dataset/electricity/ \ 52 | --data_path electricity.csv \ 53 | --model_id ECL_96_336 \ 54 | --model $model_name \ 55 | --data custom \ 56 | --features M \ 57 | --seq_len 96 \ 58 | --label_len 48 \ 59 | --pred_len 336 \ 60 | --e_layers 4 \ 61 | --factor 3 \ 62 | --enc_in 321 \ 63 | --dec_in 321 \ 64 | --c_out 321 \ 65 | --des 'Exp' \ 66 | --batch_size 4 \ 67 | --itr 1 68 | 69 | python3 -u run.py \ 70 | --task_name long_term_forecast \ 71 | --is_training 1 \ 72 | --root_path ./dataset/electricity/ \ 73 | --data_path electricity.csv \ 74 | --model_id ECL_96_720 \ 75 | --model $model_name \ 76 | --data custom \ 77 | --features M \ 78 | --seq_len 96 \ 79 | --label_len 48 \ 80 | --pred_len 720 \ 81 | --e_layers 3 \ 82 | --factor 3 \ 83 | --enc_in 321 \ 84 | --dec_in 321 \ 85 | --c_out 321 \ 86 | --des 'Exp' \ 87 | --batch_size 4 \ 88 | --itr 1 89 | -------------------------------------------------------------------------------- /scripts/multivariate/ETT/TimeXer_ETTh1.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | model_name=TimeXer 4 | 5 | python3 -u run.py \ 6 | --task_name long_term_forecast \ 7 | --is_training 1 \ 8 | --root_path ./dataset/ETT-small/ \ 9 | --data_path ETTh1.csv \ 10 | --model_id ETTh1_96_96 \ 11 | --model $model_name \ 12 | --data ETTh1 \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --label_len 48 \ 16 | --pred_len 96 \ 17 | --e_layers 1 \ 18 | --factor 3 \ 19 | --enc_in 7 \ 20 | --dec_in 7 \ 21 | --c_out 7 \ 22 | --d_model 256 \ 23 | --batch_size 4 \ 24 | --des 'exp' \ 25 | --itr 1 26 | 27 | 28 | python3 -u run.py \ 29 | --task_name long_term_forecast \ 30 | --is_training 1 \ 31 | --root_path ./dataset/ETT-small/ \ 32 | --data_path ETTh1.csv \ 33 | --model_id ETTh1_96_192 \ 34 | --model $model_name \ 35 | --data ETTh1 \ 36 | --features M \ 37 | --seq_len 96 \ 38 | --label_len 48 \ 39 | --pred_len 192 \ 40 | --e_layers 2 \ 41 | --factor 3 \ 42 | --enc_in 7 \ 43 | --dec_in 7 \ 44 | --c_out 7 \ 45 | --des 'Exp' \ 46 | --d_model 128 \ 47 | --batch_size 4 \ 48 | --itr 1 49 | 50 | python3 -u run.py \ 51 | --task_name long_term_forecast \ 52 | --is_training 1 \ 53 | --root_path ./dataset/ETT-small/ \ 54 | --data_path ETTh1.csv \ 55 | --model_id ETTh1_96_336 \ 56 | --model $model_name \ 57 | --data ETTh1 \ 58 | --features M \ 59 | --seq_len 96 \ 60 | --label_len 48 \ 61 | --pred_len 336 \ 62 | --e_layers 1 \ 63 | --factor 3 \ 64 | --enc_in 7 \ 65 | --dec_in 7 \ 66 | --c_out 7 \ 67 | --des 'Exp' \ 68 | --d_model 512 \ 69 | --d_ff 1024 \ 70 | --batch_size 16 \ 71 | --itr 1 72 | 73 | python3 -u run.py \ 74 | --task_name long_term_forecast \ 75 | --is_training 1 \ 76 | --root_path ./dataset/ETT-small/ \ 77 | --data_path ETTh1.csv \ 78 | --model_id ETTh1_96_720 \ 79 | --model $model_name \ 80 | --data ETTh1 \ 81 | --features M \ 82 | --seq_len 96 \ 83 | --label_len 48 \ 84 | --pred_len 720 \ 85 | --e_layers 1 \ 86 | --factor 3 \ 87 | --enc_in 7 \ 88 | --dec_in 7 \ 89 | --c_out 7 \ 90 | --des 'Exp' \ 91 | --d_model 256 \ 92 | --d_ff 1024 \ 93 | --batch_size 16 \ 94 | --itr 1 95 | -------------------------------------------------------------------------------- /scripts/multivariate/ETT/TimeXer_ETTh2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=TimeXer 4 | 5 | python3 -u run.py \ 6 | --task_name long_term_forecast \ 7 | --is_training 1 \ 8 | --root_path ./dataset/ETT-small/ \ 9 | --data_path ETTh2.csv \ 10 | --model_id ETTh2_96_96 \ 11 | --model $model_name \ 12 | --data ETTh2 \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --label_len 48 \ 16 | --pred_len 96 \ 17 | --e_layers 1 \ 18 | --factor 3 \ 19 | --enc_in 7 \ 20 | --dec_in 7 \ 21 | --c_out 7 \ 22 | --des 'Exp' \ 23 | --d_model 256 \ 24 | --d_ff 1024 \ 25 | --batch_size 16 \ 26 | --itr 1 27 | 28 | python3 -u run.py \ 29 | --task_name long_term_forecast \ 30 | --is_training 1 \ 31 | --root_path ./dataset/ETT-small/ \ 32 | --data_path ETTh2.csv \ 33 | --model_id ETTh2_96_192 \ 34 | --model $model_name \ 35 | --data ETTh2 \ 36 | --features M \ 37 | --seq_len 96 \ 38 | --label_len 48 \ 39 | --pred_len 192 \ 40 | --e_layers 1 \ 41 | --factor 3 \ 42 | --enc_in 7 \ 43 | --dec_in 7 \ 44 | --c_out 7 \ 45 | --des 'Exp' \ 46 | --d_model 256 \ 47 | --d_ff 1024 \ 48 | --itr 1 49 | 50 | python3 -u run.py \ 51 | --task_name long_term_forecast \ 52 | --is_training 1 \ 53 | --root_path ./dataset/ETT-small/ \ 54 | --data_path ETTh2.csv \ 55 | --model_id ETTh2_96_336 \ 56 | --model $model_name \ 57 | --data ETTh2 \ 58 | --features M \ 59 | --seq_len 96 \ 60 | --label_len 48 \ 61 | --pred_len 336 \ 62 | --e_layers 2 \ 63 | --factor 3 \ 64 | --enc_in 7 \ 65 | --dec_in 7 \ 66 | --c_out 7 \ 67 | --des 'Exp' \ 68 | --d_model 512 \ 69 | --d_ff 1024 \ 70 | --batch_size 4 \ 71 | --itr 1 72 | 73 | python3 -u run.py \ 74 | --task_name long_term_forecast \ 75 | --is_training 1 \ 76 | --root_path ./dataset/ETT-small/ \ 77 | --data_path ETTh2.csv \ 78 | --model_id ETTh2_96_720 \ 79 | --model $model_name \ 80 | --data ETTh2 \ 81 | --features M \ 82 | --seq_len 96 \ 83 | --label_len 48 \ 84 | --pred_len 720 \ 85 | --e_layers 2 \ 86 | --factor 3 \ 87 | --enc_in 7 \ 88 | --dec_in 7 \ 89 | --c_out 7 \ 90 | --des 'Exp' \ 91 | --d_model 256 \ 92 | --d_ff 1024 \ 93 | --batch_size 16 \ 94 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate/ETT/TimeXer_ETTm1.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=TimeXer 4 | 5 | python3 -u run.py \ 6 | --task_name long_term_forecast \ 7 | --is_training 1 \ 8 | --root_path ./dataset/ETT-small/ \ 9 | --data_path ETTm1.csv \ 10 | --model_id ETTm1_96_96 \ 11 | --model $model_name \ 12 | --data ETTm1 \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --label_len 48 \ 16 | --pred_len 96 \ 17 | --e_layers 1 \ 18 | --factor 3 \ 19 | --enc_in 7 \ 20 | --dec_in 7 \ 21 | --c_out 7 \ 22 | --d_model 256 \ 23 | --batch_size 4 \ 24 | --des 'Exp' \ 25 | --itr 1 26 | 27 | python3 -u run.py \ 28 | --task_name long_term_forecast \ 29 | --is_training 1 \ 30 | --root_path ./dataset/ETT-small/ \ 31 | --data_path ETTm1.csv \ 32 | --model_id ETTm1_96_192 \ 33 | --model $model_name \ 34 | --data ETTm1 \ 35 | --features M \ 36 | --seq_len 96 \ 37 | --label_len 48 \ 38 | --pred_len 192 \ 39 | --e_layers 1 \ 40 | --factor 3 \ 41 | --enc_in 7 \ 42 | --dec_in 7 \ 43 | --c_out 7 \ 44 | --d_model 256 \ 45 | --d_ff 256 \ 46 | --batch_size 4 \ 47 | --des 'Exp' \ 48 | --itr 1 49 | 50 | python3 -u run.py \ 51 | --task_name long_term_forecast \ 52 | --is_training 1 \ 53 | --root_path ./dataset/ETT-small/ \ 54 | --data_path ETTm1.csv \ 55 | --model_id ETTm1_96_336 \ 56 | --model $model_name \ 57 | --data ETTm1 \ 58 | --features M \ 59 | --seq_len 96 \ 60 | --label_len 48 \ 61 | --pred_len 336 \ 62 | --e_layers 1 \ 63 | --factor 3 \ 64 | --enc_in 7 \ 65 | --dec_in 7 \ 66 | --c_out 7 \ 67 | --d_model 256 \ 68 | --d_ff 1024 \ 69 | --batch_size 4 \ 70 | --des 'Exp' \ 71 | --itr 1 72 | 73 | python3 -u run.py \ 74 | --task_name long_term_forecast \ 75 | --is_training 1 \ 76 | --root_path ./dataset/ETT-small/ \ 77 | --data_path ETTm1.csv \ 78 | --model_id ETTm1_96_720 \ 79 | --model $model_name \ 80 | --data ETTm1 \ 81 | --features M \ 82 | --seq_len 96 \ 83 | --label_len 48 \ 84 | --pred_len 720 \ 85 | --e_layers 1 \ 86 | --factor 3 \ 87 | --enc_in 7 \ 88 | --dec_in 7 \ 89 | --c_out 7 \ 90 | --d_model 256 \ 91 | --d_ff 512 \ 92 | --batch_size 4 \ 93 | --des 'Exp' \ 94 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate/ETT/TimeXer_ETTm2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=TimeXer 4 | 5 | python3 -u run.py \ 6 | --task_name long_term_forecast \ 7 | --is_training 1 \ 8 | --root_path ./dataset/ETT-small/ \ 9 | --data_path ETTm2.csv \ 10 | --model_id ETTm2_96_96 \ 11 | --model $model_name \ 12 | --data ETTm2 \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --label_len 48 \ 16 | --pred_len 96 \ 17 | --e_layers 1 \ 18 | --d_layers 1 \ 19 | --factor 3 \ 20 | --enc_in 7 \ 21 | --dec_in 7 \ 22 | --c_out 7 \ 23 | --d_model 256 \ 24 | --des 'Exp' \ 25 | --itr 1 26 | 27 | python3 -u run.py \ 28 | --task_name long_term_forecast \ 29 | --is_training 1 \ 30 | --root_path ./dataset/ETT-small/ \ 31 | --data_path ETTm2.csv \ 32 | --model_id ETTm2_96_192 \ 33 | --model $model_name \ 34 | --data ETTm2 \ 35 | --features M \ 36 | --seq_len 96 \ 37 | --label_len 48 \ 38 | --pred_len 192 \ 39 | --e_layers 1 \ 40 | --d_layers 1 \ 41 | --factor 3 \ 42 | --enc_in 7 \ 43 | --dec_in 7 \ 44 | --c_out 7 \ 45 | --d_model 256 \ 46 | --d_ff 1024 \ 47 | --batch_size 16 \ 48 | --des 'Exp' \ 49 | --itr 1 50 | 51 | python3 -u run.py \ 52 | --task_name long_term_forecast \ 53 | --is_training 1 \ 54 | --root_path ./dataset/ETT-small/ \ 55 | --data_path ETTm2.csv \ 56 | --model_id ETTm2_96_336 \ 57 | --model $model_name \ 58 | --data ETTm2 \ 59 | --features M \ 60 | --seq_len 96 \ 61 | --label_len 48 \ 62 | --pred_len 336 \ 63 | --e_layers 1 \ 64 | --d_layers 1 \ 65 | --factor 3 \ 66 | --enc_in 7 \ 67 | --dec_in 7 \ 68 | --c_out 7 \ 69 | --d_model 512 \ 70 | --d_ff 1024 \ 71 | --des 'Exp' \ 72 | --itr 1 73 | 74 | 75 | python3 -u run.py \ 76 | --task_name long_term_forecast \ 77 | --is_training 1 \ 78 | --root_path ./dataset/ETT-small/ \ 79 | --data_path ETTm2.csv \ 80 | --model_id ETTm2_96_720 \ 81 | --model $model_name \ 82 | --data ETTm2 \ 83 | --features M \ 84 | --seq_len 96 \ 85 | --label_len 48 \ 86 | --pred_len 720 \ 87 | --e_layers 1 \ 88 | --d_layers 1 \ 89 | --factor 3 \ 90 | --enc_in 7 \ 91 | --dec_in 7 \ 92 | --c_out 7 \ 93 | --d_model 512 \ 94 | --des 'Exp' \ 95 | --itr 1 -------------------------------------------------------------------------------- /scripts/multivariate/Traffic/TimeXer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | model_name=TimeXer 4 | 5 | python3 -u run.py \ 6 | --task_name long_term_forecast \ 7 | --is_training 1 \ 8 | --root_path ./dataset/traffic/ \ 9 | --data_path traffic.csv \ 10 | --model_id traffic_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --label_len 48 \ 16 | --pred_len 96 \ 17 | --e_layers 3 \ 18 | --factor 3 \ 19 | --enc_in 862 \ 20 | --dec_in 862 \ 21 | --c_out 862 \ 22 | --d_model 512 \ 23 | --d_ff 512 \ 24 | --des 'Exp' \ 25 | --batch_size 16 \ 26 | --learning_rate 0.001 \ 27 | --itr 1 28 | 29 | python3 -u run.py \ 30 | --task_name long_term_forecast \ 31 | --is_training 1 \ 32 | --root_path ./dataset/traffic/ \ 33 | --data_path traffic.csv \ 34 | --model_id traffic_96_192 \ 35 | --model $model_name \ 36 | --data custom \ 37 | --features M \ 38 | --seq_len 96 \ 39 | --label_len 48 \ 40 | --pred_len 192 \ 41 | --e_layers 3 \ 42 | --factor 3 \ 43 | --enc_in 862 \ 44 | --dec_in 862 \ 45 | --c_out 862 \ 46 | --d_model 512 \ 47 | --d_ff 512 \ 48 | --des 'Exp' \ 49 | --batch_size 16 \ 50 | --learning_rate 0.001 \ 51 | --itr 1 52 | 53 | python3 -u run.py \ 54 | --task_name long_term_forecast \ 55 | --is_training 1 \ 56 | --root_path ./dataset/traffic/ \ 57 | --data_path traffic.csv \ 58 | --model_id traffic_96_336 \ 59 | --model $model_name \ 60 | --data custom \ 61 | --features M \ 62 | --seq_len 96 \ 63 | --label_len 48 \ 64 | --pred_len 336 \ 65 | --e_layers 2 \ 66 | --factor 3 \ 67 | --enc_in 862 \ 68 | --dec_in 862 \ 69 | --c_out 862 \ 70 | --d_model 512 \ 71 | --d_ff 512 \ 72 | --des 'Exp' \ 73 | --batch_size 16 \ 74 | --learning_rate 0.001 \ 75 | --itr 1 76 | 77 | python3 -u run.py \ 78 | --task_name long_term_forecast \ 79 | --is_training 1 \ 80 | --root_path ./dataset/traffic/ \ 81 | --data_path traffic.csv \ 82 | --model_id traffic_96_720 \ 83 | --model $model_name \ 84 | --data custom \ 85 | --features M \ 86 | --seq_len 96 \ 87 | --label_len 48 \ 88 | --pred_len 720 \ 89 | --e_layers 2 \ 90 | --factor 3 \ 91 | --enc_in 862 \ 92 | --dec_in 862 \ 93 | --c_out 862 \ 94 | --d_model 512 \ 95 | --d_ff 512 \ 96 | --des 'Exp' \ 97 | --batch_size 16 \ 98 | --learning_rate 0.001 \ 99 | --itr 1 100 | -------------------------------------------------------------------------------- /scripts/multivariate/Weather/TimeXer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=TimeXer 4 | 5 | python3 -u run.py \ 6 | --task_name long_term_forecast \ 7 | --is_training 1 \ 8 | --root_path ./dataset/weather/ \ 9 | --data_path weather.csv \ 10 | --model_id weather_96_96 \ 11 | --model $model_name \ 12 | --data custom \ 13 | --features M \ 14 | --seq_len 96 \ 15 | --label_len 48 \ 16 | --pred_len 96 \ 17 | --e_layers 1 \ 18 | --factor 3 \ 19 | --enc_in 21 \ 20 | --dec_in 21 \ 21 | --c_out 21 \ 22 | --des 'Exp' \ 23 | --d_model 256 \ 24 | --d_ff 512 \ 25 | --batch_size 4 \ 26 | --itr 1 \ 27 | 28 | python3 -u run.py \ 29 | --task_name long_term_forecast \ 30 | --is_training 1 \ 31 | --root_path ./dataset/weather/ \ 32 | --data_path weather.csv \ 33 | --model_id weather_96_192 \ 34 | --model $model_name \ 35 | --data custom \ 36 | --features M \ 37 | --seq_len 96 \ 38 | --label_len 48 \ 39 | --pred_len 192 \ 40 | --e_layers 3 \ 41 | --factor 3 \ 42 | --enc_in 21 \ 43 | --dec_in 21 \ 44 | --c_out 21 \ 45 | --des 'Exp' \ 46 | --d_model 128 \ 47 | --d_ff 1024 \ 48 | --batch_size 4 \ 49 | --itr 1 50 | 51 | python3 -u run.py \ 52 | --task_name long_term_forecast \ 53 | --is_training 1 \ 54 | --root_path ./dataset/weather/ \ 55 | --data_path weather.csv \ 56 | --model_id weather_96_336 \ 57 | --model $model_name \ 58 | --data custom \ 59 | --features M \ 60 | --seq_len 96 \ 61 | --label_len 48 \ 62 | --pred_len 336 \ 63 | --e_layers 1 \ 64 | --factor 3 \ 65 | --enc_in 21 \ 66 | --dec_in 21 \ 67 | --c_out 21 \ 68 | --des 'Exp' \ 69 | --d_model 256 \ 70 | --batch_size 4 \ 71 | --itr 1 72 | 73 | python3 -u run.py \ 74 | --task_name long_term_forecast \ 75 | --is_training 1 \ 76 | --root_path ./dataset/weather/ \ 77 | --data_path weather.csv \ 78 | --model_id weather_96_720 \ 79 | --model $model_name \ 80 | --data custom \ 81 | --features M \ 82 | --seq_len 96 \ 83 | --label_len 48 \ 84 | --pred_len 720 \ 85 | --e_layers 1 \ 86 | --factor 3 \ 87 | --enc_in 21 \ 88 | --dec_in 21 \ 89 | --c_out 21 \ 90 | --des 'Exp' \ 91 | --d_model 128 \ 92 | --batch_size 4 \ 93 | --itr 1 94 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/TimeXer/76011909357972bd55a27adba2e1be994d81b327/utils/__init__.py -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | Loss functions for PyTorch. 17 | """ 18 | 19 | import torch as t 20 | import torch.nn as nn 21 | import numpy as np 22 | import pdb 23 | 24 | 25 | def divide_no_nan(a, b): 26 | """ 27 | a/b where the resulted NaN or Inf are replaced by 0. 28 | """ 29 | result = a / b 30 | result[result != result] = .0 31 | result[result == np.inf] = .0 32 | return result 33 | 34 | 35 | class mape_loss(nn.Module): 36 | def __init__(self): 37 | super(mape_loss, self).__init__() 38 | 39 | def forward(self, insample: t.Tensor, freq: int, 40 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 41 | """ 42 | MAPE loss as defined in: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error 43 | 44 | :param forecast: Forecast values. Shape: batch, time 45 | :param target: Target values. Shape: batch, time 46 | :param mask: 0/1 mask. Shape: batch, time 47 | :return: Loss value 48 | """ 49 | weights = divide_no_nan(mask, target) 50 | return t.mean(t.abs((forecast - target) * weights)) 51 | 52 | 53 | class smape_loss(nn.Module): 54 | def __init__(self): 55 | super(smape_loss, self).__init__() 56 | 57 | def forward(self, insample: t.Tensor, freq: int, 58 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 59 | """ 60 | sMAPE loss as defined in https://robjhyndman.com/hyndsight/smape/ (Makridakis 1993) 61 | 62 | :param forecast: Forecast values. Shape: batch, time 63 | :param target: Target values. Shape: batch, time 64 | :param mask: 0/1 mask. Shape: batch, time 65 | :return: Loss value 66 | """ 67 | return 200 * t.mean(divide_no_nan(t.abs(forecast - target), 68 | t.abs(forecast.data) + t.abs(target.data)) * mask) 69 | 70 | 71 | class mase_loss(nn.Module): 72 | def __init__(self): 73 | super(mase_loss, self).__init__() 74 | 75 | def forward(self, insample: t.Tensor, freq: int, 76 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 77 | """ 78 | MASE loss as defined in "Scaled Errors" https://robjhyndman.com/papers/mase.pdf 79 | 80 | :param insample: Insample values. Shape: batch, time_i 81 | :param freq: Frequency value 82 | :param forecast: Forecast values. Shape: batch, time_o 83 | :param target: Target values. Shape: batch, time_o 84 | :param mask: 0/1 mask. Shape: batch, time_o 85 | :return: Loss value 86 | """ 87 | masep = t.mean(t.abs(insample[:, freq:] - insample[:, :-freq]), dim=1) 88 | masked_masep_inv = divide_no_nan(mask, masep[:, None]) 89 | return t.mean(t.abs(target - forecast) * masked_masep_inv) 90 | -------------------------------------------------------------------------------- /utils/m4_summary.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | M4 Summary 17 | """ 18 | from collections import OrderedDict 19 | 20 | import numpy as np 21 | import pandas as pd 22 | 23 | from data_provider.m4 import M4Dataset 24 | from data_provider.m4 import M4Meta 25 | import os 26 | 27 | 28 | def group_values(values, groups, group_name): 29 | return np.array([v[~np.isnan(v)] for v in values[groups == group_name]]) 30 | 31 | 32 | def mase(forecast, insample, outsample, frequency): 33 | return np.mean(np.abs(forecast - outsample)) / np.mean(np.abs(insample[:-frequency] - insample[frequency:])) 34 | 35 | 36 | def smape_2(forecast, target): 37 | denom = np.abs(target) + np.abs(forecast) 38 | # divide by 1.0 instead of 0.0, in case when denom is zero the enumerator will be 0.0 anyway. 39 | denom[denom == 0.0] = 1.0 40 | return 200 * np.abs(forecast - target) / denom 41 | 42 | 43 | def mape(forecast, target): 44 | denom = np.abs(target) 45 | # divide by 1.0 instead of 0.0, in case when denom is zero the enumerator will be 0.0 anyway. 46 | denom[denom == 0.0] = 1.0 47 | return 100 * np.abs(forecast - target) / denom 48 | 49 | 50 | class M4Summary: 51 | def __init__(self, file_path, root_path): 52 | self.file_path = file_path 53 | self.training_set = M4Dataset.load(training=True, dataset_file=root_path) 54 | self.test_set = M4Dataset.load(training=False, dataset_file=root_path) 55 | self.naive_path = os.path.join(root_path, 'submission-Naive2.csv') 56 | 57 | def evaluate(self): 58 | """ 59 | Evaluate forecasts using M4 test dataset. 60 | 61 | :param forecast: Forecasts. Shape: timeseries, time. 62 | :return: sMAPE and OWA grouped by seasonal patterns. 63 | """ 64 | grouped_owa = OrderedDict() 65 | 66 | naive2_forecasts = pd.read_csv(self.naive_path).values[:, 1:].astype(np.float32) 67 | naive2_forecasts = np.array([v[~np.isnan(v)] for v in naive2_forecasts]) 68 | 69 | model_mases = {} 70 | naive2_smapes = {} 71 | naive2_mases = {} 72 | grouped_smapes = {} 73 | grouped_mapes = {} 74 | for group_name in M4Meta.seasonal_patterns: 75 | file_name = self.file_path + group_name + "_forecast.csv" 76 | if os.path.exists(file_name): 77 | model_forecast = pd.read_csv(file_name).values 78 | 79 | naive2_forecast = group_values(naive2_forecasts, self.test_set.groups, group_name) 80 | target = group_values(self.test_set.values, self.test_set.groups, group_name) 81 | # all timeseries within group have same frequency 82 | frequency = self.training_set.frequencies[self.test_set.groups == group_name][0] 83 | insample = group_values(self.training_set.values, self.test_set.groups, group_name) 84 | 85 | model_mases[group_name] = np.mean([mase(forecast=model_forecast[i], 86 | insample=insample[i], 87 | outsample=target[i], 88 | frequency=frequency) for i in range(len(model_forecast))]) 89 | naive2_mases[group_name] = np.mean([mase(forecast=naive2_forecast[i], 90 | insample=insample[i], 91 | outsample=target[i], 92 | frequency=frequency) for i in range(len(model_forecast))]) 93 | 94 | naive2_smapes[group_name] = np.mean(smape_2(naive2_forecast, target)) 95 | grouped_smapes[group_name] = np.mean(smape_2(forecast=model_forecast, target=target)) 96 | grouped_mapes[group_name] = np.mean(mape(forecast=model_forecast, target=target)) 97 | 98 | grouped_smapes = self.summarize_groups(grouped_smapes) 99 | grouped_mapes = self.summarize_groups(grouped_mapes) 100 | grouped_model_mases = self.summarize_groups(model_mases) 101 | grouped_naive2_smapes = self.summarize_groups(naive2_smapes) 102 | grouped_naive2_mases = self.summarize_groups(naive2_mases) 103 | for k in grouped_model_mases.keys(): 104 | grouped_owa[k] = (grouped_model_mases[k] / grouped_naive2_mases[k] + 105 | grouped_smapes[k] / grouped_naive2_smapes[k]) / 2 106 | 107 | def round_all(d): 108 | return dict(map(lambda kv: (kv[0], np.round(kv[1], 3)), d.items())) 109 | 110 | return round_all(grouped_smapes), round_all(grouped_owa), round_all(grouped_mapes), round_all( 111 | grouped_model_mases) 112 | 113 | def summarize_groups(self, scores): 114 | """ 115 | Re-group scores respecting M4 rules. 116 | :param scores: Scores per group. 117 | :return: Grouped scores. 118 | """ 119 | scores_summary = OrderedDict() 120 | 121 | def group_count(group_name): 122 | return len(np.where(self.test_set.groups == group_name)[0]) 123 | 124 | weighted_score = {} 125 | for g in ['Yearly', 'Quarterly', 'Monthly']: 126 | weighted_score[g] = scores[g] * group_count(g) 127 | scores_summary[g] = scores[g] 128 | 129 | others_score = 0 130 | others_count = 0 131 | for g in ['Weekly', 'Daily', 'Hourly']: 132 | others_score += scores[g] * group_count(g) 133 | others_count += group_count(g) 134 | weighted_score['Others'] = others_score 135 | scores_summary['Others'] = others_score / others_count 136 | 137 | average = np.sum(list(weighted_score.values())) / len(self.test_set.groups) 138 | scores_summary['Average'] = average 139 | 140 | return scores_summary 141 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 11 | return (u / d).mean(-1) 12 | 13 | 14 | def MAE(pred, true): 15 | return np.mean(np.abs(pred - true)) 16 | 17 | 18 | def MSE(pred, true): 19 | return np.mean((pred - true) ** 2) 20 | 21 | 22 | def RMSE(pred, true): 23 | return np.sqrt(MSE(pred, true)) 24 | 25 | 26 | def MAPE(pred, true): 27 | return np.mean(np.abs((pred - true) / true)) 28 | 29 | 30 | def MSPE(pred, true): 31 | return np.mean(np.square((pred - true) / true)) 32 | 33 | # 标准差 34 | def STD(pred, true): 35 | return np.std(pred - true) 36 | 37 | def metric(pred, true): 38 | mae = MAE(pred, true) 39 | mse = MSE(pred, true) 40 | rmse = RMSE(pred, true) 41 | mape = MAPE(pred, true) 42 | mspe = MSPE(pred, true) 43 | 44 | return mae, mse, rmse, mape, mspe 45 | -------------------------------------------------------------------------------- /utils/print_args.py: -------------------------------------------------------------------------------- 1 | def print_args(args): 2 | print("\033[1m" + "Basic Config" + "\033[0m") 3 | print(f' {"Task Name:":<20}{args.task_name:<20}{"Is Training:":<20}{args.is_training:<20}') 4 | print(f' {"Model ID:":<20}{args.model_id:<20}{"Model:":<20}{args.model:<20}') 5 | print() 6 | 7 | print("\033[1m" + "Data Loader" + "\033[0m") 8 | print(f' {"Data:":<20}{args.data:<20}{"Root Path:":<20}{args.root_path:<20}') 9 | print(f' {"Data Path:":<20}{args.data_path:<20}{"Features:":<20}{args.features:<20}') 10 | print(f' {"Target:":<20}{args.target:<20}{"Freq:":<20}{args.freq:<20}') 11 | print(f' {"Checkpoints:":<20}{args.checkpoints:<20}') 12 | print() 13 | 14 | if args.task_name in ['long_term_forecast', 'short_term_forecast']: 15 | print("\033[1m" + "Forecasting Task" + "\033[0m") 16 | print(f' {"Seq Len:":<20}{args.seq_len:<20}{"Label Len:":<20}{args.label_len:<20}') 17 | print(f' {"Pred Len:":<20}{args.pred_len:<20}{"Seasonal Patterns:":<20}{args.seasonal_patterns:<20}') 18 | print(f' {"Inverse:":<20}{args.inverse:<20}') 19 | print() 20 | 21 | if args.task_name == 'imputation': 22 | print("\033[1m" + "Imputation Task" + "\033[0m") 23 | print(f' {"Mask Rate:":<20}{args.mask_rate:<20}') 24 | print() 25 | 26 | if args.task_name == 'anomaly_detection': 27 | print("\033[1m" + "Anomaly Detection Task" + "\033[0m") 28 | print(f' {"Anomaly Ratio:":<20}{args.anomaly_ratio:<20}') 29 | print() 30 | 31 | print("\033[1m" + "Model Parameters" + "\033[0m") 32 | print(f' {"Top k:":<20}{args.top_k:<20}{"Num Kernels:":<20}{args.num_kernels:<20}') 33 | print(f' {"Enc In:":<20}{args.enc_in:<20}{"Dec In:":<20}{args.dec_in:<20}') 34 | print(f' {"C Out:":<20}{args.c_out:<20}{"d model:":<20}{args.d_model:<20}') 35 | print(f' {"n heads:":<20}{args.n_heads:<20}{"e layers:":<20}{args.e_layers:<20}') 36 | print(f' {"d layers:":<20}{args.d_layers:<20}{"d FF:":<20}{args.d_ff:<20}') 37 | print(f' {"Moving Avg:":<20}{args.moving_avg:<20}{"Factor:":<20}{args.factor:<20}') 38 | print(f' {"Distil:":<20}{args.distil:<20}{"Dropout:":<20}{args.dropout:<20}') 39 | print(f' {"Embed:":<20}{args.embed:<20}{"Activation:":<20}{args.activation:<20}') 40 | print() 41 | 42 | print("\033[1m" + "Run Parameters" + "\033[0m") 43 | print(f' {"Num Workers:":<20}{args.num_workers:<20}{"Itr:":<20}{args.itr:<20}') 44 | print(f' {"Train Epochs:":<20}{args.train_epochs:<20}{"Batch Size:":<20}{args.batch_size:<20}') 45 | print(f' {"Patience:":<20}{args.patience:<20}{"Learning Rate:":<20}{args.learning_rate:<20}') 46 | print(f' {"Des:":<20}{args.des:<20}{"Loss:":<20}{args.loss:<20}') 47 | print(f' {"Lradj:":<20}{args.lradj:<20}{"Use Amp:":<20}{args.use_amp:<20}') 48 | print() 49 | 50 | print("\033[1m" + "GPU" + "\033[0m") 51 | print(f' {"Use GPU:":<20}{args.use_gpu:<20}{"GPU:":<20}{args.gpu:<20}') 52 | print(f' {"Use Multi GPU:":<20}{args.use_multi_gpu:<20}{"Devices:":<20}{args.devices:<20}') 53 | print() 54 | 55 | print("\033[1m" + "De-stationary Projector Params" + "\033[0m") 56 | p_hidden_dims_str = ', '.join(map(str, args.p_hidden_dims)) 57 | print(f' {"P Hidden Dims:":<20}{p_hidden_dims_str:<20}{"P Hidden Layers:":<20}{args.p_hidden_layers:<20}') 58 | print() -------------------------------------------------------------------------------- /utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | # From: gluonts/src/gluonts/time_feature/_base.py 2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"). 5 | # You may not use this file except in compliance with the License. 6 | # A copy of the License is located at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # or in the "license" file accompanying this file. This file is distributed 11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 | # express or implied. See the License for the specific language governing 13 | # permissions and limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from pandas.tseries import offsets 20 | from pandas.tseries.frequencies import to_offset 21 | 22 | 23 | class TimeFeature: 24 | def __init__(self): 25 | pass 26 | 27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 28 | pass 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + "()" 32 | 33 | 34 | class SecondOfMinute(TimeFeature): 35 | """Minute of hour encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.second / 59.0 - 0.5 39 | 40 | 41 | class MinuteOfHour(TimeFeature): 42 | """Minute of hour encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.minute / 59.0 - 0.5 46 | 47 | 48 | class HourOfDay(TimeFeature): 49 | """Hour of day encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return index.hour / 23.0 - 0.5 53 | 54 | 55 | class DayOfWeek(TimeFeature): 56 | """Hour of day encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return index.dayofweek / 6.0 - 0.5 60 | 61 | 62 | class DayOfMonth(TimeFeature): 63 | """Day of month encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.day - 1) / 30.0 - 0.5 67 | 68 | 69 | class DayOfYear(TimeFeature): 70 | """Day of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.dayofyear - 1) / 365.0 - 0.5 74 | 75 | 76 | class MonthOfYear(TimeFeature): 77 | """Month of year encoded as value between [-0.5, 0.5]""" 78 | 79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 80 | return (index.month - 1) / 11.0 - 0.5 81 | 82 | 83 | class WeekOfYear(TimeFeature): 84 | """Week of year encoded as value between [-0.5, 0.5]""" 85 | 86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 87 | return (index.isocalendar().week - 1) / 52.0 - 0.5 88 | 89 | 90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 91 | """ 92 | Returns a list of time features that will be appropriate for the given frequency string. 93 | Parameters 94 | ---------- 95 | freq_str 96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 97 | """ 98 | 99 | features_by_offsets = { 100 | offsets.YearEnd: [], 101 | offsets.QuarterEnd: [MonthOfYear], 102 | offsets.MonthEnd: [MonthOfYear], 103 | offsets.Week: [DayOfMonth, WeekOfYear], 104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 107 | offsets.Minute: [ 108 | MinuteOfHour, 109 | HourOfDay, 110 | DayOfWeek, 111 | DayOfMonth, 112 | DayOfYear, 113 | ], 114 | offsets.Second: [ 115 | SecondOfMinute, 116 | MinuteOfHour, 117 | HourOfDay, 118 | DayOfWeek, 119 | DayOfMonth, 120 | DayOfYear, 121 | ], 122 | } 123 | 124 | offset = to_offset(freq_str) 125 | 126 | for offset_type, feature_classes in features_by_offsets.items(): 127 | if isinstance(offset, offset_type): 128 | return [cls() for cls in feature_classes] 129 | 130 | supported_freq_msg = f""" 131 | Unsupported frequency {freq_str} 132 | The following frequencies are supported: 133 | Y - yearly 134 | alias: A 135 | M - monthly 136 | W - weekly 137 | D - daily 138 | B - business days 139 | H - hourly 140 | T - minutely 141 | alias: min 142 | S - secondly 143 | """ 144 | raise RuntimeError(supported_freq_msg) 145 | 146 | 147 | def time_features(dates, freq='h'): 148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 149 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import torch 7 | import matplotlib.pyplot as plt 8 | import pandas as pd 9 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format 10 | 11 | plt.switch_backend('agg') 12 | 13 | 14 | def adjust_learning_rate(optimizer, epoch, args): 15 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 16 | if args.lradj == 'type1': 17 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 18 | elif args.lradj == 'type2': 19 | lr_adjust = { 20 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 21 | 10: 5e-7, 15: 1e-7, 20: 5e-8 22 | } 23 | if epoch in lr_adjust.keys(): 24 | lr = lr_adjust[epoch] 25 | for param_group in optimizer.param_groups: 26 | param_group['lr'] = lr 27 | print('Updating learning rate to {}'.format(lr)) 28 | 29 | 30 | class EarlyStopping: 31 | def __init__(self, patience=7, verbose=False, delta=0): 32 | self.patience = patience 33 | self.verbose = verbose 34 | self.counter = 0 35 | self.best_score = None 36 | self.early_stop = False 37 | self.val_loss_min = np.Inf 38 | self.delta = delta 39 | 40 | def __call__(self, val_loss, model, path): 41 | score = -val_loss 42 | if self.best_score is None: 43 | self.best_score = score 44 | self.save_checkpoint(val_loss, model, path) 45 | elif score < self.best_score + self.delta: 46 | self.counter += 1 47 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 48 | if self.counter >= self.patience: 49 | self.early_stop = True 50 | else: 51 | self.best_score = score 52 | self.save_checkpoint(val_loss, model, path) 53 | self.counter = 0 54 | 55 | def save_checkpoint(self, val_loss, model, path): 56 | if self.verbose: 57 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 58 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 59 | self.val_loss_min = val_loss 60 | 61 | 62 | class dotdict(dict): 63 | """dot.notation access to dictionary attributes""" 64 | __getattr__ = dict.get 65 | __setattr__ = dict.__setitem__ 66 | __delattr__ = dict.__delitem__ 67 | 68 | 69 | class StandardScaler(): 70 | def __init__(self, mean, std): 71 | self.mean = mean 72 | self.std = std 73 | 74 | def transform(self, data): 75 | return (data - self.mean) / self.std 76 | 77 | def inverse_transform(self, data): 78 | return (data * self.std) + self.mean 79 | 80 | 81 | def visual(true, preds=None, name='./pic/test.pdf'): 82 | """ 83 | Results visualization 84 | """ 85 | plt.figure() 86 | plt.plot(true, label='GroundTruth', linewidth=2) 87 | if preds is not None: 88 | plt.plot(preds, label='Prediction', linewidth=2) 89 | plt.legend() 90 | plt.savefig(name, bbox_inches='tight') 91 | 92 | 93 | def adjustment(gt, pred): 94 | anomaly_state = False 95 | for i in range(len(gt)): 96 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 97 | anomaly_state = True 98 | for j in range(i, 0, -1): 99 | if gt[j] == 0: 100 | break 101 | else: 102 | if pred[j] == 0: 103 | pred[j] = 1 104 | for j in range(i, len(gt)): 105 | if gt[j] == 0: 106 | break 107 | else: 108 | if pred[j] == 0: 109 | pred[j] = 1 110 | elif gt[i] == 0: 111 | anomaly_state = False 112 | if anomaly_state: 113 | pred[i] = 1 114 | return gt, pred 115 | 116 | 117 | def cal_accuracy(y_pred, y_true): 118 | return np.mean(y_pred == y_true) 119 | 120 | 121 | def custom_collate(batch): 122 | r"""source: pytorch 1.9.0, only one modification to original code """ 123 | 124 | elem = batch[0] 125 | elem_type = type(elem) 126 | if isinstance(elem, torch.Tensor): 127 | out = None 128 | if torch.utils.data.get_worker_info() is not None: 129 | # If we're in a background process, concatenate directly into a 130 | # shared memory tensor to avoid an extra copy 131 | numel = sum([x.numel() for x in batch]) 132 | storage = elem.storage()._new_shared(numel) 133 | out = elem.new(storage) 134 | return torch.stack(batch, 0, out=out) 135 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 136 | and elem_type.__name__ != 'string_': 137 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 138 | # array of string classes and object 139 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 140 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 141 | 142 | return custom_collate([torch.as_tensor(b) for b in batch]) 143 | elif elem.shape == (): # scalars 144 | return torch.as_tensor(batch) 145 | elif isinstance(elem, float): 146 | return torch.tensor(batch, dtype=torch.float64) 147 | elif isinstance(elem, int): 148 | return torch.tensor(batch) 149 | elif isinstance(elem, str): 150 | return batch 151 | elif isinstance(elem, collections.abc.Mapping): 152 | return {key: custom_collate([d[key] for d in batch]) for key in elem} 153 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 154 | return elem_type(*(custom_collate(samples) for samples in zip(*batch))) 155 | elif isinstance(elem, collections.abc.Sequence): 156 | # check to make sure that the elements in batch have consistent size 157 | it = iter(batch) 158 | elem_size = len(next(it)) 159 | if not all(len(elem) == elem_size for elem in it): 160 | raise RuntimeError('each element in list of batch should be of equal size') 161 | transposed = zip(*batch) 162 | return [custom_collate(samples) for samples in transposed] 163 | 164 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 165 | 166 | class HiddenPrints: 167 | def __init__(self, rank): 168 | # 如果rank是none,那么就是单机单卡,不需要隐藏打印,将rank设置为0 169 | if rank is None: 170 | rank = 0 171 | self.rank = rank 172 | def __enter__(self): 173 | if self.rank == 0: 174 | return 175 | self._original_stdout = sys.stdout 176 | sys.stdout = open(os.devnull, 'w') 177 | 178 | def __exit__(self, exc_type, exc_val, exc_tb): 179 | if self.rank == 0: 180 | return 181 | sys.stdout.close() 182 | sys.stdout = self._original_stdout 183 | --------------------------------------------------------------------------------