├── .gitignore ├── LICENSE ├── README.md ├── data_provider ├── __init__.py ├── data_factory.py ├── data_loader.py ├── m4.py └── uea.py ├── exp ├── __init__.py ├── exp_basic.py └── exp_long_term_forecasting.py ├── layers ├── AutoCorrelation.py ├── Autoformer_EncDec.py ├── Conv_Blocks.py ├── Crossformer_EncDec.py ├── ETSformer_EncDec.py ├── Embed.py ├── FourierCorrelation.py ├── MultiWaveletCorrelation.py ├── Pyraformer_EncDec.py ├── SelfAttention_Family.py ├── Transformer_EncDec.py ├── __init__.py ├── linear_family.py └── mamba_ssm │ ├── __init__.py │ ├── mamba2_simple.py │ ├── mamba_simple.py │ └── mixer2_seq_simple.py ├── models ├── Autoformer.py ├── Crossformer.py ├── DLinear.py ├── FEDformer.py ├── FourierGNN.py ├── MICN.py ├── MambaTS.py ├── PatchTST.py ├── __init__.py └── iTransformer.py ├── requirements.txt ├── run.py ├── scripts └── long_term_forecast │ ├── MambaTS_COV.sh │ ├── MambaTS_ECL.sh │ ├── MambaTS_ETTh2.sh │ ├── MambaTS_ETTm2.sh │ ├── MambaTS_PEMS.sh │ ├── MambaTS_SOL.sh │ ├── MambaTS_TFF.sh │ └── MambaTS_WTH.sh └── utils ├── __init__.py ├── losses.py ├── lr_scheduler.py ├── m4_summary.py ├── masking.py ├── metrics.py ├── optim_utils.py ├── print_args.py ├── timefeatures.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | checkpoints/* 3 | dataset/* 4 | visuals/* 5 | results/* 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | /checkpoints/ 137 | /run_ane.py 138 | /dataset/ 139 | /result_long_term_forecast.txt 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Xiuding Cai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MambaTS 2 | 3 | The repo is the official implementation for the paper: [MambaTS: Improved Selective State Space Models for Long-term Time Series Forecasting](http://arxiv.org/abs/2405.16440). 4 | 5 | Key codes: 6 | 7 | * For the architecture design of MambaTS, please refer primarily to `models/MambaTS.py`. 8 | * For variable permutation training (VPT), please focus on the `random_shuffle` and `unshuffle` function in `utils/masking.py`. 9 | * For variable-aware scan along time (VAST), please refer mainly to `layers/mamba_ssm/mixer2_seq_simple.py`. 10 | 11 | Recently, we've also released a repo tracking the latest developments in Mamba. If you're interested, you can check it out at [Awesome-Mamba-Collection](https://github.com/XiudingCai/Awesome-Mamba-Collection) and enjoy it. 12 | 13 | ## Usage 14 | 15 | 1. Install Python 3.11. For convenience, execute the following command. 16 | 17 | ``` 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | 2. For setting up the Mamba environment, please refer to https://github.com/state-spaces/mamba. Here is a simple instruction on Linux system, 22 | 23 | ``` 24 | pip install causal-conv1d>=1.2.0 25 | pip install mamba-ssm 26 | ``` 27 | 28 | 3. Prepare Data. You can obtain the well pre-processed datasets from public channel like [[Google Drive]](https://drive.google.com/drive/folders/13Cg1KYOlzM5C7K8gK8NfC-F3EYxkM3D2?usp=sharing) or [[Tsinghua Cloud]](https://cloud.tsinghua.edu.cn/f/2ea5ca3d621e4e5ba36a/), Then place the downloaded data in the folder`./dataset`. 29 | 30 | 4. Train and evaluate model. We provide the experiment scripts for MambaTS under the folder `./scripts/`. You can reproduce the experiment results as the following examples: 31 | 32 | ``` 33 | # long-term forecast for ETTm2 dataset 34 | bash ./scripts/long_term_forecast/MambaTS_ETTm2.sh 35 | ``` 36 | 37 | ## Acknowledgement 38 | 39 | This library is constructed based on the following repos: 40 | 41 | - Time-Series-Library: https://github.com/thuml/Time-Series-Library 42 | - Mamba: https://github.com/state-spaces/mamba 43 | 44 | All the experiment datasets are public, and we obtain them from the following links: 45 | 46 | - Long-term Forecasting: https://github.com/thuml/Autoformer and https://github.com/thuml/iTransformer. 47 | 48 | We also greatly appreciate the [python-tsp](https://github.com/fillipe-gsm/python-tsp) for providing efficient solvers for the Asymmetric Traveling Salesperson Problem (ATSP). 49 | 50 | We extend our sincere thanks for their excellent work and repositories! 51 | 52 | ## Citation 53 | 54 | If you find this repo useful, please consider citing our paper. 55 | 56 | ``` 57 | @article{cai2024mambats, 58 | title={MambaTS: Improved Selective State Space Models for Long-term Time Series Forecasting}, 59 | author={Cai, Xiuding and Zhu, Yaoyao and Wang, Xueyao and Yao, Yu}, 60 | journal={arXiv preprint arXiv:2405.16440}, 61 | year={2024} 62 | } 63 | ``` -------------------------------------------------------------------------------- /data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4, PSMSegLoader, \ 5 | MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader, Dataset_Solar, Dataset_PEMS 6 | from data_provider.uea import collate_fn 7 | from torch.utils.data import DataLoader 8 | 9 | data_dict = { 10 | 'ETTh1': Dataset_ETT_hour, 11 | 'ETTh2': Dataset_ETT_hour, 12 | 'ETTm1': Dataset_ETT_minute, 13 | 'ETTm2': Dataset_ETT_minute, 14 | 'custom': Dataset_Custom, 15 | 'Solar': Dataset_Solar, 16 | 'PEMS': Dataset_PEMS, 17 | 'm4': Dataset_M4, 18 | 'PSM': PSMSegLoader, 19 | 'MSL': MSLSegLoader, 20 | 'SMAP': SMAPSegLoader, 21 | 'SMD': SMDSegLoader, 22 | 'SWAT': SWATSegLoader, 23 | 'UEA': UEAloader, 24 | } 25 | 26 | 27 | def data_provider(args, flag): 28 | Data = data_dict[args.data] 29 | if args.embed in ['fixed', 'learned']: 30 | timeenc = 0 31 | elif args.embed == 'timeF': 32 | timeenc = 1 33 | else: 34 | timeenc = 2 35 | 36 | if flag == 'test': 37 | shuffle_flag = False 38 | drop_last = True 39 | if args.task_name == 'anomaly_detection' or args.task_name == 'classification': 40 | batch_size = args.batch_size 41 | else: 42 | batch_size = 1 # bsz=1 for evaluation 43 | freq = args.freq 44 | else: 45 | shuffle_flag = True 46 | drop_last = True 47 | batch_size = args.batch_size # bsz for train and valid 48 | freq = args.freq 49 | 50 | if args.task_name == 'anomaly_detection': 51 | drop_last = False 52 | data_set = Data( 53 | root_path=args.root_path, 54 | win_size=args.seq_len, 55 | flag=flag, 56 | ) 57 | print(flag, len(data_set)) 58 | data_loader = DataLoader( 59 | data_set, 60 | batch_size=batch_size, 61 | shuffle=shuffle_flag, 62 | num_workers=args.num_workers, 63 | drop_last=drop_last) 64 | return data_set, data_loader 65 | elif args.task_name == 'classification': 66 | drop_last = False 67 | data_set = Data( 68 | root_path=args.root_path, 69 | flag=flag, 70 | ) 71 | 72 | data_loader = DataLoader( 73 | data_set, 74 | batch_size=batch_size, 75 | shuffle=shuffle_flag, 76 | num_workers=args.num_workers, 77 | drop_last=drop_last, 78 | collate_fn=lambda x: collate_fn(x, max_len=args.seq_len) 79 | ) 80 | return data_set, data_loader 81 | else: 82 | if args.data == 'm4': 83 | drop_last = False 84 | data_set = Data( 85 | root_path=args.root_path, 86 | data_path=args.data_path, 87 | flag=flag, 88 | size=[args.seq_len, args.label_len, args.pred_len], 89 | features=args.features, 90 | target=args.target, 91 | timeenc=timeenc, 92 | freq=freq, 93 | seasonal_patterns=args.seasonal_patterns, 94 | args=args 95 | ) 96 | print(flag, len(data_set)) 97 | data_loader = DataLoader( 98 | data_set, 99 | batch_size=batch_size, 100 | shuffle=shuffle_flag, 101 | num_workers=args.num_workers, 102 | drop_last=drop_last) 103 | return data_set, data_loader 104 | -------------------------------------------------------------------------------- /data_provider/m4.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | M4 Dataset 17 | """ 18 | import logging 19 | import os 20 | from collections import OrderedDict 21 | from dataclasses import dataclass 22 | from glob import glob 23 | 24 | import numpy as np 25 | import pandas as pd 26 | import patoolib 27 | from tqdm import tqdm 28 | import logging 29 | import os 30 | import pathlib 31 | import sys 32 | from urllib import request 33 | 34 | 35 | def url_file_name(url: str) -> str: 36 | """ 37 | Extract file name from url. 38 | 39 | :param url: URL to extract file name from. 40 | :return: File name. 41 | """ 42 | return url.split('/')[-1] if len(url) > 0 else '' 43 | 44 | 45 | def download(url: str, file_path: str) -> None: 46 | """ 47 | Download a file to the given path. 48 | 49 | :param url: URL to download 50 | :param file_path: Where to download the content. 51 | """ 52 | 53 | def progress(count, block_size, total_size): 54 | progress_pct = float(count * block_size) / float(total_size) * 100.0 55 | sys.stdout.write('\rDownloading {} to {} {:.1f}%'.format(url, file_path, progress_pct)) 56 | sys.stdout.flush() 57 | 58 | if not os.path.isfile(file_path): 59 | opener = request.build_opener() 60 | opener.addheaders = [('User-agent', 'Mozilla/5.0')] 61 | request.install_opener(opener) 62 | pathlib.Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True) 63 | f, _ = request.urlretrieve(url, file_path, progress) 64 | sys.stdout.write('\n') 65 | sys.stdout.flush() 66 | file_info = os.stat(f) 67 | logging.info(f'Successfully downloaded {os.path.basename(file_path)} {file_info.st_size} bytes.') 68 | else: 69 | file_info = os.stat(file_path) 70 | logging.info(f'File already exists: {file_path} {file_info.st_size} bytes.') 71 | 72 | 73 | @dataclass() 74 | class M4Dataset: 75 | ids: np.ndarray 76 | groups: np.ndarray 77 | frequencies: np.ndarray 78 | horizons: np.ndarray 79 | values: np.ndarray 80 | 81 | @staticmethod 82 | def load(training: bool = True, dataset_file: str = '../dataset/m4') -> 'M4Dataset': 83 | """ 84 | Load cached dataset. 85 | 86 | :param training: Load training part if training is True, test part otherwise. 87 | """ 88 | info_file = os.path.join(dataset_file, 'M4-info.csv') 89 | train_cache_file = os.path.join(dataset_file, 'training.npz') 90 | test_cache_file = os.path.join(dataset_file, 'test.npz') 91 | m4_info = pd.read_csv(info_file) 92 | return M4Dataset(ids=m4_info.M4id.values, 93 | groups=m4_info.SP.values, 94 | frequencies=m4_info.Frequency.values, 95 | horizons=m4_info.Horizon.values, 96 | values=np.load( 97 | train_cache_file if training else test_cache_file, 98 | allow_pickle=True)) 99 | 100 | 101 | @dataclass() 102 | class M4Meta: 103 | seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly'] 104 | horizons = [6, 8, 18, 13, 14, 48] 105 | frequencies = [1, 4, 12, 1, 1, 24] 106 | horizons_map = { 107 | 'Yearly': 6, 108 | 'Quarterly': 8, 109 | 'Monthly': 18, 110 | 'Weekly': 13, 111 | 'Daily': 14, 112 | 'Hourly': 48 113 | } # different predict length 114 | frequency_map = { 115 | 'Yearly': 1, 116 | 'Quarterly': 4, 117 | 'Monthly': 12, 118 | 'Weekly': 1, 119 | 'Daily': 1, 120 | 'Hourly': 24 121 | } 122 | history_size = { 123 | 'Yearly': 1.5, 124 | 'Quarterly': 1.5, 125 | 'Monthly': 1.5, 126 | 'Weekly': 10, 127 | 'Daily': 10, 128 | 'Hourly': 10 129 | } # from interpretable.gin 130 | 131 | 132 | def load_m4_info() -> pd.DataFrame: 133 | """ 134 | Load M4Info file. 135 | 136 | :return: Pandas DataFrame of M4Info. 137 | """ 138 | return pd.read_csv(INFO_FILE_PATH) 139 | -------------------------------------------------------------------------------- /data_provider/uea.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | 6 | 7 | def collate_fn(data, max_len=None): 8 | """Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create 9 | Args: 10 | data: len(batch_size) list of tuples (X, y). 11 | - X: torch tensor of shape (seq_length, feat_dim); variable seq_length. 12 | - y: torch tensor of shape (num_labels,) : class indices or numerical targets 13 | (for classification or regression, respectively). num_labels > 1 for multi-task models 14 | max_len: global fixed sequence length. Used for architectures requiring fixed length input, 15 | where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s 16 | Returns: 17 | X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input) 18 | targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output) 19 | target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor 20 | 0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values 21 | padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 means padding 22 | """ 23 | 24 | batch_size = len(data) 25 | features, labels = zip(*data) 26 | 27 | # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension) 28 | lengths = [X.shape[0] for X in features] # original sequence length for each time series 29 | if max_len is None: 30 | max_len = max(lengths) 31 | 32 | X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim) 33 | for i in range(batch_size): 34 | end = min(lengths[i], max_len) 35 | X[i, :end, :] = features[i][:end, :] 36 | 37 | targets = torch.stack(labels, dim=0) # (batch_size, num_labels) 38 | 39 | padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16), 40 | max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep 41 | 42 | return X, targets, padding_masks 43 | 44 | 45 | def padding_mask(lengths, max_len=None): 46 | """ 47 | Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths, 48 | where 1 means keep element at this position (time step) 49 | """ 50 | batch_size = lengths.numel() 51 | max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types 52 | return (torch.arange(0, max_len, device=lengths.device) 53 | .type_as(lengths) 54 | .repeat(batch_size, 1) 55 | .lt(lengths.unsqueeze(1))) 56 | 57 | 58 | class Normalizer(object): 59 | """ 60 | Normalizes dataframe across ALL contained rows (time steps). Different from per-sample normalization. 61 | """ 62 | 63 | def __init__(self, norm_type='standardization', mean=None, std=None, min_val=None, max_val=None): 64 | """ 65 | Args: 66 | norm_type: choose from: 67 | "standardization", "minmax": normalizes dataframe across ALL contained rows (time steps) 68 | "per_sample_std", "per_sample_minmax": normalizes each sample separately (i.e. across only its own rows) 69 | mean, std, min_val, max_val: optional (num_feat,) Series of pre-computed values 70 | """ 71 | 72 | self.norm_type = norm_type 73 | self.mean = mean 74 | self.std = std 75 | self.min_val = min_val 76 | self.max_val = max_val 77 | 78 | def normalize(self, df): 79 | """ 80 | Args: 81 | df: input dataframe 82 | Returns: 83 | df: normalized dataframe 84 | """ 85 | if self.norm_type == "standardization": 86 | if self.mean is None: 87 | self.mean = df.mean() 88 | self.std = df.std() 89 | return (df - self.mean) / (self.std + np.finfo(float).eps) 90 | 91 | elif self.norm_type == "minmax": 92 | if self.max_val is None: 93 | self.max_val = df.max() 94 | self.min_val = df.min() 95 | return (df - self.min_val) / (self.max_val - self.min_val + np.finfo(float).eps) 96 | 97 | elif self.norm_type == "per_sample_std": 98 | grouped = df.groupby(by=df.index) 99 | return (df - grouped.transform('mean')) / grouped.transform('std') 100 | 101 | elif self.norm_type == "per_sample_minmax": 102 | grouped = df.groupby(by=df.index) 103 | min_vals = grouped.transform('min') 104 | return (df - min_vals) / (grouped.transform('max') - min_vals + np.finfo(float).eps) 105 | 106 | else: 107 | raise (NameError(f'Normalize method "{self.norm_type}" not implemented')) 108 | 109 | 110 | def interpolate_missing(y): 111 | """ 112 | Replaces NaN values in pd.Series `y` using linear interpolation 113 | """ 114 | if y.isna().any(): 115 | y = y.interpolate(method='linear', limit_direction='both') 116 | return y 117 | 118 | 119 | def subsample(y, limit=256, factor=2): 120 | """ 121 | If a given Series is longer than `limit`, returns subsampled sequence by the specified integer factor 122 | """ 123 | if len(y) > limit: 124 | return y[::factor].reset_index(drop=True) 125 | return y 126 | -------------------------------------------------------------------------------- /exp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiudingCai/MambaTS-pytorch/3b6797d9bf5178a490e1bb3c3e9f4d223f2af51d/exp/__init__.py -------------------------------------------------------------------------------- /exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models import Autoformer, DLinear, FEDformer, PatchTST, MICN, Crossformer, iTransformer, MambaTS, FourierGNN 4 | 5 | 6 | class Exp_Basic(object): 7 | def __init__(self, args): 8 | self.args = args 9 | self.model_dict = { 10 | 'Autoformer': Autoformer, 11 | 'DLinear': DLinear, 12 | 'FEDformer': FEDformer, 13 | 'PatchTST': PatchTST, 14 | 'MICN': MICN, 15 | 'Crossformer': Crossformer, 16 | 'iTransformer': iTransformer, 17 | 'FourierGNN': FourierGNN, 18 | 'MambaTS': MambaTS, 19 | } 20 | self.device = self._acquire_device() 21 | self.model = self._build_model().to(self.device) 22 | 23 | # from torchsummary import summary 24 | # from torchinfo import summary 25 | # # torch.Size([16, 720, 321]) torch.Size([16, 768, 321]) torch.Size([16, 720, 4]) torch.Size([16, 768, 4]) 26 | # # print(batch_x.shape, batch_y.shape, batch_x_mark.shape, batch_y_mark.shape) 27 | # summary(self.model, [(self.args.batch_size, self.args.seq_len, self.args.enc_in), 28 | # (self.args.batch_size, self.args.seq_len, self.args.enc_in), 29 | # (self.args.batch_size, self.args.seq_len, 4), 30 | # (self.args.batch_size, self.args.seq_len, 4)], device='cuda') 31 | 32 | def _build_model(self): 33 | raise NotImplementedError 34 | return None 35 | 36 | def _acquire_device(self): 37 | if self.args.use_gpu: 38 | os.environ["CUDA_VISIBLE_DEVICES"] = str( 39 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices 40 | device = torch.device('cuda:{}'.format(self.args.gpu)) 41 | print('Use GPU: cuda:{}'.format(self.args.gpu)) 42 | else: 43 | device = torch.device('cpu') 44 | print('Use CPU') 45 | return device 46 | 47 | def _get_data(self): 48 | pass 49 | 50 | def vali(self): 51 | pass 52 | 53 | def train(self): 54 | pass 55 | 56 | def test(self): 57 | pass 58 | -------------------------------------------------------------------------------- /layers/AutoCorrelation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import math 7 | from math import sqrt 8 | import os 9 | 10 | 11 | class AutoCorrelation(nn.Module): 12 | """ 13 | AutoCorrelation Mechanism with the following two phases: 14 | (1) period-based dependencies discovery 15 | (2) time delay aggregation 16 | This block can replace the self-attention family mechanism seamlessly. 17 | """ 18 | 19 | def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False): 20 | super(AutoCorrelation, self).__init__() 21 | self.factor = factor 22 | self.scale = scale 23 | self.mask_flag = mask_flag 24 | self.output_attention = output_attention 25 | self.dropout = nn.Dropout(attention_dropout) 26 | 27 | def time_delay_agg_training(self, values, corr): 28 | """ 29 | SpeedUp version of Autocorrelation (a batch-normalization style design) 30 | This is for the training phase. 31 | """ 32 | head = values.shape[1] 33 | channel = values.shape[2] 34 | length = values.shape[3] 35 | # find top k 36 | top_k = int(self.factor * math.log(length)) 37 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 38 | index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] 39 | weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) 40 | # update corr 41 | tmp_corr = torch.softmax(weights, dim=-1) 42 | # aggregation 43 | tmp_values = values 44 | delays_agg = torch.zeros_like(values).float() 45 | for i in range(top_k): 46 | pattern = torch.roll(tmp_values, -int(index[i]), -1) 47 | delays_agg = delays_agg + pattern * \ 48 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 49 | return delays_agg 50 | 51 | def time_delay_agg_inference(self, values, corr): 52 | """ 53 | SpeedUp version of Autocorrelation (a batch-normalization style design) 54 | This is for the inference phase. 55 | """ 56 | batch = values.shape[0] 57 | head = values.shape[1] 58 | channel = values.shape[2] 59 | length = values.shape[3] 60 | # index init 61 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 62 | # find top k 63 | top_k = int(self.factor * math.log(length)) 64 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 65 | weights, delay = torch.topk(mean_value, top_k, dim=-1) 66 | # update corr 67 | tmp_corr = torch.softmax(weights, dim=-1) 68 | # aggregation 69 | tmp_values = values.repeat(1, 1, 1, 2) 70 | delays_agg = torch.zeros_like(values).float() 71 | for i in range(top_k): 72 | tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) 73 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 74 | delays_agg = delays_agg + pattern * \ 75 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 76 | return delays_agg 77 | 78 | def time_delay_agg_full(self, values, corr): 79 | """ 80 | Standard version of Autocorrelation 81 | """ 82 | batch = values.shape[0] 83 | head = values.shape[1] 84 | channel = values.shape[2] 85 | length = values.shape[3] 86 | # index init 87 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 88 | # find top k 89 | top_k = int(self.factor * math.log(length)) 90 | weights, delay = torch.topk(corr, top_k, dim=-1) 91 | # update corr 92 | tmp_corr = torch.softmax(weights, dim=-1) 93 | # aggregation 94 | tmp_values = values.repeat(1, 1, 1, 2) 95 | delays_agg = torch.zeros_like(values).float() 96 | for i in range(top_k): 97 | tmp_delay = init_index + delay[..., i].unsqueeze(-1) 98 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 99 | delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) 100 | return delays_agg 101 | 102 | def forward(self, queries, keys, values, attn_mask): 103 | B, L, H, E = queries.shape 104 | _, S, _, D = values.shape 105 | if L > S: 106 | zeros = torch.zeros_like(queries[:, :(L - S), :]).float() 107 | values = torch.cat([values, zeros], dim=1) 108 | keys = torch.cat([keys, zeros], dim=1) 109 | else: 110 | values = values[:, :L, :, :] 111 | keys = keys[:, :L, :, :] 112 | 113 | # period-based dependencies 114 | q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) 115 | k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) 116 | res = q_fft * torch.conj(k_fft) 117 | corr = torch.fft.irfft(res, dim=-1) 118 | 119 | # time delay agg 120 | if self.training: 121 | V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 122 | else: 123 | V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 124 | 125 | if self.output_attention: 126 | return (V.contiguous(), corr.permute(0, 3, 1, 2)) 127 | else: 128 | return (V.contiguous(), None) 129 | 130 | 131 | class AutoCorrelationLayer(nn.Module): 132 | def __init__(self, correlation, d_model, n_heads, d_keys=None, 133 | d_values=None): 134 | super(AutoCorrelationLayer, self).__init__() 135 | 136 | d_keys = d_keys or (d_model // n_heads) 137 | d_values = d_values or (d_model // n_heads) 138 | 139 | self.inner_correlation = correlation 140 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 141 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 142 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 143 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 144 | self.n_heads = n_heads 145 | 146 | def forward(self, queries, keys, values, attn_mask): 147 | B, L, _ = queries.shape 148 | _, S, _ = keys.shape 149 | H = self.n_heads 150 | 151 | queries = self.query_projection(queries).view(B, L, H, -1) 152 | keys = self.key_projection(keys).view(B, S, H, -1) 153 | values = self.value_projection(values).view(B, S, H, -1) 154 | 155 | out, attn = self.inner_correlation( 156 | queries, 157 | keys, 158 | values, 159 | attn_mask 160 | ) 161 | out = out.view(B, L, -1) 162 | 163 | return self.out_projection(out), attn 164 | -------------------------------------------------------------------------------- /layers/Autoformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class my_Layernorm(nn.Module): 7 | """ 8 | Special designed layernorm for the seasonal part 9 | """ 10 | 11 | def __init__(self, channels): 12 | super(my_Layernorm, self).__init__() 13 | self.layernorm = nn.LayerNorm(channels) 14 | 15 | def forward(self, x): 16 | x_hat = self.layernorm(x) 17 | bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) 18 | return x_hat - bias 19 | 20 | 21 | class moving_avg(nn.Module): 22 | """ 23 | Moving average block to highlight the trend of time series 24 | """ 25 | 26 | def __init__(self, kernel_size, stride): 27 | super(moving_avg, self).__init__() 28 | self.kernel_size = kernel_size 29 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 30 | 31 | def forward(self, x): 32 | # padding on the both ends of time series 33 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 34 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 35 | x = torch.cat([front, x, end], dim=1) 36 | x = self.avg(x.permute(0, 2, 1)) 37 | x = x.permute(0, 2, 1) 38 | return x 39 | 40 | 41 | class series_decomp(nn.Module): 42 | """ 43 | Series decomposition block 44 | """ 45 | 46 | def __init__(self, kernel_size): 47 | super(series_decomp, self).__init__() 48 | self.moving_avg = moving_avg(kernel_size, stride=1) 49 | 50 | def forward(self, x): 51 | moving_mean = self.moving_avg(x) 52 | res = x - moving_mean 53 | return res, moving_mean 54 | 55 | 56 | class series_decomp_multi(nn.Module): 57 | """ 58 | Multiple Series decomposition block from FEDformer 59 | """ 60 | 61 | def __init__(self, kernel_size): 62 | super(series_decomp_multi, self).__init__() 63 | self.kernel_size = kernel_size 64 | self.series_decomp = [series_decomp(kernel) for kernel in kernel_size] 65 | 66 | def forward(self, x): 67 | moving_mean = [] 68 | res = [] 69 | for func in self.series_decomp: 70 | sea, moving_avg = func(x) 71 | moving_mean.append(moving_avg) 72 | res.append(sea) 73 | 74 | sea = sum(res) / len(res) 75 | moving_mean = sum(moving_mean) / len(moving_mean) 76 | return sea, moving_mean 77 | 78 | 79 | class EncoderLayer(nn.Module): 80 | """ 81 | Autoformer encoder layer with the progressive decomposition architecture 82 | """ 83 | 84 | def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): 85 | super(EncoderLayer, self).__init__() 86 | d_ff = d_ff or 4 * d_model 87 | self.attention = attention 88 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 89 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 90 | self.decomp1 = series_decomp(moving_avg) 91 | self.decomp2 = series_decomp(moving_avg) 92 | self.dropout = nn.Dropout(dropout) 93 | self.activation = F.relu if activation == "relu" else F.gelu 94 | 95 | def forward(self, x, attn_mask=None): 96 | new_x, attn = self.attention( 97 | x, x, x, 98 | attn_mask=attn_mask 99 | ) 100 | x = x + self.dropout(new_x) 101 | x, _ = self.decomp1(x) 102 | y = x 103 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 104 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 105 | res, _ = self.decomp2(x + y) 106 | return res, attn 107 | 108 | 109 | class Encoder(nn.Module): 110 | """ 111 | Autoformer encoder 112 | """ 113 | 114 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 115 | super(Encoder, self).__init__() 116 | self.attn_layers = nn.ModuleList(attn_layers) 117 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 118 | self.norm = norm_layer 119 | 120 | def forward(self, x, attn_mask=None): 121 | attns = [] 122 | if self.conv_layers is not None: 123 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 124 | x, attn = attn_layer(x, attn_mask=attn_mask) 125 | x = conv_layer(x) 126 | attns.append(attn) 127 | x, attn = self.attn_layers[-1](x) 128 | attns.append(attn) 129 | else: 130 | for attn_layer in self.attn_layers: 131 | x, attn = attn_layer(x, attn_mask=attn_mask) 132 | attns.append(attn) 133 | 134 | if self.norm is not None: 135 | x = self.norm(x) 136 | 137 | return x, attns 138 | 139 | 140 | class DecoderLayer(nn.Module): 141 | """ 142 | Autoformer decoder layer with the progressive decomposition architecture 143 | """ 144 | 145 | def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None, 146 | moving_avg=25, dropout=0.1, activation="relu"): 147 | super(DecoderLayer, self).__init__() 148 | d_ff = d_ff or 4 * d_model 149 | self.self_attention = self_attention 150 | self.cross_attention = cross_attention 151 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 152 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 153 | self.decomp1 = series_decomp(moving_avg) 154 | self.decomp2 = series_decomp(moving_avg) 155 | self.decomp3 = series_decomp(moving_avg) 156 | self.dropout = nn.Dropout(dropout) 157 | self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1, 158 | padding_mode='circular', bias=False) 159 | self.activation = F.relu if activation == "relu" else F.gelu 160 | 161 | def forward(self, x, cross, x_mask=None, cross_mask=None): 162 | x = x + self.dropout(self.self_attention( 163 | x, x, x, 164 | attn_mask=x_mask 165 | )[0]) 166 | x, trend1 = self.decomp1(x) 167 | x = x + self.dropout(self.cross_attention( 168 | x, cross, cross, 169 | attn_mask=cross_mask 170 | )[0]) 171 | x, trend2 = self.decomp2(x) 172 | y = x 173 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 174 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 175 | x, trend3 = self.decomp3(x + y) 176 | 177 | residual_trend = trend1 + trend2 + trend3 178 | residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) 179 | return x, residual_trend 180 | 181 | 182 | class Decoder(nn.Module): 183 | """ 184 | Autoformer encoder 185 | """ 186 | 187 | def __init__(self, layers, norm_layer=None, projection=None): 188 | super(Decoder, self).__init__() 189 | self.layers = nn.ModuleList(layers) 190 | self.norm = norm_layer 191 | self.projection = projection 192 | 193 | def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): 194 | for layer in self.layers: 195 | x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 196 | trend = trend + residual_trend 197 | 198 | if self.norm is not None: 199 | x = self.norm(x) 200 | 201 | if self.projection is not None: 202 | x = self.projection(x) 203 | return x, trend 204 | -------------------------------------------------------------------------------- /layers/Conv_Blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Inception_Block_V1(nn.Module): 6 | def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): 7 | super(Inception_Block_V1, self).__init__() 8 | self.in_channels = in_channels 9 | self.out_channels = out_channels 10 | self.num_kernels = num_kernels 11 | kernels = [] 12 | for i in range(self.num_kernels): 13 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)) 14 | self.kernels = nn.ModuleList(kernels) 15 | if init_weight: 16 | self._initialize_weights() 17 | 18 | def _initialize_weights(self): 19 | for m in self.modules(): 20 | if isinstance(m, nn.Conv2d): 21 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 22 | if m.bias is not None: 23 | nn.init.constant_(m.bias, 0) 24 | 25 | def forward(self, x): 26 | res_list = [] 27 | for i in range(self.num_kernels): 28 | res_list.append(self.kernels[i](x)) 29 | res = torch.stack(res_list, dim=-1).mean(-1) 30 | return res 31 | 32 | 33 | class Inception_Block_V2(nn.Module): 34 | def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True): 35 | super(Inception_Block_V2, self).__init__() 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | self.num_kernels = num_kernels 39 | kernels = [] 40 | for i in range(self.num_kernels // 2): 41 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1])) 42 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0])) 43 | kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) 44 | self.kernels = nn.ModuleList(kernels) 45 | if init_weight: 46 | self._initialize_weights() 47 | 48 | def _initialize_weights(self): 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 52 | if m.bias is not None: 53 | nn.init.constant_(m.bias, 0) 54 | 55 | def forward(self, x): 56 | res_list = [] 57 | for i in range(self.num_kernels + 1): 58 | res_list.append(self.kernels[i](x)) 59 | res = torch.stack(res_list, dim=-1).mean(-1) 60 | return res 61 | -------------------------------------------------------------------------------- /layers/Crossformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange, repeat 4 | from layers.SelfAttention_Family import TwoStageAttentionLayer 5 | 6 | 7 | class SegMerging(nn.Module): 8 | def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm): 9 | super().__init__() 10 | self.d_model = d_model 11 | self.win_size = win_size 12 | self.linear_trans = nn.Linear(win_size * d_model, d_model) 13 | self.norm = norm_layer(win_size * d_model) 14 | 15 | def forward(self, x): 16 | batch_size, ts_d, seg_num, d_model = x.shape 17 | pad_num = seg_num % self.win_size 18 | if pad_num != 0: 19 | pad_num = self.win_size - pad_num 20 | x = torch.cat((x, x[:, :, -pad_num:, :]), dim=-2) 21 | 22 | seg_to_merge = [] 23 | for i in range(self.win_size): 24 | seg_to_merge.append(x[:, :, i::self.win_size, :]) 25 | x = torch.cat(seg_to_merge, -1) 26 | 27 | x = self.norm(x) 28 | x = self.linear_trans(x) 29 | 30 | return x 31 | 32 | 33 | class scale_block(nn.Module): 34 | def __init__(self, configs, win_size, d_model, n_heads, d_ff, depth, dropout, \ 35 | seg_num=10, factor=10): 36 | super(scale_block, self).__init__() 37 | 38 | if win_size > 1: 39 | self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm) 40 | else: 41 | self.merge_layer = None 42 | 43 | self.encode_layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.encode_layers.append(TwoStageAttentionLayer(configs, seg_num, factor, d_model, n_heads, \ 47 | d_ff, dropout)) 48 | 49 | def forward(self, x, attn_mask=None, tau=None, delta=None): 50 | _, ts_dim, _, _ = x.shape 51 | 52 | if self.merge_layer is not None: 53 | x = self.merge_layer(x) 54 | 55 | for layer in self.encode_layers: 56 | x = layer(x) 57 | 58 | return x, None 59 | 60 | 61 | class Encoder(nn.Module): 62 | def __init__(self, attn_layers): 63 | super(Encoder, self).__init__() 64 | self.encode_blocks = nn.ModuleList(attn_layers) 65 | 66 | def forward(self, x): 67 | encode_x = [] 68 | encode_x.append(x) 69 | 70 | for block in self.encode_blocks: 71 | x, attns = block(x) 72 | encode_x.append(x) 73 | 74 | return encode_x, None 75 | 76 | 77 | class DecoderLayer(nn.Module): 78 | def __init__(self, self_attention, cross_attention, seg_len, d_model, d_ff=None, dropout=0.1): 79 | super(DecoderLayer, self).__init__() 80 | self.self_attention = self_attention 81 | self.cross_attention = cross_attention 82 | self.norm1 = nn.LayerNorm(d_model) 83 | self.norm2 = nn.LayerNorm(d_model) 84 | self.dropout = nn.Dropout(dropout) 85 | self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model), 86 | nn.GELU(), 87 | nn.Linear(d_model, d_model)) 88 | self.linear_pred = nn.Linear(d_model, seg_len) 89 | 90 | def forward(self, x, cross): 91 | batch = x.shape[0] 92 | x = self.self_attention(x) 93 | x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model') 94 | 95 | cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model') 96 | tmp, attn = self.cross_attention(x, cross, cross, None, None, None,) 97 | x = x + self.dropout(tmp) 98 | y = x = self.norm1(x) 99 | y = self.MLP1(y) 100 | dec_output = self.norm2(x + y) 101 | 102 | dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b=batch) 103 | layer_predict = self.linear_pred(dec_output) 104 | layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len') 105 | 106 | return dec_output, layer_predict 107 | 108 | 109 | class Decoder(nn.Module): 110 | def __init__(self, layers): 111 | super(Decoder, self).__init__() 112 | self.decode_layers = nn.ModuleList(layers) 113 | 114 | 115 | def forward(self, x, cross): 116 | final_predict = None 117 | i = 0 118 | 119 | ts_d = x.shape[1] 120 | for layer in self.decode_layers: 121 | cross_enc = cross[i] 122 | x, layer_predict = layer(x, cross_enc) 123 | if final_predict is None: 124 | final_predict = layer_predict 125 | else: 126 | final_predict = final_predict + layer_predict 127 | i += 1 128 | 129 | final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d=ts_d) 130 | 131 | return final_predict 132 | -------------------------------------------------------------------------------- /layers/ETSformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.fft as fft 5 | from einops import rearrange, reduce, repeat 6 | import math, random 7 | from scipy.fftpack import next_fast_len 8 | 9 | 10 | class Transform: 11 | def __init__(self, sigma): 12 | self.sigma = sigma 13 | 14 | @torch.no_grad() 15 | def transform(self, x): 16 | return self.jitter(self.shift(self.scale(x))) 17 | 18 | def jitter(self, x): 19 | return x + (torch.randn(x.shape).to(x.device) * self.sigma) 20 | 21 | def scale(self, x): 22 | return x * (torch.randn(x.size(-1)).to(x.device) * self.sigma + 1) 23 | 24 | def shift(self, x): 25 | return x + (torch.randn(x.size(-1)).to(x.device) * self.sigma) 26 | 27 | 28 | def conv1d_fft(f, g, dim=-1): 29 | N = f.size(dim) 30 | M = g.size(dim) 31 | 32 | fast_len = next_fast_len(N + M - 1) 33 | 34 | F_f = fft.rfft(f, fast_len, dim=dim) 35 | F_g = fft.rfft(g, fast_len, dim=dim) 36 | 37 | F_fg = F_f * F_g.conj() 38 | out = fft.irfft(F_fg, fast_len, dim=dim) 39 | out = out.roll((-1,), dims=(dim,)) 40 | idx = torch.as_tensor(range(fast_len - N, fast_len)).to(out.device) 41 | out = out.index_select(dim, idx) 42 | 43 | return out 44 | 45 | 46 | class ExponentialSmoothing(nn.Module): 47 | 48 | def __init__(self, dim, nhead, dropout=0.1, aux=False): 49 | super().__init__() 50 | self._smoothing_weight = nn.Parameter(torch.randn(nhead, 1)) 51 | self.v0 = nn.Parameter(torch.randn(1, 1, nhead, dim)) 52 | self.dropout = nn.Dropout(dropout) 53 | if aux: 54 | self.aux_dropout = nn.Dropout(dropout) 55 | 56 | def forward(self, values, aux_values=None): 57 | b, t, h, d = values.shape 58 | 59 | init_weight, weight = self.get_exponential_weight(t) 60 | output = conv1d_fft(self.dropout(values), weight, dim=1) 61 | output = init_weight * self.v0 + output 62 | 63 | if aux_values is not None: 64 | aux_weight = weight / (1 - self.weight) * self.weight 65 | aux_output = conv1d_fft(self.aux_dropout(aux_values), aux_weight) 66 | output = output + aux_output 67 | 68 | return output 69 | 70 | def get_exponential_weight(self, T): 71 | # Generate array [0, 1, ..., T-1] 72 | powers = torch.arange(T, dtype=torch.float, device=self.weight.device) 73 | 74 | # (1 - \alpha) * \alpha^t, for all t = T-1, T-2, ..., 0] 75 | weight = (1 - self.weight) * (self.weight ** torch.flip(powers, dims=(0,))) 76 | 77 | # \alpha^t for all t = 1, 2, ..., T 78 | init_weight = self.weight ** (powers + 1) 79 | 80 | return rearrange(init_weight, 'h t -> 1 t h 1'), \ 81 | rearrange(weight, 'h t -> 1 t h 1') 82 | 83 | @property 84 | def weight(self): 85 | return torch.sigmoid(self._smoothing_weight) 86 | 87 | 88 | class Feedforward(nn.Module): 89 | def __init__(self, d_model, dim_feedforward, dropout=0.1, activation='sigmoid'): 90 | # Implementation of Feedforward model 91 | super().__init__() 92 | self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) 93 | self.dropout1 = nn.Dropout(dropout) 94 | self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) 95 | self.dropout2 = nn.Dropout(dropout) 96 | self.activation = getattr(F, activation) 97 | 98 | def forward(self, x): 99 | x = self.linear2(self.dropout1(self.activation(self.linear1(x)))) 100 | return self.dropout2(x) 101 | 102 | 103 | class GrowthLayer(nn.Module): 104 | 105 | def __init__(self, d_model, nhead, d_head=None, dropout=0.1): 106 | super().__init__() 107 | self.d_head = d_head or (d_model // nhead) 108 | self.d_model = d_model 109 | self.nhead = nhead 110 | 111 | self.z0 = nn.Parameter(torch.randn(self.nhead, self.d_head)) 112 | self.in_proj = nn.Linear(self.d_model, self.d_head * self.nhead) 113 | self.es = ExponentialSmoothing(self.d_head, self.nhead, dropout=dropout) 114 | self.out_proj = nn.Linear(self.d_head * self.nhead, self.d_model) 115 | 116 | assert self.d_head * self.nhead == self.d_model, "d_model must be divisible by nhead" 117 | 118 | def forward(self, inputs): 119 | """ 120 | :param inputs: shape: (batch, seq_len, dim) 121 | :return: shape: (batch, seq_len, dim) 122 | """ 123 | b, t, d = inputs.shape 124 | values = self.in_proj(inputs).view(b, t, self.nhead, -1) 125 | values = torch.cat([repeat(self.z0, 'h d -> b 1 h d', b=b), values], dim=1) 126 | values = values[:, 1:] - values[:, :-1] 127 | out = self.es(values) 128 | out = torch.cat([repeat(self.es.v0, '1 1 h d -> b 1 h d', b=b), out], dim=1) 129 | out = rearrange(out, 'b t h d -> b t (h d)') 130 | return self.out_proj(out) 131 | 132 | 133 | class FourierLayer(nn.Module): 134 | 135 | def __init__(self, d_model, pred_len, k=None, low_freq=1): 136 | super().__init__() 137 | self.d_model = d_model 138 | self.pred_len = pred_len 139 | self.k = k 140 | self.low_freq = low_freq 141 | 142 | def forward(self, x): 143 | """x: (b, t, d)""" 144 | b, t, d = x.shape 145 | x_freq = fft.rfft(x, dim=1) 146 | 147 | if t % 2 == 0: 148 | x_freq = x_freq[:, self.low_freq:-1] 149 | f = fft.rfftfreq(t)[self.low_freq:-1] 150 | else: 151 | x_freq = x_freq[:, self.low_freq:] 152 | f = fft.rfftfreq(t)[self.low_freq:] 153 | 154 | x_freq, index_tuple = self.topk_freq(x_freq) 155 | f = repeat(f, 'f -> b f d', b=x_freq.size(0), d=x_freq.size(2)) 156 | f = rearrange(f[index_tuple], 'b f d -> b f () d').to(x_freq.device) 157 | 158 | return self.extrapolate(x_freq, f, t) 159 | 160 | def extrapolate(self, x_freq, f, t): 161 | x_freq = torch.cat([x_freq, x_freq.conj()], dim=1) 162 | f = torch.cat([f, -f], dim=1) 163 | t_val = rearrange(torch.arange(t + self.pred_len, dtype=torch.float), 164 | 't -> () () t ()').to(x_freq.device) 165 | 166 | amp = rearrange(x_freq.abs() / t, 'b f d -> b f () d') 167 | phase = rearrange(x_freq.angle(), 'b f d -> b f () d') 168 | 169 | x_time = amp * torch.cos(2 * math.pi * f * t_val + phase) 170 | 171 | return reduce(x_time, 'b f t d -> b t d', 'sum') 172 | 173 | def topk_freq(self, x_freq): 174 | values, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True) 175 | mesh_a, mesh_b = torch.meshgrid(torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2))) 176 | index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1)) 177 | x_freq = x_freq[index_tuple] 178 | 179 | return x_freq, index_tuple 180 | 181 | 182 | class LevelLayer(nn.Module): 183 | 184 | def __init__(self, d_model, c_out, dropout=0.1): 185 | super().__init__() 186 | self.d_model = d_model 187 | self.c_out = c_out 188 | 189 | self.es = ExponentialSmoothing(1, self.c_out, dropout=dropout, aux=True) 190 | self.growth_pred = nn.Linear(self.d_model, self.c_out) 191 | self.season_pred = nn.Linear(self.d_model, self.c_out) 192 | 193 | def forward(self, level, growth, season): 194 | b, t, _ = level.shape 195 | growth = self.growth_pred(growth).view(b, t, self.c_out, 1) 196 | season = self.season_pred(season).view(b, t, self.c_out, 1) 197 | growth = growth.view(b, t, self.c_out, 1) 198 | season = season.view(b, t, self.c_out, 1) 199 | level = level.view(b, t, self.c_out, 1) 200 | out = self.es(level - season, aux_values=growth) 201 | out = rearrange(out, 'b t h d -> b t (h d)') 202 | return out 203 | 204 | 205 | class EncoderLayer(nn.Module): 206 | 207 | def __init__(self, d_model, nhead, c_out, seq_len, pred_len, k, dim_feedforward=None, dropout=0.1, 208 | activation='sigmoid', layer_norm_eps=1e-5): 209 | super().__init__() 210 | self.d_model = d_model 211 | self.nhead = nhead 212 | self.c_out = c_out 213 | self.seq_len = seq_len 214 | self.pred_len = pred_len 215 | dim_feedforward = dim_feedforward or 4 * d_model 216 | self.dim_feedforward = dim_feedforward 217 | 218 | self.growth_layer = GrowthLayer(d_model, nhead, dropout=dropout) 219 | self.seasonal_layer = FourierLayer(d_model, pred_len, k=k) 220 | self.level_layer = LevelLayer(d_model, c_out, dropout=dropout) 221 | 222 | # Implementation of Feedforward model 223 | self.ff = Feedforward(d_model, dim_feedforward, dropout=dropout, activation=activation) 224 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 225 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 226 | 227 | self.dropout1 = nn.Dropout(dropout) 228 | self.dropout2 = nn.Dropout(dropout) 229 | 230 | def forward(self, res, level, attn_mask=None): 231 | season = self._season_block(res) 232 | res = res - season[:, :-self.pred_len] 233 | growth = self._growth_block(res) 234 | res = self.norm1(res - growth[:, 1:]) 235 | res = self.norm2(res + self.ff(res)) 236 | 237 | level = self.level_layer(level, growth[:, :-1], season[:, :-self.pred_len]) 238 | return res, level, growth, season 239 | 240 | def _growth_block(self, x): 241 | x = self.growth_layer(x) 242 | return self.dropout1(x) 243 | 244 | def _season_block(self, x): 245 | x = self.seasonal_layer(x) 246 | return self.dropout2(x) 247 | 248 | 249 | class Encoder(nn.Module): 250 | 251 | def __init__(self, layers): 252 | super().__init__() 253 | self.layers = nn.ModuleList(layers) 254 | 255 | def forward(self, res, level, attn_mask=None): 256 | growths = [] 257 | seasons = [] 258 | for layer in self.layers: 259 | res, level, growth, season = layer(res, level, attn_mask=None) 260 | growths.append(growth) 261 | seasons.append(season) 262 | 263 | return level, growths, seasons 264 | 265 | 266 | class DampingLayer(nn.Module): 267 | 268 | def __init__(self, pred_len, nhead, dropout=0.1): 269 | super().__init__() 270 | self.pred_len = pred_len 271 | self.nhead = nhead 272 | self._damping_factor = nn.Parameter(torch.randn(1, nhead)) 273 | self.dropout = nn.Dropout(dropout) 274 | 275 | def forward(self, x): 276 | x = repeat(x, 'b 1 d -> b t d', t=self.pred_len) 277 | b, t, d = x.shape 278 | 279 | powers = torch.arange(self.pred_len).to(self._damping_factor.device) + 1 280 | powers = powers.view(self.pred_len, 1) 281 | damping_factors = self.damping_factor ** powers 282 | damping_factors = damping_factors.cumsum(dim=0) 283 | x = x.view(b, t, self.nhead, -1) 284 | x = self.dropout(x) * damping_factors.unsqueeze(-1) 285 | return x.view(b, t, d) 286 | 287 | @property 288 | def damping_factor(self): 289 | return torch.sigmoid(self._damping_factor) 290 | 291 | 292 | class DecoderLayer(nn.Module): 293 | 294 | def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1): 295 | super().__init__() 296 | self.d_model = d_model 297 | self.nhead = nhead 298 | self.c_out = c_out 299 | self.pred_len = pred_len 300 | 301 | self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout) 302 | self.dropout1 = nn.Dropout(dropout) 303 | 304 | def forward(self, growth, season): 305 | growth_horizon = self.growth_damping(growth[:, -1:]) 306 | growth_horizon = self.dropout1(growth_horizon) 307 | 308 | seasonal_horizon = season[:, -self.pred_len:] 309 | return growth_horizon, seasonal_horizon 310 | 311 | 312 | class Decoder(nn.Module): 313 | 314 | def __init__(self, layers): 315 | super().__init__() 316 | self.d_model = layers[0].d_model 317 | self.c_out = layers[0].c_out 318 | self.pred_len = layers[0].pred_len 319 | self.nhead = layers[0].nhead 320 | 321 | self.layers = nn.ModuleList(layers) 322 | self.pred = nn.Linear(self.d_model, self.c_out) 323 | 324 | def forward(self, growths, seasons): 325 | growth_repr = [] 326 | season_repr = [] 327 | 328 | for idx, layer in enumerate(self.layers): 329 | growth_horizon, season_horizon = layer(growths[idx], seasons[idx]) 330 | growth_repr.append(growth_horizon) 331 | season_repr.append(season_horizon) 332 | growth_repr = sum(growth_repr) 333 | season_repr = sum(season_repr) 334 | return self.pred(growth_repr), self.pred(season_repr) 335 | -------------------------------------------------------------------------------- /layers/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm 5 | import math 6 | from einops.layers.torch import Rearrange 7 | 8 | 9 | class PositionalEmbedding(nn.Module): 10 | def __init__(self, d_model, max_len=5000): 11 | super(PositionalEmbedding, self).__init__() 12 | # Compute the positional encodings once in log space. 13 | pe = torch.zeros(max_len, d_model).float() 14 | pe.require_grad = False 15 | 16 | position = torch.arange(0, max_len).float().unsqueeze(1) 17 | div_term = (torch.arange(0, d_model, 2).float() 18 | * -(math.log(10000.0) / d_model)).exp() 19 | 20 | pe[:, 0::2] = torch.sin(position * div_term) 21 | pe[:, 1::2] = torch.cos(position * div_term) 22 | 23 | pe = pe.unsqueeze(0) 24 | self.register_buffer('pe', pe) 25 | 26 | def forward(self, x): 27 | return self.pe[:, :x.size(1)] 28 | 29 | 30 | class TokenEmbedding(nn.Module): 31 | def __init__(self, c_in, d_model): 32 | super(TokenEmbedding, self).__init__() 33 | padding = 1 if torch.__version__ >= '1.5.0' else 2 34 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 35 | kernel_size=3, padding=padding, padding_mode='circular', bias=False) 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv1d): 38 | nn.init.kaiming_normal_( 39 | m.weight, mode='fan_in', nonlinearity='leaky_relu') 40 | 41 | def forward(self, x): 42 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 43 | return x 44 | 45 | 46 | class FixedEmbedding(nn.Module): 47 | def __init__(self, c_in, d_model): 48 | super(FixedEmbedding, self).__init__() 49 | 50 | w = torch.zeros(c_in, d_model).float() 51 | w.require_grad = False 52 | 53 | position = torch.arange(0, c_in).float().unsqueeze(1) 54 | div_term = (torch.arange(0, d_model, 2).float() 55 | * -(math.log(10000.0) / d_model)).exp() 56 | 57 | w[:, 0::2] = torch.sin(position * div_term) 58 | w[:, 1::2] = torch.cos(position * div_term) 59 | 60 | self.emb = nn.Embedding(c_in, d_model) 61 | self.emb.weight = nn.Parameter(w, requires_grad=False) 62 | 63 | def forward(self, x): 64 | return self.emb(x).detach() 65 | 66 | 67 | class TemporalEmbedding(nn.Module): 68 | def __init__(self, d_model, embed_type='fixed', freq='h'): 69 | super(TemporalEmbedding, self).__init__() 70 | 71 | minute_size = 4 72 | hour_size = 24 73 | weekday_size = 7 74 | day_size = 32 75 | month_size = 13 76 | 77 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding 78 | if freq == 't': 79 | self.minute_embed = Embed(minute_size, d_model) 80 | self.hour_embed = Embed(hour_size, d_model) 81 | self.weekday_embed = Embed(weekday_size, d_model) 82 | self.day_embed = Embed(day_size, d_model) 83 | self.month_embed = Embed(month_size, d_model) 84 | 85 | def forward(self, x): 86 | x = x.long() 87 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr( 88 | self, 'minute_embed') else 0. 89 | hour_x = self.hour_embed(x[:, :, 3]) 90 | weekday_x = self.weekday_embed(x[:, :, 2]) 91 | day_x = self.day_embed(x[:, :, 1]) 92 | month_x = self.month_embed(x[:, :, 0]) 93 | 94 | return hour_x + weekday_x + day_x + month_x + minute_x 95 | 96 | 97 | class TimeFeatureEmbedding(nn.Module): 98 | def __init__(self, d_model, embed_type='timeF', freq='h'): 99 | super(TimeFeatureEmbedding, self).__init__() 100 | 101 | freq_map = {'h': 4, 't': 5, 's': 6, 102 | 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 103 | d_inp = freq_map[freq] 104 | self.embed = nn.Linear(d_inp, d_model, bias=False) 105 | 106 | def forward(self, x): 107 | return self.embed(x) 108 | 109 | 110 | class DataEmbedding(nn.Module): 111 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 112 | super(DataEmbedding, self).__init__() 113 | 114 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 115 | self.position_embedding = PositionalEmbedding(d_model=d_model) 116 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 117 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 118 | d_model=d_model, embed_type=embed_type, freq=freq) 119 | self.dropout = nn.Dropout(p=dropout) 120 | 121 | def forward(self, x, x_mark): 122 | if x_mark is None: 123 | x = self.value_embedding(x) + self.position_embedding(x) 124 | else: 125 | # print(x.shape, x_mark.shape) 126 | # print(self.value_embedding) 127 | # print(self.temporal_embedding) 128 | # print(self.position_embedding) 129 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 130 | return self.dropout(x) 131 | 132 | 133 | class DataEmbedding_inverted(nn.Module): 134 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 135 | super(DataEmbedding_inverted, self).__init__() 136 | self.value_embedding = nn.Linear(c_in, d_model) 137 | self.dropout = nn.Dropout(p=dropout) 138 | 139 | def forward(self, x, x_mark): 140 | x = x.permute(0, 2, 1) 141 | # x: [Batch Variate Time] 142 | if x_mark is None: 143 | x = self.value_embedding(x) 144 | else: 145 | x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) 146 | # x: [Batch Variate d_model] 147 | return self.dropout(x) 148 | 149 | 150 | class DataEmbedding_wo_pos(nn.Module): 151 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 152 | super(DataEmbedding_wo_pos, self).__init__() 153 | 154 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 155 | self.position_embedding = PositionalEmbedding(d_model=d_model) 156 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 157 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 158 | d_model=d_model, embed_type=embed_type, freq=freq) 159 | self.dropout = nn.Dropout(p=dropout) 160 | 161 | def forward(self, x, x_mark): 162 | if x_mark is None: 163 | x = self.value_embedding(x) 164 | else: 165 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) 166 | return self.dropout(x) 167 | 168 | 169 | class PatchEmbedding(nn.Module): 170 | def __init__(self, d_model, patch_len, stride, padding, dropout): 171 | super(PatchEmbedding, self).__init__() 172 | # Patching 173 | self.patch_len = patch_len 174 | self.stride = stride 175 | self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) 176 | 177 | # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space 178 | self.value_embedding = nn.Linear(patch_len, d_model, bias=False) 179 | 180 | # Positional embedding 181 | self.position_embedding = PositionalEmbedding(d_model) 182 | 183 | # Residual dropout 184 | self.dropout = nn.Dropout(dropout) 185 | 186 | def forward(self, x): 187 | # do patching 188 | n_vars = x.shape[1] 189 | x = self.padding_patch_layer(x) 190 | x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) 191 | x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) 192 | # Input encoding 193 | x = self.value_embedding(x) + self.position_embedding(x) 194 | return self.dropout(x), n_vars 195 | 196 | 197 | class BlendEmbedding(nn.Module): 198 | def __init__(self, d_model, patch_len, n_vars, stride, padding, dropout): 199 | super(BlendEmbedding, self).__init__() 200 | # Patching 201 | self.patch_len = patch_len 202 | self.stride = stride 203 | self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) 204 | 205 | # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space 206 | self.value_embedding = nn.Linear(patch_len * n_vars, d_model, bias=False) 207 | 208 | # Positional embedding 209 | self.position_embedding = PositionalEmbedding(d_model) 210 | 211 | # Residual dropout 212 | self.dropout = nn.Dropout(dropout) 213 | 214 | def forward(self, x, y=None): 215 | # do patching 216 | n_vars = x.shape[1] 217 | 218 | if y is not None: 219 | x = torch.cat([x, y], dim=1) 220 | 221 | x = self.padding_patch_layer(x) 222 | x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) 223 | x = torch.reshape(x, (x.shape[0], x.shape[2], x.shape[3] * x.shape[1])) 224 | 225 | # Input encoding 226 | x = self.value_embedding(x) + self.position_embedding(x) 227 | 228 | # if y is not None: 229 | # y = self.padding_patch_layer(y) 230 | # y = y.unfold(dimension=-1, size=self.patch_len, step=self.stride) 231 | # y = torch.reshape(y, (y.shape[0], y.shape[2], y.shape[3] * y.shape[1])) 232 | # return self.dropout(x), y, n_vars 233 | # else: 234 | return self.dropout(x), n_vars 235 | 236 | 237 | class BlendDSCEmbedding(nn.Module): 238 | def __init__(self, d_model, patch_len, n_vars, stride, padding, dropout): 239 | super(BlendDSCEmbedding, self).__init__() 240 | # Patching 241 | self.patch_len = patch_len 242 | self.stride = stride 243 | self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) 244 | 245 | # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space 246 | nc_reduction = int(0.25 * n_vars) 247 | # if expand_radio == 1: 248 | self.pre_embedding = nn.Sequential( 249 | nn.Conv1d(in_channels=n_vars, out_channels=n_vars, kernel_size=7, padding=3, groups=n_vars, ), 250 | nn.LeakyReLU(True), 251 | nn.Conv1d(in_channels=n_vars, out_channels=nc_reduction, kernel_size=1, padding=0, ), 252 | # Rearrange('b m l k -> b l (k m)'), 253 | # nn.Linear(patch_len * n_vars, d_model, bias=False) 254 | ) 255 | # else: 256 | # self.pre_embedding = nn.Sequential( 257 | # nn.Conv1d(in_channels=n_vars, out_channels=n_vars * expand_radio, kernel_size=1, padding=0, bias=False), 258 | # # nn.ReLU6(True), 259 | # nn.Conv1d(in_channels=n_vars * expand_radio, out_channels=n_vars * expand_radio, 260 | # kernel_size=7, padding=3, groups=n_vars * expand_radio), 261 | # nn.Conv1d(in_channels=n_vars * expand_radio, out_channels=n_vars, kernel_size=1, padding=0, bias=False), 262 | # Rearrange('b m l k -> b l (k m)'), 263 | # nn.Linear(patch_len * n_vars, d_model, bias=False) 264 | # ) 265 | 266 | self.value_embedding = nn.Sequential( 267 | Rearrange('b m l k -> b l (k m)'), 268 | nn.Linear(patch_len * nc_reduction, d_model, bias=False) 269 | ) 270 | 271 | # Positional embedding 272 | self.position_embedding = PositionalEmbedding(d_model) 273 | 274 | # Residual dropout 275 | self.dropout = nn.Dropout(dropout) 276 | 277 | def forward(self, x): 278 | # x: b, n_vars, seq_len 279 | # do patching 280 | n_vars = x.shape[1] 281 | x = self.padding_patch_layer(x) 282 | # print(x.shape, self.pre_embedding(x).shape) # torch.Size([128, 21, 736]) torch.Size([128, 5, 736]) 283 | x = self.pre_embedding(x) 284 | 285 | # B, n_vars, num_patch, patch_len 286 | x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) 287 | 288 | # print(x.shape) 289 | # Input encoding 290 | x = self.value_embedding(x) 291 | 292 | x = x + self.position_embedding(x) 293 | 294 | return self.dropout(x), n_vars 295 | -------------------------------------------------------------------------------- /layers/FourierCorrelation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # author=maziqing 3 | # email=maziqing.mzq@alibaba-inc.com 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def get_frequency_modes(seq_len, modes=64, mode_select_method='random'): 11 | """ 12 | get modes on frequency domain: 13 | 'random' means sampling randomly; 14 | 'else' means sampling the lowest modes; 15 | """ 16 | modes = min(modes, seq_len // 2) 17 | if mode_select_method == 'random': 18 | index = list(range(0, seq_len // 2)) 19 | np.random.shuffle(index) 20 | index = index[:modes] 21 | else: 22 | index = list(range(0, modes)) 23 | index.sort() 24 | return index 25 | 26 | 27 | # ########## fourier layer ############# 28 | class FourierBlock(nn.Module): 29 | def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'): 30 | super(FourierBlock, self).__init__() 31 | print('fourier enhanced block used!') 32 | """ 33 | 1D Fourier block. It performs representation learning on frequency domain, 34 | it does FFT, linear transform, and Inverse FFT. 35 | """ 36 | # get modes on frequency domain 37 | self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method) 38 | print('modes={}, index={}'.format(modes, self.index)) 39 | 40 | self.scale = (1 / (in_channels * out_channels)) 41 | self.weights1 = nn.Parameter( 42 | self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float)) 43 | self.weights2 = nn.Parameter( 44 | self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float)) 45 | 46 | # Complex multiplication 47 | def compl_mul1d(self, order, x, weights): 48 | x_flag = True 49 | w_flag = True 50 | if not torch.is_complex(x): 51 | x_flag = False 52 | x = torch.complex(x, torch.zeros_like(x).to(x.device)) 53 | if not torch.is_complex(weights): 54 | w_flag = False 55 | weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) 56 | if x_flag or w_flag: 57 | return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), 58 | torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real)) 59 | else: 60 | return torch.einsum(order, x.real, weights.real) 61 | 62 | def forward(self, q, k, v, mask): 63 | # size = [B, L, H, E] 64 | B, L, H, E = q.shape 65 | x = q.permute(0, 2, 3, 1) 66 | # Compute Fourier coefficients 67 | x_ft = torch.fft.rfft(x, dim=-1) 68 | # Perform Fourier neural operations 69 | out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) 70 | for wi, i in enumerate(self.index): 71 | if i >= x_ft.shape[3] or wi >= out_ft.shape[3]: 72 | continue 73 | out_ft[:, :, :, wi] = self.compl_mul1d("bhi,hio->bho", x_ft[:, :, :, i], 74 | torch.complex(self.weights1, self.weights2)[:, :, :, wi]) 75 | # Return to time domain 76 | x = torch.fft.irfft(out_ft, n=x.size(-1)) 77 | return (x, None) 78 | 79 | 80 | # ########## Fourier Cross Former #################### 81 | class FourierCrossAttention(nn.Module): 82 | def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random', 83 | activation='tanh', policy=0, num_heads=8): 84 | super(FourierCrossAttention, self).__init__() 85 | print(' fourier enhanced cross attention used!') 86 | """ 87 | 1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT. 88 | """ 89 | self.activation = activation 90 | self.in_channels = in_channels 91 | self.out_channels = out_channels 92 | # get modes for queries and keys (& values) on frequency domain 93 | self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method) 94 | self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method) 95 | 96 | print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q)) 97 | print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv)) 98 | 99 | self.scale = (1 / (in_channels * out_channels)) 100 | self.weights1 = nn.Parameter( 101 | self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float)) 102 | self.weights2 = nn.Parameter( 103 | self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float)) 104 | 105 | # Complex multiplication 106 | def compl_mul1d(self, order, x, weights): 107 | x_flag = True 108 | w_flag = True 109 | if not torch.is_complex(x): 110 | x_flag = False 111 | x = torch.complex(x, torch.zeros_like(x).to(x.device)) 112 | if not torch.is_complex(weights): 113 | w_flag = False 114 | weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device)) 115 | if x_flag or w_flag: 116 | return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag), 117 | torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real)) 118 | else: 119 | return torch.einsum(order, x.real, weights.real) 120 | 121 | def forward(self, q, k, v, mask): 122 | # size = [B, L, H, E] 123 | B, L, H, E = q.shape 124 | xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L] 125 | xk = k.permute(0, 2, 3, 1) 126 | xv = v.permute(0, 2, 3, 1) 127 | 128 | # Compute Fourier coefficients 129 | xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) 130 | xq_ft = torch.fft.rfft(xq, dim=-1) 131 | for i, j in enumerate(self.index_q): 132 | if j >= xq_ft.shape[3]: 133 | continue 134 | xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] 135 | xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat) 136 | xk_ft = torch.fft.rfft(xk, dim=-1) 137 | for i, j in enumerate(self.index_kv): 138 | if j >= xk_ft.shape[3]: 139 | continue 140 | xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] 141 | 142 | # perform attention mechanism on frequency domain 143 | xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_)) 144 | if self.activation == 'tanh': 145 | xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh()) 146 | elif self.activation == 'softmax': 147 | xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) 148 | xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) 149 | else: 150 | raise Exception('{} actiation function is not implemented'.format(self.activation)) 151 | xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_) 152 | xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2)) 153 | out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) 154 | for i, j in enumerate(self.index_q): 155 | if i >= xqkvw.shape[3] or j >= out_ft.shape[3]: 156 | continue 157 | out_ft[:, :, :, j] = xqkvw[:, :, :, i] 158 | # Return to time domain 159 | out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)) 160 | return (out, None) 161 | -------------------------------------------------------------------------------- /layers/Pyraformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.linear import Linear 5 | from layers.SelfAttention_Family import AttentionLayer, FullAttention 6 | from layers.Embed import DataEmbedding 7 | import math 8 | 9 | 10 | def get_mask(input_size, window_size, inner_size): 11 | """Get the attention mask of PAM-Naive""" 12 | # Get the size of all layers 13 | all_size = [] 14 | all_size.append(input_size) 15 | for i in range(len(window_size)): 16 | layer_size = math.floor(all_size[i] / window_size[i]) 17 | all_size.append(layer_size) 18 | 19 | seq_length = sum(all_size) 20 | mask = torch.zeros(seq_length, seq_length) 21 | 22 | # get intra-scale mask 23 | inner_window = inner_size // 2 24 | for layer_idx in range(len(all_size)): 25 | start = sum(all_size[:layer_idx]) 26 | for i in range(start, start + all_size[layer_idx]): 27 | left_side = max(i - inner_window, start) 28 | right_side = min(i + inner_window + 1, start + all_size[layer_idx]) 29 | mask[i, left_side:right_side] = 1 30 | 31 | # get inter-scale mask 32 | for layer_idx in range(1, len(all_size)): 33 | start = sum(all_size[:layer_idx]) 34 | for i in range(start, start + all_size[layer_idx]): 35 | left_side = (start - all_size[layer_idx - 1]) + \ 36 | (i - start) * window_size[layer_idx - 1] 37 | if i == (start + all_size[layer_idx] - 1): 38 | right_side = start 39 | else: 40 | right_side = ( 41 | start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1] 42 | mask[i, left_side:right_side] = 1 43 | mask[left_side:right_side, i] = 1 44 | 45 | mask = (1 - mask).bool() 46 | 47 | return mask, all_size 48 | 49 | 50 | def refer_points(all_sizes, window_size): 51 | """Gather features from PAM's pyramid sequences""" 52 | input_size = all_sizes[0] 53 | indexes = torch.zeros(input_size, len(all_sizes)) 54 | 55 | for i in range(input_size): 56 | indexes[i][0] = i 57 | former_index = i 58 | for j in range(1, len(all_sizes)): 59 | start = sum(all_sizes[:j]) 60 | inner_layer_idx = former_index - (start - all_sizes[j - 1]) 61 | former_index = start + \ 62 | min(inner_layer_idx // window_size[j - 1], all_sizes[j] - 1) 63 | indexes[i][j] = former_index 64 | 65 | indexes = indexes.unsqueeze(0).unsqueeze(3) 66 | 67 | return indexes.long() 68 | 69 | 70 | class RegularMask(): 71 | def __init__(self, mask): 72 | self._mask = mask.unsqueeze(1) 73 | 74 | @property 75 | def mask(self): 76 | return self._mask 77 | 78 | 79 | class EncoderLayer(nn.Module): 80 | """ Compose with two layers """ 81 | 82 | def __init__(self, d_model, d_inner, n_head, dropout=0.1, normalize_before=True): 83 | super(EncoderLayer, self).__init__() 84 | 85 | self.slf_attn = AttentionLayer( 86 | FullAttention(mask_flag=True, factor=0, 87 | attention_dropout=dropout, output_attention=False), 88 | d_model, n_head) 89 | self.pos_ffn = PositionwiseFeedForward( 90 | d_model, d_inner, dropout=dropout, normalize_before=normalize_before) 91 | 92 | def forward(self, enc_input, slf_attn_mask=None): 93 | attn_mask = RegularMask(slf_attn_mask) 94 | enc_output, _ = self.slf_attn( 95 | enc_input, enc_input, enc_input, attn_mask=attn_mask) 96 | enc_output = self.pos_ffn(enc_output) 97 | return enc_output 98 | 99 | 100 | class Encoder(nn.Module): 101 | """ A encoder model with self attention mechanism. """ 102 | 103 | def __init__(self, configs, window_size, inner_size): 104 | super().__init__() 105 | 106 | d_bottleneck = configs.d_model//4 107 | 108 | self.mask, self.all_size = get_mask( 109 | configs.seq_len, window_size, inner_size) 110 | self.indexes = refer_points(self.all_size, window_size) 111 | self.layers = nn.ModuleList([ 112 | EncoderLayer(configs.d_model, configs.d_ff, configs.n_heads, dropout=configs.dropout, 113 | normalize_before=False) for _ in range(configs.e_layers) 114 | ]) # naive pyramid attention 115 | 116 | self.enc_embedding = DataEmbedding( 117 | configs.enc_in, configs.d_model, configs.dropout) 118 | self.conv_layers = Bottleneck_Construct( 119 | configs.d_model, window_size, d_bottleneck) 120 | 121 | def forward(self, x_enc, x_mark_enc): 122 | seq_enc = self.enc_embedding(x_enc, x_mark_enc) 123 | 124 | mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device) 125 | seq_enc = self.conv_layers(seq_enc) 126 | 127 | for i in range(len(self.layers)): 128 | seq_enc = self.layers[i](seq_enc, mask) 129 | 130 | indexes = self.indexes.repeat(seq_enc.size( 131 | 0), 1, 1, seq_enc.size(2)).to(seq_enc.device) 132 | indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2)) 133 | all_enc = torch.gather(seq_enc, 1, indexes) 134 | seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1) 135 | 136 | return seq_enc 137 | 138 | 139 | class ConvLayer(nn.Module): 140 | def __init__(self, c_in, window_size): 141 | super(ConvLayer, self).__init__() 142 | self.downConv = nn.Conv1d(in_channels=c_in, 143 | out_channels=c_in, 144 | kernel_size=window_size, 145 | stride=window_size) 146 | self.norm = nn.BatchNorm1d(c_in) 147 | self.activation = nn.ELU() 148 | 149 | def forward(self, x): 150 | x = self.downConv(x) 151 | x = self.norm(x) 152 | x = self.activation(x) 153 | return x 154 | 155 | 156 | class Bottleneck_Construct(nn.Module): 157 | """Bottleneck convolution CSCM""" 158 | 159 | def __init__(self, d_model, window_size, d_inner): 160 | super(Bottleneck_Construct, self).__init__() 161 | if not isinstance(window_size, list): 162 | self.conv_layers = nn.ModuleList([ 163 | ConvLayer(d_inner, window_size), 164 | ConvLayer(d_inner, window_size), 165 | ConvLayer(d_inner, window_size) 166 | ]) 167 | else: 168 | self.conv_layers = [] 169 | for i in range(len(window_size)): 170 | self.conv_layers.append(ConvLayer(d_inner, window_size[i])) 171 | self.conv_layers = nn.ModuleList(self.conv_layers) 172 | self.up = Linear(d_inner, d_model) 173 | self.down = Linear(d_model, d_inner) 174 | self.norm = nn.LayerNorm(d_model) 175 | 176 | def forward(self, enc_input): 177 | temp_input = self.down(enc_input).permute(0, 2, 1) 178 | all_inputs = [] 179 | for i in range(len(self.conv_layers)): 180 | temp_input = self.conv_layers[i](temp_input) 181 | all_inputs.append(temp_input) 182 | 183 | all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2) 184 | all_inputs = self.up(all_inputs) 185 | all_inputs = torch.cat([enc_input, all_inputs], dim=1) 186 | 187 | all_inputs = self.norm(all_inputs) 188 | return all_inputs 189 | 190 | 191 | class PositionwiseFeedForward(nn.Module): 192 | """ Two-layer position-wise feed-forward neural network. """ 193 | 194 | def __init__(self, d_in, d_hid, dropout=0.1, normalize_before=True): 195 | super().__init__() 196 | 197 | self.normalize_before = normalize_before 198 | 199 | self.w_1 = nn.Linear(d_in, d_hid) 200 | self.w_2 = nn.Linear(d_hid, d_in) 201 | 202 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 203 | self.dropout = nn.Dropout(dropout) 204 | 205 | def forward(self, x): 206 | residual = x 207 | if self.normalize_before: 208 | x = self.layer_norm(x) 209 | 210 | x = F.gelu(self.w_1(x)) 211 | x = self.dropout(x) 212 | x = self.w_2(x) 213 | x = self.dropout(x) 214 | x = x + residual 215 | 216 | if not self.normalize_before: 217 | x = self.layer_norm(x) 218 | return x 219 | -------------------------------------------------------------------------------- /layers/Transformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLayer(nn.Module): 7 | def __init__(self, c_in): 8 | super(ConvLayer, self).__init__() 9 | self.downConv = nn.Conv1d(in_channels=c_in, 10 | out_channels=c_in, 11 | kernel_size=3, 12 | padding=2, 13 | padding_mode='circular') 14 | self.norm = nn.BatchNorm1d(c_in) 15 | self.activation = nn.ELU() 16 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 17 | 18 | def forward(self, x): 19 | x = self.downConv(x.permute(0, 2, 1)) 20 | x = self.norm(x) 21 | x = self.activation(x) 22 | x = self.maxPool(x) 23 | x = x.transpose(1, 2) 24 | return x 25 | 26 | 27 | class EncoderLayer(nn.Module): 28 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 29 | super(EncoderLayer, self).__init__() 30 | d_ff = d_ff or 4 * d_model 31 | self.attention = attention 32 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 33 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 34 | self.norm1 = nn.LayerNorm(d_model) 35 | self.norm2 = nn.LayerNorm(d_model) 36 | self.dropout = nn.Dropout(dropout) 37 | self.activation = F.relu if activation == "relu" else F.gelu 38 | 39 | def forward(self, x, attn_mask=None, tau=None, delta=None): 40 | new_x, attn = self.attention( 41 | x, x, x, 42 | attn_mask=attn_mask, 43 | tau=tau, delta=delta 44 | ) 45 | x = x + self.dropout(new_x) 46 | 47 | y = x = self.norm1(x) 48 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 49 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 50 | 51 | return self.norm2(x + y), attn 52 | 53 | 54 | class EncoderLayerWoAttn(nn.Module): 55 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 56 | super(EncoderLayerWoAttn, self).__init__() 57 | d_ff = d_ff or 4 * d_model 58 | self.attention = attention 59 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 60 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 61 | self.norm1 = nn.LayerNorm(d_model) 62 | self.norm2 = nn.LayerNorm(d_model) 63 | self.dropout = nn.Dropout(dropout) 64 | self.activation = F.relu if activation == "relu" else F.gelu 65 | 66 | def forward(self, x, attn_mask=None, tau=None, delta=None): 67 | new_x, attn = self.attention( 68 | x, x, x, 69 | attn_mask=attn_mask, 70 | tau=tau, delta=delta 71 | ) 72 | x = x + self.dropout(new_x) 73 | 74 | y = x = self.norm1(x) 75 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 76 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 77 | 78 | return self.norm2(x + y) 79 | 80 | 81 | class Encoder(nn.Module): 82 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 83 | super(Encoder, self).__init__() 84 | self.attn_layers = nn.ModuleList(attn_layers) 85 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 86 | self.norm = norm_layer 87 | 88 | def forward(self, x, attn_mask=None, tau=None, delta=None): 89 | # x [B, L, D] 90 | attns = [] 91 | if self.conv_layers is not None: 92 | for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): 93 | delta = delta if i == 0 else None 94 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 95 | x = conv_layer(x) 96 | attns.append(attn) 97 | x, attn = self.attn_layers[-1](x, tau=tau, delta=None) 98 | attns.append(attn) 99 | else: 100 | for attn_layer in self.attn_layers: 101 | x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) 102 | attns.append(attn) 103 | 104 | if self.norm is not None: 105 | x = self.norm(x) 106 | 107 | return x, attns 108 | 109 | 110 | class DecoderLayer(nn.Module): 111 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 112 | dropout=0.1, activation="relu"): 113 | super(DecoderLayer, self).__init__() 114 | d_ff = d_ff or 4 * d_model 115 | self.self_attention = self_attention 116 | self.cross_attention = cross_attention 117 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 118 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 119 | self.norm1 = nn.LayerNorm(d_model) 120 | self.norm2 = nn.LayerNorm(d_model) 121 | self.norm3 = nn.LayerNorm(d_model) 122 | self.dropout = nn.Dropout(dropout) 123 | self.activation = F.relu if activation == "relu" else F.gelu 124 | 125 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 126 | x = x + self.dropout(self.self_attention( 127 | x, x, x, 128 | attn_mask=x_mask, 129 | tau=tau, delta=None 130 | )[0]) 131 | x = self.norm1(x) 132 | 133 | x = x + self.dropout(self.cross_attention( 134 | x, cross, cross, 135 | attn_mask=cross_mask, 136 | tau=tau, delta=delta 137 | )[0]) 138 | 139 | y = x = self.norm2(x) 140 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 141 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 142 | 143 | return self.norm3(x + y) 144 | 145 | 146 | class Decoder(nn.Module): 147 | def __init__(self, layers, norm_layer=None, projection=None): 148 | super(Decoder, self).__init__() 149 | self.layers = nn.ModuleList(layers) 150 | self.norm = norm_layer 151 | self.projection = projection 152 | 153 | def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): 154 | for layer in self.layers: 155 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta) 156 | 157 | if self.norm is not None: 158 | x = self.norm(x) 159 | 160 | if self.projection is not None: 161 | x = self.projection(x) 162 | return x 163 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiudingCai/MambaTS-pytorch/3b6797d9bf5178a490e1bb3c3e9f4d223f2af51d/layers/__init__.py -------------------------------------------------------------------------------- /layers/mamba_ssm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiudingCai/MambaTS-pytorch/3b6797d9bf5178a490e1bb3c3e9f4d223f2af51d/layers/mamba_ssm/__init__.py -------------------------------------------------------------------------------- /models/Autoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Embed import DataEmbedding, DataEmbedding_wo_pos 5 | from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer 6 | from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp 7 | import math 8 | import numpy as np 9 | 10 | 11 | class Model(nn.Module): 12 | """ 13 | Autoformer is the first method to achieve the series-wise connection, 14 | with inherent O(LlogL) complexity 15 | Paper link: https://openreview.net/pdf?id=I55UqU-M11y 16 | """ 17 | 18 | def __init__(self, configs): 19 | super(Model, self).__init__() 20 | self.task_name = configs.task_name 21 | self.seq_len = configs.seq_len 22 | self.label_len = configs.label_len 23 | self.pred_len = configs.pred_len 24 | self.output_attention = configs.output_attention 25 | 26 | # Decomp 27 | kernel_size = configs.moving_avg 28 | self.decomp = series_decomp(kernel_size) 29 | 30 | # Embedding 31 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, 32 | configs.dropout) 33 | # Encoder 34 | self.encoder = Encoder( 35 | [ 36 | EncoderLayer( 37 | AutoCorrelationLayer( 38 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 39 | output_attention=configs.output_attention), 40 | configs.d_model, configs.n_heads), 41 | configs.d_model, 42 | configs.d_ff, 43 | moving_avg=configs.moving_avg, 44 | dropout=configs.dropout, 45 | activation=configs.activation 46 | ) for l in range(configs.e_layers) 47 | ], 48 | norm_layer=my_Layernorm(configs.d_model) 49 | ) 50 | # Decoder 51 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 52 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, 53 | configs.dropout) 54 | self.decoder = Decoder( 55 | [ 56 | DecoderLayer( 57 | AutoCorrelationLayer( 58 | AutoCorrelation(True, configs.factor, attention_dropout=configs.dropout, 59 | output_attention=False), 60 | configs.d_model, configs.n_heads), 61 | AutoCorrelationLayer( 62 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 63 | output_attention=False), 64 | configs.d_model, configs.n_heads), 65 | configs.d_model, 66 | configs.c_out, 67 | configs.d_ff, 68 | moving_avg=configs.moving_avg, 69 | dropout=configs.dropout, 70 | activation=configs.activation, 71 | ) 72 | for l in range(configs.d_layers) 73 | ], 74 | norm_layer=my_Layernorm(configs.d_model), 75 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 76 | ) 77 | if self.task_name == 'imputation': 78 | self.projection = nn.Linear( 79 | configs.d_model, configs.c_out, bias=True) 80 | if self.task_name == 'anomaly_detection': 81 | self.projection = nn.Linear( 82 | configs.d_model, configs.c_out, bias=True) 83 | if self.task_name == 'classification': 84 | self.act = F.gelu 85 | self.dropout = nn.Dropout(configs.dropout) 86 | self.projection = nn.Linear( 87 | configs.d_model * configs.seq_len, configs.num_class) 88 | 89 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 90 | if x_mark_enc.shape[-1] != 4: 91 | x_mark_enc = None 92 | if x_mark_dec.shape[-1] != 4: 93 | x_mark_dec = None 94 | 95 | # decomp init 96 | mean = torch.mean(x_enc, dim=1).unsqueeze( 97 | 1).repeat(1, self.pred_len, 1) 98 | zeros = torch.zeros([x_dec.shape[0], self.pred_len, 99 | x_dec.shape[2]], device=x_enc.device) 100 | seasonal_init, trend_init = self.decomp(x_enc) 101 | # decoder input 102 | trend_init = torch.cat( 103 | [trend_init[:, -self.label_len:, :], mean], dim=1) 104 | seasonal_init = torch.cat( 105 | [seasonal_init[:, -self.label_len:, :], zeros], dim=1) 106 | # enc 107 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 108 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 109 | # dec 110 | dec_out = self.dec_embedding(seasonal_init, x_mark_dec) 111 | seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, 112 | trend=trend_init) 113 | # final 114 | dec_out = trend_part + seasonal_part 115 | return dec_out 116 | 117 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 118 | # enc 119 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 120 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 121 | # final 122 | dec_out = self.projection(enc_out) 123 | return dec_out 124 | 125 | def anomaly_detection(self, x_enc): 126 | # enc 127 | enc_out = self.enc_embedding(x_enc, None) 128 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 129 | # final 130 | dec_out = self.projection(enc_out) 131 | return dec_out 132 | 133 | def classification(self, x_enc, x_mark_enc): 134 | # enc 135 | enc_out = self.enc_embedding(x_enc, None) 136 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 137 | 138 | # Output 139 | # the output transformer encoder/decoder embeddings don't include non-linearity 140 | output = self.act(enc_out) 141 | output = self.dropout(output) 142 | # zero-out padding embeddings 143 | output = output * x_mark_enc.unsqueeze(-1) 144 | # (batch_size, seq_length * d_model) 145 | output = output.reshape(output.shape[0], -1) 146 | output = self.projection(output) # (batch_size, num_classes) 147 | return output 148 | 149 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 150 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 151 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 152 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 153 | if self.task_name == 'imputation': 154 | dec_out = self.imputation( 155 | x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 156 | return dec_out # [B, L, D] 157 | if self.task_name == 'anomaly_detection': 158 | dec_out = self.anomaly_detection(x_enc) 159 | return dec_out # [B, L, D] 160 | if self.task_name == 'classification': 161 | dec_out = self.classification(x_enc, x_mark_enc) 162 | return dec_out # [B, N] 163 | return None 164 | -------------------------------------------------------------------------------- /models/Crossformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from layers.Crossformer_EncDec import scale_block, Encoder, Decoder, DecoderLayer 6 | from layers.Embed import PatchEmbedding 7 | from layers.SelfAttention_Family import AttentionLayer, FullAttention, TwoStageAttentionLayer 8 | from models.PatchTST import FlattenHead 9 | 10 | from math import ceil 11 | 12 | 13 | class Model(nn.Module): 14 | """ 15 | Paper link: https://openreview.net/pdf?id=vSVLM2j9eie 16 | """ 17 | 18 | def __init__(self, configs): 19 | super(Model, self).__init__() 20 | self.enc_in = configs.enc_in 21 | self.seq_len = configs.seq_len 22 | self.pred_len = configs.pred_len 23 | self.seg_len = 12 24 | self.win_size = 2 25 | self.task_name = configs.task_name 26 | 27 | # The padding operation to handle invisible sgemnet length 28 | self.pad_in_len = ceil(1.0 * configs.seq_len / self.seg_len) * self.seg_len 29 | self.pad_out_len = ceil(1.0 * configs.pred_len / self.seg_len) * self.seg_len 30 | self.in_seg_num = self.pad_in_len // self.seg_len 31 | self.out_seg_num = ceil(self.in_seg_num / (self.win_size ** (configs.e_layers - 1))) 32 | self.head_nf = configs.d_model * self.out_seg_num 33 | 34 | # Embedding 35 | self.enc_value_embedding = PatchEmbedding(configs.d_model, self.seg_len, self.seg_len, 36 | self.pad_in_len - configs.seq_len, 0) 37 | self.enc_pos_embedding = nn.Parameter( 38 | torch.randn(1, configs.enc_in, self.in_seg_num, configs.d_model)) 39 | self.pre_norm = nn.LayerNorm(configs.d_model) 40 | 41 | # Encoder 42 | self.encoder = Encoder( 43 | [ 44 | scale_block(configs, 1 if l is 0 else self.win_size, configs.d_model, configs.n_heads, configs.d_ff, 45 | 1, configs.dropout, 46 | self.in_seg_num if l is 0 else ceil(self.in_seg_num / self.win_size ** l), configs.factor 47 | ) for l in range(configs.e_layers) 48 | ] 49 | ) 50 | # Decoder 51 | self.dec_pos_embedding = nn.Parameter( 52 | torch.randn(1, configs.enc_in, (self.pad_out_len // self.seg_len), configs.d_model)) 53 | 54 | self.decoder = Decoder( 55 | [ 56 | DecoderLayer( 57 | TwoStageAttentionLayer(configs, (self.pad_out_len // self.seg_len), configs.factor, configs.d_model, 58 | configs.n_heads, 59 | configs.d_ff, configs.dropout), 60 | AttentionLayer( 61 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 62 | output_attention=False), 63 | configs.d_model, configs.n_heads), 64 | self.seg_len, 65 | configs.d_model, 66 | configs.d_ff, 67 | dropout=configs.dropout, 68 | # activation=configs.activation, 69 | ) 70 | for l in range(configs.e_layers + 1) 71 | ], 72 | ) 73 | if self.task_name == 'imputation' or self.task_name == 'anomaly_detection': 74 | self.head = FlattenHead(configs.enc_in, self.head_nf, configs.seq_len, 75 | head_dropout=configs.dropout) 76 | elif self.task_name == 'classification': 77 | self.flatten = nn.Flatten(start_dim=-2) 78 | self.dropout = nn.Dropout(configs.dropout) 79 | self.projection = nn.Linear( 80 | self.head_nf * configs.enc_in, configs.num_class) 81 | 82 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 83 | # Normalization from Non-stationary Transformer 84 | means = x_enc.mean(1, keepdim=True).detach() 85 | x_enc = x_enc - means 86 | stdev = torch.sqrt( 87 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 88 | x_enc /= stdev 89 | 90 | # embedding 91 | x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1)) 92 | x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars) 93 | x_enc += self.enc_pos_embedding 94 | x_enc = self.pre_norm(x_enc) 95 | enc_out, attns = self.encoder(x_enc) 96 | 97 | dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat=x_enc.shape[0]) 98 | dec_out = self.decoder(dec_in, enc_out) 99 | 100 | # De-Normalization from Non-stationary Transformer 101 | dec_out = dec_out * \ 102 | (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 103 | dec_out = dec_out + \ 104 | (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 105 | 106 | return dec_out 107 | 108 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 109 | # embedding 110 | x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1)) 111 | x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars) 112 | x_enc += self.enc_pos_embedding 113 | x_enc = self.pre_norm(x_enc) 114 | enc_out, attns = self.encoder(x_enc) 115 | 116 | dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1) 117 | 118 | return dec_out 119 | 120 | def anomaly_detection(self, x_enc): 121 | # embedding 122 | x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1)) 123 | x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars) 124 | x_enc += self.enc_pos_embedding 125 | x_enc = self.pre_norm(x_enc) 126 | enc_out, attns = self.encoder(x_enc) 127 | 128 | dec_out = self.head(enc_out[-1].permute(0, 1, 3, 2)).permute(0, 2, 1) 129 | return dec_out 130 | 131 | def classification(self, x_enc, x_mark_enc): 132 | # embedding 133 | x_enc, n_vars = self.enc_value_embedding(x_enc.permute(0, 2, 1)) 134 | 135 | x_enc = rearrange(x_enc, '(b d) seg_num d_model -> b d seg_num d_model', d=n_vars) 136 | x_enc += self.enc_pos_embedding 137 | x_enc = self.pre_norm(x_enc) 138 | enc_out, attns = self.encoder(x_enc) 139 | # Output from Non-stationary Transformer 140 | output = self.flatten(enc_out[-1].permute(0, 1, 3, 2)) 141 | output = self.dropout(output) 142 | output = output.reshape(output.shape[0], -1) 143 | output = self.projection(output) 144 | return output 145 | 146 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 147 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 148 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 149 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 150 | if self.task_name == 'imputation': 151 | dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 152 | return dec_out # [B, L, D] 153 | if self.task_name == 'anomaly_detection': 154 | dec_out = self.anomaly_detection(x_enc) 155 | return dec_out # [B, L, D] 156 | if self.task_name == 'classification': 157 | dec_out = self.classification(x_enc, x_mark_enc) 158 | return dec_out # [B, N] 159 | return None 160 | -------------------------------------------------------------------------------- /models/DLinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Autoformer_EncDec import series_decomp 5 | 6 | 7 | class Model(nn.Module): 8 | """ 9 | Paper link: https://arxiv.org/pdf/2205.13504.pdf 10 | """ 11 | 12 | def __init__(self, configs, individual=False): 13 | """ 14 | individual: Bool, whether shared model among different variates. 15 | """ 16 | super(Model, self).__init__() 17 | self.task_name = configs.task_name 18 | self.seq_len = configs.seq_len 19 | if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation': 20 | self.pred_len = configs.seq_len 21 | else: 22 | self.pred_len = configs.pred_len 23 | # Series decomposition block from Autoformer 24 | self.decompsition = series_decomp(configs.moving_avg) 25 | self.individual = individual 26 | self.channels = configs.enc_in 27 | 28 | if self.individual: 29 | self.Linear_Seasonal = nn.ModuleList() 30 | self.Linear_Trend = nn.ModuleList() 31 | 32 | for i in range(self.channels): 33 | self.Linear_Seasonal.append( 34 | nn.Linear(self.seq_len, self.pred_len)) 35 | self.Linear_Trend.append( 36 | nn.Linear(self.seq_len, self.pred_len)) 37 | 38 | self.Linear_Seasonal[i].weight = nn.Parameter( 39 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 40 | self.Linear_Trend[i].weight = nn.Parameter( 41 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 42 | else: 43 | self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len) 44 | self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len) 45 | 46 | self.Linear_Seasonal.weight = nn.Parameter( 47 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 48 | self.Linear_Trend.weight = nn.Parameter( 49 | (1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len])) 50 | 51 | if self.task_name == 'classification': 52 | self.act = F.gelu 53 | self.dropout = nn.Dropout(configs.dropout) 54 | self.projection = nn.Linear( 55 | configs.enc_in * configs.seq_len, configs.num_class) 56 | 57 | def encoder(self, x): 58 | seasonal_init, trend_init = self.decompsition(x) 59 | seasonal_init, trend_init = seasonal_init.permute( 60 | 0, 2, 1), trend_init.permute(0, 2, 1) 61 | if self.individual: 62 | seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size(1), self.pred_len], 63 | dtype=seasonal_init.dtype).to(seasonal_init.device) 64 | trend_output = torch.zeros([trend_init.size(0), trend_init.size(1), self.pred_len], 65 | dtype=trend_init.dtype).to(trend_init.device) 66 | for i in range(self.channels): 67 | seasonal_output[:, i, :] = self.Linear_Seasonal[i]( 68 | seasonal_init[:, i, :]) 69 | trend_output[:, i, :] = self.Linear_Trend[i]( 70 | trend_init[:, i, :]) 71 | else: 72 | seasonal_output = self.Linear_Seasonal(seasonal_init) 73 | trend_output = self.Linear_Trend(trend_init) 74 | x = seasonal_output + trend_output 75 | return x.permute(0, 2, 1) 76 | 77 | def forecast(self, x_enc): 78 | # Normalization from Non-stationary Transformer 79 | means = x_enc.mean(1, keepdim=True).detach() 80 | x_enc = x_enc - means 81 | stdev = torch.sqrt( 82 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 83 | x_enc /= stdev 84 | 85 | dec_out = self.encoder(x_enc) 86 | 87 | # De-Normalization from Non-stationary Transformer 88 | dec_out = dec_out * \ 89 | (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 90 | dec_out = dec_out + \ 91 | (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 92 | 93 | return dec_out 94 | 95 | def imputation(self, x_enc): 96 | # Encoder 97 | return self.encoder(x_enc) 98 | 99 | def anomaly_detection(self, x_enc): 100 | # Encoder 101 | return self.encoder(x_enc) 102 | 103 | def classification(self, x_enc): 104 | # Encoder 105 | enc_out = self.encoder(x_enc) 106 | # Output 107 | # (batch_size, seq_length * d_model) 108 | output = enc_out.reshape(enc_out.shape[0], -1) 109 | # (batch_size, num_classes) 110 | output = self.projection(output) 111 | return output 112 | 113 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 114 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 115 | dec_out = self.forecast(x_enc) 116 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 117 | if self.task_name == 'imputation': 118 | dec_out = self.imputation(x_enc) 119 | return dec_out # [B, L, D] 120 | if self.task_name == 'anomaly_detection': 121 | dec_out = self.anomaly_detection(x_enc) 122 | return dec_out # [B, L, D] 123 | if self.task_name == 'classification': 124 | dec_out = self.classification(x_enc) 125 | return dec_out # [B, N] 126 | return None 127 | -------------------------------------------------------------------------------- /models/FEDformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Embed import DataEmbedding 5 | from layers.AutoCorrelation import AutoCorrelationLayer 6 | from layers.FourierCorrelation import FourierBlock, FourierCrossAttention 7 | from layers.MultiWaveletCorrelation import MultiWaveletCross, MultiWaveletTransform 8 | from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp 9 | 10 | 11 | class Model(nn.Module): 12 | """ 13 | FEDformer performs the attention mechanism on frequency domain and achieved O(N) complexity 14 | Paper link: https://proceedings.mlr.press/v162/zhou22g.html 15 | """ 16 | 17 | def __init__(self, configs, version='fourier', mode_select='random', modes=32): 18 | """ 19 | version: str, for FEDformer, there are two versions to choose, options: [Fourier, Wavelets]. 20 | mode_select: str, for FEDformer, there are two mode selection method, options: [random, low]. 21 | modes: int, modes to be selected. 22 | """ 23 | super(Model, self).__init__() 24 | self.task_name = configs.task_name 25 | self.seq_len = configs.seq_len 26 | self.label_len = configs.label_len 27 | self.pred_len = configs.pred_len 28 | 29 | self.version = version 30 | self.mode_select = mode_select 31 | self.modes = modes 32 | 33 | # Decomp 34 | self.decomp = series_decomp(configs.moving_avg) 35 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 36 | configs.dropout) 37 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 38 | configs.dropout) 39 | 40 | if self.version == 'Wavelets': 41 | encoder_self_att = MultiWaveletTransform(ich=configs.d_model, L=1, base='legendre') 42 | decoder_self_att = MultiWaveletTransform(ich=configs.d_model, L=1, base='legendre') 43 | decoder_cross_att = MultiWaveletCross(in_channels=configs.d_model, 44 | out_channels=configs.d_model, 45 | seq_len_q=self.seq_len // 2 + self.pred_len, 46 | seq_len_kv=self.seq_len, 47 | modes=self.modes, 48 | ich=configs.d_model, 49 | base='legendre', 50 | activation='tanh') 51 | else: 52 | encoder_self_att = FourierBlock(in_channels=configs.d_model, 53 | out_channels=configs.d_model, 54 | seq_len=self.seq_len, 55 | modes=self.modes, 56 | mode_select_method=self.mode_select) 57 | decoder_self_att = FourierBlock(in_channels=configs.d_model, 58 | out_channels=configs.d_model, 59 | seq_len=self.seq_len // 2 + self.pred_len, 60 | modes=self.modes, 61 | mode_select_method=self.mode_select) 62 | decoder_cross_att = FourierCrossAttention(in_channels=configs.d_model, 63 | out_channels=configs.d_model, 64 | seq_len_q=self.seq_len // 2 + self.pred_len, 65 | seq_len_kv=self.seq_len, 66 | modes=self.modes, 67 | mode_select_method=self.mode_select, 68 | num_heads=configs.n_heads) 69 | # Encoder 70 | self.encoder = Encoder( 71 | [ 72 | EncoderLayer( 73 | AutoCorrelationLayer( 74 | encoder_self_att, # instead of multi-head attention in transformer 75 | configs.d_model, configs.n_heads), 76 | configs.d_model, 77 | configs.d_ff, 78 | moving_avg=configs.moving_avg, 79 | dropout=configs.dropout, 80 | activation=configs.activation 81 | ) for l in range(configs.e_layers) 82 | ], 83 | norm_layer=my_Layernorm(configs.d_model) 84 | ) 85 | # Decoder 86 | self.decoder = Decoder( 87 | [ 88 | DecoderLayer( 89 | AutoCorrelationLayer( 90 | decoder_self_att, 91 | configs.d_model, configs.n_heads), 92 | AutoCorrelationLayer( 93 | decoder_cross_att, 94 | configs.d_model, configs.n_heads), 95 | configs.d_model, 96 | configs.c_out, 97 | configs.d_ff, 98 | moving_avg=configs.moving_avg, 99 | dropout=configs.dropout, 100 | activation=configs.activation, 101 | ) 102 | for l in range(configs.d_layers) 103 | ], 104 | norm_layer=my_Layernorm(configs.d_model), 105 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 106 | ) 107 | 108 | if self.task_name == 'imputation': 109 | self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True) 110 | if self.task_name == 'anomaly_detection': 111 | self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True) 112 | if self.task_name == 'classification': 113 | self.act = F.gelu 114 | self.dropout = nn.Dropout(configs.dropout) 115 | self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class) 116 | 117 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 118 | if x_mark_enc.shape[-1] != 4: 119 | x_mark_enc = None 120 | if x_mark_dec.shape[-1] != 4: 121 | x_mark_dec = None 122 | # decomp init 123 | mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) 124 | seasonal_init, trend_init = self.decomp(x_enc) # x - moving_avg, moving_avg 125 | # decoder input 126 | trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1) 127 | seasonal_init = F.pad(seasonal_init[:, -self.label_len:, :], (0, 0, 0, self.pred_len)) 128 | # enc 129 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 130 | dec_out = self.dec_embedding(seasonal_init, x_mark_dec) 131 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 132 | # dec 133 | seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, trend=trend_init) 134 | # final 135 | dec_out = trend_part + seasonal_part 136 | return dec_out 137 | 138 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 139 | # enc 140 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 141 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 142 | # final 143 | dec_out = self.projection(enc_out) 144 | return dec_out 145 | 146 | def anomaly_detection(self, x_enc): 147 | # enc 148 | enc_out = self.enc_embedding(x_enc, None) 149 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 150 | # final 151 | dec_out = self.projection(enc_out) 152 | return dec_out 153 | 154 | def classification(self, x_enc, x_mark_enc): 155 | # enc 156 | enc_out = self.enc_embedding(x_enc, None) 157 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 158 | 159 | # Output 160 | output = self.act(enc_out) 161 | output = self.dropout(output) 162 | output = output * x_mark_enc.unsqueeze(-1) 163 | output = output.reshape(output.shape[0], -1) 164 | output = self.projection(output) 165 | return output 166 | 167 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 168 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 169 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 170 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 171 | if self.task_name == 'imputation': 172 | dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 173 | return dec_out # [B, L, D] 174 | if self.task_name == 'anomaly_detection': 175 | dec_out = self.anomaly_detection(x_enc) 176 | return dec_out # [B, L, D] 177 | if self.task_name == 'classification': 178 | dec_out = self.classification(x_enc, x_mark_enc) 179 | return dec_out # [B, N] 180 | return None 181 | -------------------------------------------------------------------------------- /models/FourierGNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Model(nn.Module): 7 | def __init__(self, configs): 8 | super().__init__() 9 | 10 | self.task_name = configs.task_name 11 | self.seq_len = configs.seq_len 12 | self.pred_len = configs.pred_len 13 | 14 | self.embed_size = 64 15 | self.d_model = configs.d_model 16 | self.number_frequency = 1 17 | # self.pred_len = pre_length 18 | # self.feature_size = feature_size 19 | # self.seq_len = seq_len 20 | self.frequency_size = self.embed_size // self.number_frequency 21 | self.hidden_size_factor = 1 22 | self.sparsity_threshold = 0.01 23 | self.hard_thresholding_fraction = 1 24 | self.scale = 0.02 25 | self.embeddings = nn.Parameter(torch.randn(1, self.embed_size)) 26 | 27 | self.w1 = nn.Parameter( 28 | self.scale * torch.randn(2, self.frequency_size, self.frequency_size * self.hidden_size_factor)) 29 | self.b1 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor)) 30 | self.w2 = nn.Parameter( 31 | self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor, self.frequency_size)) 32 | self.b2 = nn.Parameter(self.scale * torch.randn(2, self.frequency_size)) 33 | self.w3 = nn.Parameter( 34 | self.scale * torch.randn(2, self.frequency_size, 35 | self.frequency_size * self.hidden_size_factor)) 36 | self.b3 = nn.Parameter( 37 | self.scale * torch.randn(2, self.frequency_size * self.hidden_size_factor)) 38 | self.embeddings_10 = nn.Parameter(torch.randn(self.seq_len, 8)) 39 | self.fc = nn.Sequential( 40 | nn.Linear(self.embed_size * 8, 64), 41 | nn.LeakyReLU(), 42 | nn.Linear(64, self.d_model), 43 | nn.LeakyReLU(), 44 | nn.Linear(self.d_model, self.pred_len) 45 | ) 46 | self.to('cuda:0') 47 | 48 | def tokenEmb(self, x): 49 | x = x.unsqueeze(2) 50 | y = self.embeddings 51 | return x * y 52 | 53 | # FourierGNN 54 | def fourierGC(self, x, B, N, L): 55 | o1_real = torch.zeros([B, (N * L) // 2 + 1, self.frequency_size * self.hidden_size_factor], 56 | device=x.device) 57 | o1_imag = torch.zeros([B, (N * L) // 2 + 1, self.frequency_size * self.hidden_size_factor], 58 | device=x.device) 59 | o2_real = torch.zeros(x.shape, device=x.device) 60 | o2_imag = torch.zeros(x.shape, device=x.device) 61 | 62 | o3_real = torch.zeros(x.shape, device=x.device) 63 | o3_imag = torch.zeros(x.shape, device=x.device) 64 | 65 | o1_real = F.relu( 66 | torch.einsum('bli,ii->bli', x.real, self.w1[0]) - \ 67 | torch.einsum('bli,ii->bli', x.imag, self.w1[1]) + \ 68 | self.b1[0] 69 | ) 70 | 71 | o1_imag = F.relu( 72 | torch.einsum('bli,ii->bli', x.imag, self.w1[0]) + \ 73 | torch.einsum('bli,ii->bli', x.real, self.w1[1]) + \ 74 | self.b1[1] 75 | ) 76 | 77 | # 1 layer 78 | y = torch.stack([o1_real, o1_imag], dim=-1) 79 | y = F.softshrink(y, lambd=self.sparsity_threshold) 80 | 81 | o2_real = F.relu( 82 | torch.einsum('bli,ii->bli', o1_real, self.w2[0]) - \ 83 | torch.einsum('bli,ii->bli', o1_imag, self.w2[1]) + \ 84 | self.b2[0] 85 | ) 86 | 87 | o2_imag = F.relu( 88 | torch.einsum('bli,ii->bli', o1_imag, self.w2[0]) + \ 89 | torch.einsum('bli,ii->bli', o1_real, self.w2[1]) + \ 90 | self.b2[1] 91 | ) 92 | 93 | # 2 layer 94 | x = torch.stack([o2_real, o2_imag], dim=-1) 95 | x = F.softshrink(x, lambd=self.sparsity_threshold) 96 | x = x + y 97 | 98 | o3_real = F.relu( 99 | torch.einsum('bli,ii->bli', o2_real, self.w3[0]) - \ 100 | torch.einsum('bli,ii->bli', o2_imag, self.w3[1]) + \ 101 | self.b3[0] 102 | ) 103 | 104 | o3_imag = F.relu( 105 | torch.einsum('bli,ii->bli', o2_imag, self.w3[0]) + \ 106 | torch.einsum('bli,ii->bli', o2_real, self.w3[1]) + \ 107 | self.b3[1] 108 | ) 109 | 110 | # 3 layer 111 | z = torch.stack([o3_real, o3_imag], dim=-1) 112 | z = F.softshrink(z, lambd=self.sparsity_threshold) 113 | z = z + x 114 | z = torch.view_as_complex(z) 115 | return z 116 | 117 | def forecast(self, x, x_mark_enc, x_dec, x_mark_dec): 118 | x = x.permute(0, 2, 1).contiguous() 119 | B, N, L = x.shape 120 | # B*N*L ==> B*NL 121 | x = x.reshape(B, -1) 122 | # embedding B*NL ==> B*NL*D 123 | x = self.tokenEmb(x) 124 | 125 | # FFT B*NL*D ==> B*NT/2*D 126 | x = torch.fft.rfft(x, dim=1, norm='ortho') 127 | 128 | x = x.reshape(B, (N * L) // 2 + 1, self.frequency_size) 129 | 130 | bias = x 131 | 132 | # FourierGNN 133 | x = self.fourierGC(x, B, N, L) 134 | 135 | x = x + bias 136 | 137 | x = x.reshape(B, (N * L) // 2 + 1, self.embed_size) 138 | 139 | # ifft 140 | x = torch.fft.irfft(x, n=N * L, dim=1, norm="ortho") 141 | 142 | x = x.reshape(B, N, L, self.embed_size) 143 | x = x.permute(0, 1, 3, 2) # B, N, D, L 144 | 145 | # projection 146 | x = torch.matmul(x, self.embeddings_10) 147 | x = x.reshape(B, N, -1) 148 | x = self.fc(x) 149 | 150 | x = x.permute(0, 2, 1) 151 | 152 | return x 153 | 154 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 155 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 156 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 157 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 158 | if self.task_name == 'imputation': 159 | dec_out = self.imputation( 160 | x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 161 | return dec_out # [B, L, D] 162 | if self.task_name == 'anomaly_detection': 163 | dec_out = self.anomaly_detection(x_enc) 164 | return dec_out # [B, L, D] 165 | if self.task_name == 'classification': 166 | dec_out = self.classification(x_enc, x_mark_enc) 167 | return dec_out # [B, N] 168 | return None 169 | 170 | 171 | class FlattenHead(nn.Module): 172 | def __init__(self, n_vars, nf, target_window, head_dropout=0): 173 | super().__init__() 174 | self.n_vars = n_vars 175 | self.flatten = nn.Flatten(start_dim=-2) 176 | self.linear = nn.Linear(nf, target_window) 177 | self.dropout = nn.Dropout(head_dropout) 178 | 179 | def forward(self, x): # x: [bs x nvars x d_model x patch_num] 180 | x = self.flatten(x) 181 | x = self.linear(x) 182 | x = self.dropout(x) 183 | return x 184 | -------------------------------------------------------------------------------- /models/MICN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers.Embed import DataEmbedding 4 | from layers.Autoformer_EncDec import series_decomp, series_decomp_multi 5 | import torch.nn.functional as F 6 | 7 | 8 | class MIC(nn.Module): 9 | """ 10 | MIC layer to extract local and global features 11 | """ 12 | 13 | def __init__(self, feature_size=512, n_heads=8, dropout=0.05, decomp_kernel=[32], conv_kernel=[24], 14 | isometric_kernel=[18, 6], device='cuda'): 15 | super(MIC, self).__init__() 16 | self.conv_kernel = conv_kernel 17 | self.device = device 18 | 19 | # isometric convolution 20 | self.isometric_conv = nn.ModuleList([nn.Conv1d(in_channels=feature_size, out_channels=feature_size, 21 | kernel_size=i, padding=0, stride=1) 22 | for i in isometric_kernel]) 23 | 24 | # downsampling convolution: padding=i//2, stride=i 25 | self.conv = nn.ModuleList([nn.Conv1d(in_channels=feature_size, out_channels=feature_size, 26 | kernel_size=i, padding=i // 2, stride=i) 27 | for i in conv_kernel]) 28 | 29 | # upsampling convolution 30 | self.conv_trans = nn.ModuleList([nn.ConvTranspose1d(in_channels=feature_size, out_channels=feature_size, 31 | kernel_size=i, padding=0, stride=i) 32 | for i in conv_kernel]) 33 | 34 | self.decomp = nn.ModuleList([series_decomp(k) for k in decomp_kernel]) 35 | self.merge = torch.nn.Conv2d(in_channels=feature_size, out_channels=feature_size, 36 | kernel_size=(len(self.conv_kernel), 1)) 37 | 38 | # feedforward network 39 | self.conv1 = nn.Conv1d(in_channels=feature_size, out_channels=feature_size * 4, kernel_size=1) 40 | self.conv2 = nn.Conv1d(in_channels=feature_size * 4, out_channels=feature_size, kernel_size=1) 41 | self.norm1 = nn.LayerNorm(feature_size) 42 | self.norm2 = nn.LayerNorm(feature_size) 43 | 44 | self.norm = torch.nn.LayerNorm(feature_size) 45 | self.act = torch.nn.Tanh() 46 | self.drop = torch.nn.Dropout(0.05) 47 | 48 | def conv_trans_conv(self, input, conv1d, conv1d_trans, isometric): 49 | batch, seq_len, channel = input.shape 50 | x = input.permute(0, 2, 1) 51 | 52 | # downsampling convolution 53 | x1 = self.drop(self.act(conv1d(x))) 54 | x = x1 55 | 56 | # isometric convolution 57 | zeros = torch.zeros((x.shape[0], x.shape[1], x.shape[2] - 1), device=self.device) 58 | x = torch.cat((zeros, x), dim=-1) 59 | x = self.drop(self.act(isometric(x))) 60 | x = self.norm((x + x1).permute(0, 2, 1)).permute(0, 2, 1) 61 | 62 | # upsampling convolution 63 | x = self.drop(self.act(conv1d_trans(x))) 64 | x = x[:, :, :seq_len] # truncate 65 | 66 | x = self.norm(x.permute(0, 2, 1) + input) 67 | return x 68 | 69 | def forward(self, src): 70 | # multi-scale 71 | multi = [] 72 | for i in range(len(self.conv_kernel)): 73 | src_out, trend1 = self.decomp[i](src) 74 | src_out = self.conv_trans_conv(src_out, self.conv[i], self.conv_trans[i], self.isometric_conv[i]) 75 | multi.append(src_out) 76 | 77 | # merge 78 | mg = torch.tensor([], device=self.device) 79 | for i in range(len(self.conv_kernel)): 80 | mg = torch.cat((mg, multi[i].unsqueeze(1)), dim=1) 81 | mg = self.merge(mg.permute(0, 3, 1, 2)).squeeze(-2).permute(0, 2, 1) 82 | 83 | y = self.norm1(mg) 84 | y = self.conv2(self.conv1(y.transpose(-1, 1))).transpose(-1, 1) 85 | 86 | return self.norm2(mg + y) 87 | 88 | 89 | class SeasonalPrediction(nn.Module): 90 | def __init__(self, embedding_size=512, n_heads=8, dropout=0.05, d_layers=1, decomp_kernel=[32], c_out=1, 91 | conv_kernel=[2, 4], isometric_kernel=[18, 6], device='cuda'): 92 | super(SeasonalPrediction, self).__init__() 93 | 94 | self.mic = nn.ModuleList([MIC(feature_size=embedding_size, n_heads=n_heads, 95 | decomp_kernel=decomp_kernel, conv_kernel=conv_kernel, 96 | isometric_kernel=isometric_kernel, device=device) 97 | for i in range(d_layers)]) 98 | 99 | self.projection = nn.Linear(embedding_size, c_out) 100 | 101 | def forward(self, dec): 102 | for mic_layer in self.mic: 103 | dec = mic_layer(dec) 104 | return self.projection(dec) 105 | 106 | 107 | class Model(nn.Module): 108 | """ 109 | Paper link: https://openreview.net/pdf?id=zt53IDUR1U 110 | """ 111 | def __init__(self, configs, conv_kernel=[12, 16]): 112 | """ 113 | conv_kernel: downsampling and upsampling convolution kernel_size 114 | """ 115 | super(Model, self).__init__() 116 | 117 | decomp_kernel = [] # kernel of decomposition operation 118 | isometric_kernel = [] # kernel of isometric convolution 119 | for ii in conv_kernel: 120 | if ii % 2 == 0: # the kernel of decomposition operation must be odd 121 | decomp_kernel.append(ii + 1) 122 | isometric_kernel.append((configs.seq_len + configs.pred_len + ii) // ii) 123 | else: 124 | decomp_kernel.append(ii) 125 | isometric_kernel.append((configs.seq_len + configs.pred_len + ii - 1) // ii) 126 | 127 | self.task_name = configs.task_name 128 | self.pred_len = configs.pred_len 129 | self.seq_len = configs.seq_len 130 | 131 | # Multiple Series decomposition block from FEDformer 132 | self.decomp_multi = series_decomp_multi(decomp_kernel) 133 | 134 | # embedding 135 | self.dec_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 136 | configs.dropout) 137 | 138 | self.conv_trans = SeasonalPrediction(embedding_size=configs.d_model, n_heads=configs.n_heads, 139 | dropout=configs.dropout, 140 | d_layers=configs.d_layers, decomp_kernel=decomp_kernel, 141 | c_out=configs.c_out, conv_kernel=conv_kernel, 142 | isometric_kernel=isometric_kernel, device=torch.device('cuda:0')) 143 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 144 | # refer to DLinear 145 | self.regression = nn.Linear(configs.seq_len, configs.pred_len) 146 | self.regression.weight = nn.Parameter( 147 | (1 / configs.pred_len) * torch.ones([configs.pred_len, configs.seq_len]), 148 | requires_grad=True) 149 | if self.task_name == 'imputation': 150 | self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True) 151 | if self.task_name == 'anomaly_detection': 152 | self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True) 153 | if self.task_name == 'classification': 154 | self.act = F.gelu 155 | self.dropout = nn.Dropout(configs.dropout) 156 | self.projection = nn.Linear(configs.c_out * configs.seq_len, configs.num_class) 157 | 158 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 159 | # Normalization from Non-stationary Transformer 160 | means = x_enc.mean(1, keepdim=True).detach() 161 | x_enc = x_enc - means 162 | stdev = torch.sqrt( 163 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 164 | x_enc /= stdev 165 | 166 | # prevention for x_mark_dec being single channel case 167 | if x_mark_dec.shape[-1] != 4: 168 | # x_mark_dec = x_mark_dec.repeat(1, 1, 4) 169 | x_mark_dec = None 170 | 171 | # print(x_enc.shape, x_dec.shape, x_mark_dec.shape) 172 | # Multi-scale Hybrid Decomposition 173 | seasonal_init_enc, trend = self.decomp_multi(x_enc) 174 | trend = self.regression(trend.permute(0, 2, 1)).permute(0, 2, 1) 175 | 176 | # embedding 177 | zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device) 178 | seasonal_init_dec = torch.cat([seasonal_init_enc[:, -self.seq_len:, :], zeros], dim=1) 179 | dec_out = self.dec_embedding(seasonal_init_dec, x_mark_dec) 180 | dec_out = self.conv_trans(dec_out) 181 | dec_out = dec_out[:, -self.pred_len:, :] + trend[:, -self.pred_len:, :] 182 | 183 | # De-Normalization from Non-stationary Transformer 184 | dec_out = dec_out * \ 185 | (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 186 | dec_out = dec_out + \ 187 | (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 188 | 189 | return dec_out 190 | 191 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 192 | # Multi-scale Hybrid Decomposition 193 | seasonal_init_enc, trend = self.decomp_multi(x_enc) 194 | 195 | # embedding 196 | dec_out = self.dec_embedding(seasonal_init_enc, x_mark_dec) 197 | dec_out = self.conv_trans(dec_out) 198 | dec_out = dec_out + trend 199 | return dec_out 200 | 201 | def anomaly_detection(self, x_enc): 202 | # Multi-scale Hybrid Decomposition 203 | seasonal_init_enc, trend = self.decomp_multi(x_enc) 204 | 205 | # embedding 206 | dec_out = self.dec_embedding(seasonal_init_enc, None) 207 | dec_out = self.conv_trans(dec_out) 208 | dec_out = dec_out + trend 209 | return dec_out 210 | 211 | def classification(self, x_enc, x_mark_enc): 212 | # Multi-scale Hybrid Decomposition 213 | seasonal_init_enc, trend = self.decomp_multi(x_enc) 214 | # embedding 215 | dec_out = self.dec_embedding(seasonal_init_enc, None) 216 | dec_out = self.conv_trans(dec_out) 217 | dec_out = dec_out + trend 218 | 219 | # Output from Non-stationary Transformer 220 | output = self.act(dec_out) # the output transformer encoder/decoder embeddings don't include non-linearity 221 | output = self.dropout(output) 222 | output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings 223 | output = output.reshape(output.shape[0], -1) # (batch_size, seq_length * d_model) 224 | output = self.projection(output) # (batch_size, num_classes) 225 | return output 226 | 227 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 228 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 229 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 230 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 231 | if self.task_name == 'imputation': 232 | dec_out = self.imputation( 233 | x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 234 | return dec_out # [B, L, D] 235 | if self.task_name == 'anomaly_detection': 236 | dec_out = self.anomaly_detection(x_enc) 237 | return dec_out # [B, L, D] 238 | if self.task_name == 'classification': 239 | dec_out = self.classification(x_enc, x_mark_enc) 240 | return dec_out # [B, N] 241 | return None 242 | -------------------------------------------------------------------------------- /models/MambaTS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops.layers.torch import Rearrange 4 | from einops import rearrange 5 | 6 | 7 | class PredictionHead(nn.Module): 8 | def __init__(self, in_nf, out_nf): 9 | super().__init__() 10 | self.linear = nn.Linear(in_nf, out_nf) 11 | 12 | def forward(self, x, n_vars): 13 | x = rearrange(x, 'b (c l) d -> (b c) (l d)', c=n_vars) 14 | x = self.linear(x) 15 | x = rearrange(x, '(b c) p -> b p c', c=n_vars) 16 | 17 | return x 18 | 19 | 20 | class Model(nn.Module): 21 | """ 22 | Paper link: http://arxiv.org/abs/2405.16440 23 | """ 24 | 25 | def __init__(self, configs): 26 | super().__init__() 27 | self.task_name = configs.task_name 28 | self.seq_len = configs.seq_len 29 | self.pred_len = configs.pred_len 30 | self.patch_len = configs.patch_len 31 | self.stride = configs.stride 32 | 33 | # patching and embedding 34 | self.value_embedding = nn.Sequential( 35 | Rearrange('b c l d -> (b c) l d'), 36 | nn.Linear(configs.patch_len, configs.d_model, bias=False), 37 | ) 38 | 39 | # Encoder 40 | from layers.mamba_ssm.mixer2_seq_simple import MixerTSModel as Mamba 41 | self.encoder = Mamba( 42 | d_model=configs.d_model, # Model dimension d_model 43 | n_layer=configs.e_layers, 44 | n_vars=configs.enc_in, 45 | dropout=configs.dropout, 46 | ssm_cfg={'layer': 'Mamba1'}, 47 | VPT_mode=configs.VPT_mode, 48 | ATSP_solver=configs.ATSP_solver, 49 | use_casual_conv=configs.use_casual_conv, 50 | fused_add_norm=True, 51 | ) 52 | 53 | # Prediction Head 54 | self.head = PredictionHead(configs.d_model * configs.seq_len // configs.patch_len, configs.pred_len) 55 | 56 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 57 | b, _, n_vars = x_enc.shape 58 | 59 | # Norm 60 | means = x_enc.mean(1, keepdim=True).detach() 61 | x_enc = x_enc - means 62 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() 63 | x_enc /= stdev 64 | 65 | # Do patching and embedding 66 | x = x_enc.permute(0, 2, 1) 67 | x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) 68 | enc_in = self.value_embedding(x) 69 | 70 | # Variable Scan along Time (VST) 71 | enc_in = rearrange(enc_in, '(b c) l d -> b (c l) d', c=n_vars) 72 | 73 | # Encoder 74 | enc_out, attns = self.encoder(enc_in) 75 | 76 | # Decoder 77 | dec_out = self.head(enc_out, n_vars=n_vars) 78 | 79 | # De-norm 80 | dec_out = dec_out * (stdev[:, [0], :].repeat(1, self.pred_len, 1)) 81 | dec_out = dec_out + (means[:, [0], :].repeat(1, self.pred_len, 1)) 82 | 83 | return dec_out 84 | 85 | def batch_update_state(self, cost_tensor): 86 | self.encoder.batch_update_state(cost_tensor) 87 | 88 | def set_reordering_index(self, reordering_index): 89 | self.encoder.set_reordering_index(reordering_index) 90 | 91 | def reset_ids_shuffle(self): 92 | self.encoder.reset_ids_shuffle() 93 | -------------------------------------------------------------------------------- /models/PatchTST.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from layers.Transformer_EncDec import Encoder, EncoderLayer 4 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 5 | from layers.Embed import PatchEmbedding 6 | 7 | 8 | class FlattenHead(nn.Module): 9 | def __init__(self, n_vars, nf, target_window, head_dropout=0): 10 | super().__init__() 11 | self.n_vars = n_vars 12 | self.flatten = nn.Flatten(start_dim=-2) 13 | self.linear = nn.Linear(nf, target_window) 14 | self.dropout = nn.Dropout(head_dropout) 15 | 16 | def forward(self, x): # x: [bs x nvars x d_model x patch_num] 17 | x = self.flatten(x) 18 | x = self.linear(x) 19 | x = self.dropout(x) 20 | return x 21 | 22 | 23 | class Model(nn.Module): 24 | """ 25 | Paper link: https://arxiv.org/pdf/2211.14730.pdf 26 | """ 27 | 28 | def __init__(self, configs): 29 | """ 30 | patch_len: int, patch len for patch_embedding 31 | stride: int, stride for patch_embedding 32 | """ 33 | super().__init__() 34 | self.task_name = configs.task_name 35 | self.seq_len = configs.seq_len 36 | self.pred_len = configs.pred_len 37 | patch_len = configs.patch_len 38 | stride = configs.stride 39 | 40 | padding = stride 41 | 42 | # patching and embedding 43 | self.patch_embedding = PatchEmbedding( 44 | configs.d_model, patch_len, stride, padding, configs.dropout) 45 | 46 | # Encoder 47 | self.encoder = Encoder( 48 | [ 49 | EncoderLayer( 50 | AttentionLayer( 51 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 52 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 53 | configs.d_model, 54 | configs.d_ff, 55 | dropout=configs.dropout, 56 | activation=configs.activation 57 | ) for l in range(configs.e_layers) 58 | ], 59 | norm_layer=torch.nn.LayerNorm(configs.d_model) 60 | ) 61 | 62 | # Prediction Head 63 | self.head_nf = configs.d_model * \ 64 | int((configs.seq_len - patch_len) / stride + 2) 65 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 66 | self.head = FlattenHead(configs.enc_in, self.head_nf, configs.pred_len, 67 | head_dropout=configs.dropout) 68 | elif self.task_name == 'imputation' or self.task_name == 'anomaly_detection': 69 | self.head = FlattenHead(configs.enc_in, self.head_nf, configs.seq_len, 70 | head_dropout=configs.dropout) 71 | elif self.task_name == 'classification': 72 | self.flatten = nn.Flatten(start_dim=-2) 73 | self.dropout = nn.Dropout(configs.dropout) 74 | self.projection = nn.Linear( 75 | self.head_nf * configs.enc_in, configs.num_class) 76 | 77 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 78 | # Normalization from Non-stationary Transformer 79 | means = x_enc.mean(1, keepdim=True).detach() 80 | x_enc = x_enc - means 81 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 82 | x_enc /= stdev 83 | 84 | # do patching and embedding 85 | x_enc = x_enc.permute(0, 2, 1) 86 | # u: [bs * nvars x patch_num x d_model] 87 | enc_out, n_vars = self.patch_embedding(x_enc) 88 | 89 | # Encoder 90 | # z: [bs * nvars x patch_num x d_model] 91 | enc_out, attns = self.encoder(enc_out) 92 | # z: [bs x nvars x patch_num x d_model] 93 | enc_out = torch.reshape( 94 | enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])) 95 | # z: [bs x nvars x d_model x patch_num] 96 | enc_out = enc_out.permute(0, 1, 3, 2) 97 | 98 | # Decoder 99 | dec_out = self.head(enc_out) # z: [bs x nvars x target_window] 100 | dec_out = dec_out.permute(0, 2, 1) 101 | 102 | # De-Normalization from Non-stationary Transformer 103 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 104 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 105 | 106 | return dec_out 107 | 108 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 109 | # Normalization from Non-stationary Transformer 110 | means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1) 111 | means = means.unsqueeze(1).detach() 112 | x_enc = x_enc - means 113 | x_enc = x_enc.masked_fill(mask == 0, 0) 114 | stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) / 115 | torch.sum(mask == 1, dim=1) + 1e-5) 116 | stdev = stdev.unsqueeze(1).detach() 117 | x_enc /= stdev 118 | 119 | # do patching and embedding 120 | x_enc = x_enc.permute(0, 2, 1) 121 | # u: [bs * nvars x patch_num x d_model] 122 | enc_out, n_vars = self.patch_embedding(x_enc) 123 | 124 | # Encoder 125 | # z: [bs * nvars x patch_num x d_model] 126 | enc_out, attns = self.encoder(enc_out) 127 | # z: [bs x nvars x patch_num x d_model] 128 | enc_out = torch.reshape( 129 | enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])) 130 | # z: [bs x nvars x d_model x patch_num] 131 | enc_out = enc_out.permute(0, 1, 3, 2) 132 | 133 | # Decoder 134 | dec_out = self.head(enc_out) # z: [bs x nvars x target_window] 135 | dec_out = dec_out.permute(0, 2, 1) 136 | 137 | # De-Normalization from Non-stationary Transformer 138 | dec_out = dec_out * \ 139 | (stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1)) 140 | dec_out = dec_out + \ 141 | (means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1)) 142 | return dec_out 143 | 144 | def anomaly_detection(self, x_enc): 145 | # Normalization from Non-stationary Transformer 146 | means = x_enc.mean(1, keepdim=True).detach() 147 | x_enc = x_enc - means 148 | stdev = torch.sqrt( 149 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 150 | x_enc /= stdev 151 | 152 | # do patching and embedding 153 | x_enc = x_enc.permute(0, 2, 1) 154 | # u: [bs * nvars x patch_num x d_model] 155 | enc_out, n_vars = self.patch_embedding(x_enc) 156 | 157 | # Encoder 158 | # z: [bs * nvars x patch_num x d_model] 159 | enc_out, attns = self.encoder(enc_out) 160 | # z: [bs x nvars x patch_num x d_model] 161 | enc_out = torch.reshape( 162 | enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])) 163 | # z: [bs x nvars x d_model x patch_num] 164 | enc_out = enc_out.permute(0, 1, 3, 2) 165 | 166 | # Decoder 167 | dec_out = self.head(enc_out) # z: [bs x nvars x target_window] 168 | dec_out = dec_out.permute(0, 2, 1) 169 | 170 | # De-Normalization from Non-stationary Transformer 171 | dec_out = dec_out * \ 172 | (stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1)) 173 | dec_out = dec_out + \ 174 | (means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1)) 175 | return dec_out 176 | 177 | def classification(self, x_enc, x_mark_enc): 178 | # Normalization from Non-stationary Transformer 179 | means = x_enc.mean(1, keepdim=True).detach() 180 | x_enc = x_enc - means 181 | stdev = torch.sqrt( 182 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 183 | x_enc /= stdev 184 | 185 | # do patching and embedding 186 | x_enc = x_enc.permute(0, 2, 1) 187 | # u: [bs * nvars x patch_num x d_model] 188 | enc_out, n_vars = self.patch_embedding(x_enc) 189 | 190 | # Encoder 191 | # z: [bs * nvars x patch_num x d_model] 192 | enc_out, attns = self.encoder(enc_out) 193 | # z: [bs x nvars x patch_num x d_model] 194 | enc_out = torch.reshape( 195 | enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])) 196 | # z: [bs x nvars x d_model x patch_num] 197 | enc_out = enc_out.permute(0, 1, 3, 2) 198 | 199 | # Decoder 200 | output = self.flatten(enc_out) 201 | output = self.dropout(output) 202 | output = output.reshape(output.shape[0], -1) 203 | output = self.projection(output) # (batch_size, num_classes) 204 | return output 205 | 206 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 207 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 208 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 209 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 210 | if self.task_name == 'imputation': 211 | dec_out = self.imputation( 212 | x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 213 | return dec_out # [B, L, D] 214 | if self.task_name == 'anomaly_detection': 215 | dec_out = self.anomaly_detection(x_enc) 216 | return dec_out # [B, L, D] 217 | if self.task_name == 'classification': 218 | dec_out = self.classification(x_enc, x_mark_enc) 219 | return dec_out # [B, N] 220 | return None 221 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiudingCai/MambaTS-pytorch/3b6797d9bf5178a490e1bb3c3e9f4d223f2af51d/models/__init__.py -------------------------------------------------------------------------------- /models/iTransformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Transformer_EncDec import Encoder, EncoderLayer 5 | from layers.SelfAttention_Family import FullAttention, AttentionLayer 6 | from layers.Embed import DataEmbedding_inverted 7 | import numpy as np 8 | 9 | 10 | class Model(nn.Module): 11 | """ 12 | Paper link: https://arxiv.org/abs/2310.06625 13 | """ 14 | 15 | def __init__(self, configs): 16 | super(Model, self).__init__() 17 | self.task_name = configs.task_name 18 | self.seq_len = configs.seq_len 19 | self.pred_len = configs.pred_len 20 | self.output_attention = configs.output_attention 21 | 22 | # Embedding 23 | self.enc_embedding = DataEmbedding_inverted(configs.seq_len, configs.d_model, configs.embed, configs.freq, 24 | configs.dropout) 25 | # Encoder 26 | self.encoder = Encoder( 27 | [ 28 | EncoderLayer( 29 | AttentionLayer( 30 | FullAttention(False, configs.factor, attention_dropout=configs.dropout, 31 | output_attention=configs.output_attention), configs.d_model, configs.n_heads), 32 | configs.d_model, 33 | configs.d_ff, 34 | dropout=configs.dropout, 35 | activation=configs.activation 36 | ) for l in range(configs.e_layers) 37 | ], 38 | norm_layer=torch.nn.LayerNorm(configs.d_model) 39 | ) 40 | # Decoder 41 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 42 | self.projection = nn.Linear(configs.d_model, configs.pred_len, bias=True) 43 | if self.task_name == 'imputation': 44 | self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True) 45 | if self.task_name == 'anomaly_detection': 46 | self.projection = nn.Linear(configs.d_model, configs.seq_len, bias=True) 47 | if self.task_name == 'classification': 48 | self.act = F.gelu 49 | self.dropout = nn.Dropout(configs.dropout) 50 | self.projection = nn.Linear(configs.d_model * configs.enc_in, configs.num_class) 51 | 52 | def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): 53 | 54 | # Normalization from Non-stationary Transformer 55 | means = x_enc.mean(1, keepdim=True).detach() 56 | x_enc = x_enc - means 57 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 58 | x_enc /= stdev 59 | 60 | _, _, N = x_enc.shape 61 | 62 | # Embedding 63 | # print(f"input: {x_enc.shape}, x_mark_enc: {x_mark_enc.shape}") 64 | # input: torch.Size([32, 96, 9]), x_mark_enc: torch.Size([32, 96, 4]) 65 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 66 | # print(f"embed: {enc_out.shape}") # embed: torch.Size([32, 13, 512]) 67 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 68 | # print(f"encoder: {enc_out.shape}") # encoder: torch.Size([32, 13, 512]) 69 | 70 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] 71 | 72 | # De-Normalization from Non-stationary Transformer 73 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 74 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) 75 | 76 | return dec_out 77 | 78 | def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask): 79 | # Normalization from Non-stationary Transformer 80 | means = x_enc.mean(1, keepdim=True).detach() 81 | x_enc = x_enc - means 82 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 83 | x_enc /= stdev 84 | 85 | _, L, N = x_enc.shape 86 | 87 | # Embedding 88 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 89 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 90 | 91 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N] 92 | # De-Normalization from Non-stationary Transformer 93 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 94 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 95 | return dec_out 96 | 97 | def anomaly_detection(self, x_enc): 98 | # Normalization from Non-stationary Transformer 99 | means = x_enc.mean(1, keepdim=True).detach() 100 | x_enc = x_enc - means 101 | stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) 102 | x_enc /= stdev 103 | 104 | _, L, N = x_enc.shape 105 | 106 | # Embedding 107 | enc_out = self.enc_embedding(x_enc, None) 108 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 109 | 110 | dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :Nc] 111 | # De-Normalization from Non-stationary Transformer 112 | dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 113 | dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, L, 1)) 114 | return dec_out 115 | 116 | def classification(self, x_enc, x_mark_enc): 117 | # Embedding 118 | enc_out = self.enc_embedding(x_enc, None) 119 | enc_out, attns = self.encoder(enc_out, attn_mask=None) 120 | 121 | # Output 122 | output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity 123 | output = self.dropout(output) 124 | output = output.reshape(output.shape[0], -1) # (batch_size, c_in * d_model) 125 | output = self.projection(output) # (batch_size, num_classes) 126 | return output 127 | 128 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): 129 | if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': 130 | dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) 131 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 132 | if self.task_name == 'imputation': 133 | dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask) 134 | return dec_out # [B, L, D] 135 | if self.task_name == 'anomaly_detection': 136 | dec_out = self.anomaly_detection(x_enc) 137 | return dec_out # [B, L, D] 138 | if self.task_name == 'classification': 139 | dec_out = self.classification(x_enc, x_mark_enc) 140 | return dec_out # [B, N] 141 | return None 142 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | causal-conv1d>=1.2.0 2 | mamba-ssm 3 | python-tsp 4 | einops 5 | matplotlib 6 | numpy 7 | pandas 8 | patool 9 | reformer-pytorch==1.4.4 10 | scikit-learn==1.2.2 11 | scipy==1.10.1 12 | sktime==0.16.1 13 | sympy 14 | torch 15 | tqdm 16 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from exp.exp_long_term_forecasting import Exp_Long_Term_Forecast 5 | from utils.print_args import print_args 6 | from utils.tools import set_seed 7 | 8 | if __name__ == '__main__': 9 | 10 | parser = argparse.ArgumentParser(description='MambaTS') 11 | 12 | # basic config 13 | parser.add_argument('--task_name', type=str, required=True, default='long_term_forecast', 14 | help='task name, options:[long_term_forecast, short_term_forecast]') 15 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status') 16 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id') 17 | parser.add_argument('--seed', type=int, default=3047, help='random seed') 18 | parser.add_argument('--model', type=str, required=True, default='MambaTS', 19 | help='model name, options: [Autoformer, Transformer, MambaTS]') 20 | 21 | # data loader 22 | parser.add_argument('--data', type=str, required=True, default='ETTm1', help='dataset type') 23 | parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file') 24 | parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file') 25 | parser.add_argument('--features', type=str, default='M', 26 | help='forecasting task, options:[M, S, MS]; ' 27 | 'M:multivariate predict multivariate, ' 28 | 'S:univariate predict univariate, ' 29 | 'MS:multivariate predict univariate') 30 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task') 31 | parser.add_argument('--freq', type=str, default='h', 32 | help='freq for time features encoding, ' 33 | 'options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], ' 34 | 'you can also use more detailed freq like 15min or 3h') 35 | parser.add_argument('--checkpoints', type=str, 36 | default='./checkpoints/', 37 | help='location of model checkpoints') 38 | parser.add_argument('--visualization', type=str, 39 | default='./test_results', 40 | help='location of model checkpoints') 41 | parser.add_argument('--results', type=str, 42 | default='./results', 43 | help='location of model checkpoints') 44 | 45 | # forecasting task 46 | parser.add_argument('--seq_len', type=int, default=96, help='input sequence length') 47 | parser.add_argument('--label_len', type=int, default=48, help='start token length') 48 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length') 49 | parser.add_argument('--max_pred_len', type=int, default=-1, help='prediction sequence length') 50 | parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4') 51 | parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False) 52 | 53 | # model define 54 | parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock') 55 | parser.add_argument('--num_kernels', type=int, default=6, help='for Inception') 56 | parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') 57 | parser.add_argument('--dec_in', type=int, default=7, help='decoder input size') 58 | parser.add_argument('--c_out', type=int, default=7, help='output size') 59 | parser.add_argument('--d_model', type=int, default=512, help='dimension of model') 60 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads') 61 | parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers') 62 | parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers') 63 | parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn') 64 | parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average') 65 | parser.add_argument('--factor', type=int, default=1, help='attn factor') 66 | parser.add_argument('--distil', action='store_false', 67 | help='whether to use distilling in encoder, using this argument means not using distilling', 68 | default=True) 69 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout') 70 | parser.add_argument('--embed', type=str, default='timeF', 71 | help='time features encoding, options:[timeF, fixed, learned]') 72 | parser.add_argument('--activation', type=str, default='gelu', help='activation') 73 | parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder') 74 | 75 | # optimization 76 | parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers') 77 | parser.add_argument('--itr', type=int, default=1, help='experiments times') 78 | parser.add_argument('--train_epochs', type=int, default=10, help='train epochs') 79 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data') 80 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience') 81 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate') 82 | parser.add_argument('--des', type=str, default='test', help='exp description') 83 | parser.add_argument('--loss', type=str, default='MSE', help='loss function') 84 | parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate') 85 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False) 86 | 87 | parser.add_argument('--optimizer', type=str, default='adam', help='optimizer') 88 | parser.add_argument('--no_lradj', action='store_true', ) 89 | parser.add_argument('--reg_weight', default=1e-5, type=float, help='regularization weight') 90 | parser.add_argument('--momentum', default=0.99, type=float, help='momentum') 91 | # update lr scheduler by iter 92 | parser.add_argument('--lradj_by_iter', action='store_true', ) 93 | parser.add_argument('--warmup_steps', default=0.1, type=float, help='warmup') 94 | parser.add_argument('--iters_per_epoch', default=None, type=str, help='warmup') 95 | 96 | # GPU 97 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') 98 | parser.add_argument('--gpu', type=int, default=0, help='gpu') 99 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) 100 | parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus') 101 | 102 | # de-stationary projector params 103 | parser.add_argument('--p_hidden_dims', type=int, nargs='+', default=[128, 128], 104 | help='hidden layer dimensions of projector (List)') 105 | parser.add_argument('--p_hidden_layers', type=int, default=2, help='number of hidden layers in projector') 106 | 107 | # PatchTST 108 | parser.add_argument('--patch_len', type=int, default=None, help='prediction sequence length') 109 | parser.add_argument('--stride', type=int, default=None, help='prediction sequence length') 110 | 111 | # MambaTS 112 | parser.add_argument('--VPT_mode', type=int, default=1, 113 | help='variable permutation training mode, 0 for no use, 1 for default') 114 | parser.add_argument('--ATSP_solver', type=str, default='SA', help='ATSP_solver', 115 | choices=['Random', 'SA', 'GD', 'LS', 'LK']) 116 | parser.add_argument('--use_casual_conv', action='store_true', help='use multiple gpus', default=False) 117 | 118 | args = parser.parse_args() 119 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False 120 | # args.use_gpu = False 121 | if args.iters_per_epoch is not None: 122 | args.iters_per_epoch = eval(args.iters_per_epoch) 123 | 124 | if args.use_gpu and args.use_multi_gpu: 125 | args.devices = args.devices.replace(' ', '') 126 | device_ids = args.devices.split(',') 127 | args.device_ids = [int(id_) for id_ in device_ids] 128 | args.gpu = args.device_ids[0] 129 | 130 | set_seed(args.seed) 131 | 132 | print('Args in experiment:') 133 | print_args(args) 134 | 135 | if args.task_name == 'long_term_forecast': 136 | Exp = Exp_Long_Term_Forecast 137 | 138 | if args.is_training: 139 | for ii in range(args.itr): 140 | # setting record of experiments 141 | exp = Exp(args) # set experiments 142 | setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format( 143 | args.task_name, 144 | args.model_id, 145 | args.model, 146 | args.data, 147 | args.features, 148 | args.seq_len, 149 | args.label_len, 150 | args.pred_len, 151 | args.d_model, 152 | args.n_heads, 153 | args.e_layers, 154 | args.d_layers, 155 | args.d_ff, 156 | args.factor, 157 | args.embed, 158 | args.distil, 159 | args.des, ii) 160 | 161 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) 162 | exp.train(setting) 163 | 164 | if args.VPT_mode in [1]: 165 | print(f"resetting ids shuffle...") 166 | exp.model.reset_ids_shuffle() 167 | 168 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 169 | set_seed(args.seed) 170 | exp.test(setting) 171 | 172 | torch.cuda.empty_cache() 173 | else: 174 | ii = 0 175 | setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format( 176 | args.task_name, 177 | args.model_id, 178 | args.model, 179 | args.data, 180 | args.features, 181 | args.seq_len, 182 | args.label_len, 183 | args.pred_len, 184 | args.d_model, 185 | args.n_heads, 186 | args.e_layers, 187 | args.d_layers, 188 | args.d_ff, 189 | args.factor, 190 | args.embed, 191 | args.distil, 192 | args.des, ii) 193 | 194 | exp = Exp(args) # set experiments 195 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 196 | set_seed(args.seed) 197 | exp.test(setting, test=1) 198 | 199 | torch.cuda.empty_cache() 200 | -------------------------------------------------------------------------------- /scripts/long_term_forecast/MambaTS_COV.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=MambaTS 4 | 5 | root_path_name=./dataset/Covid-19 6 | data_path_name=Covid-19.csv 7 | model_id_name=Covid19 8 | data_name=custom 9 | seq_len=96 10 | 11 | 12 | for pred_len in 96 48 24 12 13 | do 14 | python -u run.py \ 15 | --task_name long_term_forecast \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --enc_in 948 \ 20 | --dec_in 948 \ 21 | --c_out 948 \ 22 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 23 | --model $model_name \ 24 | --data $data_name \ 25 | --features M \ 26 | --seq_len $seq_len \ 27 | --label_len 48 \ 28 | --pred_len $pred_len \ 29 | --e_layers 4 \ 30 | --d_layers 2 \ 31 | --factor 1 \ 32 | --des 'Exp' \ 33 | --itr 1 \ 34 | --n_heads 16 \ 35 | --d_model 128 \ 36 | --d_ff 512 \ 37 | --dropout 0.2 \ 38 | --patch_len 48 --stride 48 \ 39 | --train_epochs 10 --patience 3 --batch_size 16 --learning_rate 0.001 --VPT_mode 1 --ATSP_solver SA 40 | done 41 | -------------------------------------------------------------------------------- /scripts/long_term_forecast/MambaTS_ECL.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=MambaTS 4 | 5 | root_path_name=./dataset/electricity 6 | data_path_name=electricity.csv 7 | model_id_name=electricity 8 | data_name=custom 9 | seq_len=720 10 | 11 | 12 | for pred_len in 96 192 336 720 13 | do 14 | python -u run.py \ 15 | --task_name long_term_forecast \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --enc_in 321 \ 20 | --dec_in 321 \ 21 | --c_out 321 \ 22 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 23 | --model $model_name \ 24 | --data $data_name \ 25 | --features M \ 26 | --seq_len $seq_len \ 27 | --label_len 48 \ 28 | --pred_len $pred_len \ 29 | --e_layers 4 \ 30 | --d_layers 2 \ 31 | --factor 1 \ 32 | --des 'Exp' \ 33 | --itr 1 \ 34 | --n_heads 16 \ 35 | --d_model 128 \ 36 | --dropout 0.2 \ 37 | --patch_len 48 --stride 48 \ 38 | --train_epochs 10 --patience 3 --batch_size 16 --learning_rate 0.0005 --VPT_mode 1 --ATSP_solver SA 39 | done 40 | -------------------------------------------------------------------------------- /scripts/long_term_forecast/MambaTS_ETTh2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=MambaTS 4 | 5 | root_path_name=./dataset/ETT-small 6 | data_path_name=ETTh2.csv 7 | model_id_name=ETTh2 8 | data_name=ETTh2 9 | seq_len=720 10 | 11 | 12 | pred_len=96 13 | python -u run.py \ 14 | --task_name long_term_forecast \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --enc_in 7 \ 19 | --dec_in 7 \ 20 | --c_out 7 \ 21 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 22 | --model $model_name \ 23 | --data $data_name \ 24 | --features M \ 25 | --seq_len $seq_len \ 26 | --label_len 48 \ 27 | --pred_len $pred_len \ 28 | --e_layers 5 \ 29 | --d_layers 2 \ 30 | --factor 1 \ 31 | --des 'Exp' \ 32 | --itr 1 \ 33 | --n_heads 16 \ 34 | --d_model 16 \ 35 | --dropout 0.3 \ 36 | --patch_len 48 --stride 48 --VPT_mode 1 --ATSP_solver SA \ 37 | --train_epochs 10 --patience 3 --batch_size 8 --learning_rate 0.001 38 | 39 | pred_len=192 40 | python -u run.py \ 41 | --task_name long_term_forecast \ 42 | --is_training 1 \ 43 | --root_path $root_path_name \ 44 | --data_path $data_path_name \ 45 | --enc_in 7 \ 46 | --dec_in 7 \ 47 | --c_out 7 \ 48 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 49 | --model $model_name \ 50 | --data $data_name \ 51 | --features M \ 52 | --seq_len $seq_len \ 53 | --label_len 48 \ 54 | --pred_len $pred_len \ 55 | --e_layers 3 \ 56 | --d_layers 2 \ 57 | --factor 1 \ 58 | --des 'Exp' \ 59 | --itr 1 \ 60 | --n_heads 16 \ 61 | --d_model 16 \ 62 | --dropout 0.3 \ 63 | --patch_len 48 --stride 48 --VPT_mode 1 --ATSP_solver SA \ 64 | --train_epochs 10 --patience 3 --batch_size 16 --learning_rate 0.001 65 | 66 | pred_len=336 67 | python -u run.py \ 68 | --task_name long_term_forecast \ 69 | --is_training 1 \ 70 | --root_path $root_path_name \ 71 | --data_path $data_path_name \ 72 | --enc_in 7 \ 73 | --dec_in 7 \ 74 | --c_out 7 \ 75 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 76 | --model $model_name \ 77 | --data $data_name \ 78 | --features M \ 79 | --seq_len $seq_len \ 80 | --label_len 48 \ 81 | --pred_len $pred_len \ 82 | --e_layers 2 \ 83 | --d_layers 2 \ 84 | --factor 1 \ 85 | --des 'Exp' \ 86 | --itr 1 \ 87 | --n_heads 16 \ 88 | --d_model 32 \ 89 | --dropout 0.3 \ 90 | --patch_len 48 --stride 48 --VPT_mode 1 --ATSP_solver SA \ 91 | --train_epochs 10 --patience 3 --batch_size 32 --learning_rate 0.0005 92 | 93 | pred_len=720 94 | python -u run.py \ 95 | --task_name long_term_forecast \ 96 | --is_training 1 \ 97 | --root_path $root_path_name \ 98 | --data_path $data_path_name \ 99 | --enc_in 7 \ 100 | --dec_in 7 \ 101 | --c_out 7 \ 102 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 103 | --model $model_name \ 104 | --data $data_name \ 105 | --features M \ 106 | --seq_len $seq_len \ 107 | --label_len 48 \ 108 | --pred_len $pred_len \ 109 | --e_layers 1 \ 110 | --d_layers 2 \ 111 | --factor 1 \ 112 | --des 'Exp' \ 113 | --itr 1 \ 114 | --n_heads 16 \ 115 | --d_model 128 \ 116 | --dropout 0.3 \ 117 | --patch_len 48 --stride 48 --VPT_mode 1 --ATSP_solver SA \ 118 | --train_epochs 10 --patience 3 --batch_size 16 --learning_rate 0.0005 119 | -------------------------------------------------------------------------------- /scripts/long_term_forecast/MambaTS_ETTm2.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=MambaTS 4 | 5 | root_path_name=./dataset/ETT-small 6 | data_path_name=ETTm2.csv 7 | model_id_name=ETTm2 8 | data_name=ETTm2 9 | seq_len=720 10 | 11 | 12 | for pred_len in 96 192 336 720 13 | do 14 | python -u run.py \ 15 | --task_name long_term_forecast \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --enc_in 7 \ 20 | --dec_in 7 \ 21 | --c_out 7 \ 22 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 23 | --model $model_name \ 24 | --data $data_name \ 25 | --features M \ 26 | --seq_len $seq_len \ 27 | --label_len 48 \ 28 | --pred_len $pred_len \ 29 | --e_layers 3 \ 30 | --d_layers 2 \ 31 | --factor 1 \ 32 | --des 'Exp' \ 33 | --itr 1 \ 34 | --n_heads 16 \ 35 | --d_model 16 \ 36 | --dropout 0.3 \ 37 | --patch_len 48 --stride 48 --VPT_mode 1 --ATSP_solver SA \ 38 | --train_epochs 10 --patience 3 --batch_size 16 --learning_rate 0.001 39 | done 40 | -------------------------------------------------------------------------------- /scripts/long_term_forecast/MambaTS_PEMS.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=MambaTS 4 | 5 | root_path_name=./dataset/PEMS 6 | data_path_name=PEMS03.npz 7 | model_id_name=PEMS 8 | data_name=PEMS 9 | seq_len=720 10 | 11 | 12 | model_name=MambaTSv3 13 | for pred_len in 96 48 24 12 14 | do 15 | python -u run.py \ 16 | --task_name long_term_forecast \ 17 | --is_training 1 \ 18 | --root_path $root_path_name \ 19 | --data_path $data_path_name \ 20 | --enc_in 358 \ 21 | --dec_in 358 \ 22 | --c_out 358 \ 23 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 24 | --model $model_name \ 25 | --data $data_name \ 26 | --features M \ 27 | --seq_len $seq_len \ 28 | --label_len 48 \ 29 | --pred_len $pred_len \ 30 | --e_layers 6 \ 31 | --d_layers 2 \ 32 | --factor 1 \ 33 | --des 'Exp' \ 34 | --itr 1 \ 35 | --n_heads 16 \ 36 | --d_model 128 \ 37 | --d_ff 512 \ 38 | --dropout 0.2 \ 39 | --patch_len 48 --stride 48 \ 40 | --train_epochs 10 --patience 3 --batch_size 16 --learning_rate 0.001 --VPT_mode 1 --ATSP_solver SA 41 | done 42 | -------------------------------------------------------------------------------- /scripts/long_term_forecast/MambaTS_SOL.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=MambaTS 4 | 5 | root_path_name=./dataset/Solar 6 | data_path_name=solar_AL.txt 7 | model_id_name=Solar 8 | data_name=Solar 9 | seq_len=720 10 | 11 | 12 | for pred_len in 96 192 336 720 13 | do 14 | python -u run.py \ 15 | --task_name long_term_forecast \ 16 | --is_training 1 \ 17 | --root_path $root_path_name \ 18 | --data_path $data_path_name \ 19 | --enc_in 137 \ 20 | --dec_in 137 \ 21 | --c_out 137 \ 22 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 23 | --model $model_name \ 24 | --data $data_name \ 25 | --features M \ 26 | --seq_len $seq_len \ 27 | --label_len 48 \ 28 | --pred_len $pred_len \ 29 | --e_layers 3 \ 30 | --d_layers 2 \ 31 | --factor 1 \ 32 | --des 'Exp' \ 33 | --itr 1 \ 34 | --n_heads 16 \ 35 | --d_model 32 \ 36 | --dropout 0.3 \ 37 | --patch_len 48 --stride 48 --VPT_mode 1 --ATSP_solver SA \ 38 | --train_epochs 10 --patience 3 --batch_size 16 --learning_rate 0.0005 39 | done 40 | -------------------------------------------------------------------------------- /scripts/long_term_forecast/MambaTS_TFF.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=MambaTS 4 | 5 | root_path_name=./dataset/traffic 6 | data_path_name=traffic.csv 7 | model_id_name=traffic 8 | data_name=custom 9 | seq_len=720 10 | 11 | 12 | for pred_len in 96 192 336 720 13 | do 14 | python -u run.py \ 15 | --task_name long_term_forecast \ 16 | --is_training 1 \ 17 | --num_workers 0 \ 18 | --root_path $root_path_name \ 19 | --data_path $data_path_name \ 20 | --enc_in 862 \ 21 | --dec_in 862 \ 22 | --c_out 862 \ 23 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 24 | --model $model_name \ 25 | --data $data_name \ 26 | --features M \ 27 | --seq_len $seq_len \ 28 | --label_len 48 \ 29 | --pred_len $pred_len \ 30 | --e_layers 4 \ 31 | --d_layers 2 \ 32 | --factor 1 \ 33 | --des 'Exp' \ 34 | --itr 1 \ 35 | --n_heads 16 \ 36 | --d_model 512 \ 37 | --d_ff 512 \ 38 | --dropout 0.2 \ 39 | --patch_len $seq_len --stride $seq_len \ 40 | --train_epochs 10 --patience 3 --batch_size 8 --learning_rate 0.0005 --VPT_mode 1 --ATSP_solver SA 41 | done 42 | -------------------------------------------------------------------------------- /scripts/long_term_forecast/MambaTS_WTH.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | model_name=MambaTS 4 | 5 | root_path_name=./dataset/weather 6 | data_path_name=weather.csv 7 | model_id_name=weather 8 | data_name=custom 9 | seq_len=720 10 | 11 | 12 | pred_len=96 13 | python -u run.py \ 14 | --task_name long_term_forecast \ 15 | --is_training 1 \ 16 | --root_path $root_path_name \ 17 | --data_path $data_path_name \ 18 | --enc_in 21 \ 19 | --dec_in 21 \ 20 | --c_out 21 \ 21 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 22 | --model $model_name \ 23 | --data $data_name \ 24 | --features M \ 25 | --seq_len $seq_len \ 26 | --label_len 48 \ 27 | --pred_len $pred_len \ 28 | --e_layers 3 \ 29 | --d_layers 2 \ 30 | --factor 1 \ 31 | --des 'Exp' \ 32 | --itr 1 \ 33 | --n_heads 16 \ 34 | --d_model 32 \ 35 | --d_ff 512 \ 36 | --dropout 0.2 \ 37 | --patch_len 48 --stride 48 \ 38 | --train_epochs 10 --patience 3 --batch_size 8 --learning_rate 0.0005 --VPT_mode 1 --ATSP_solver SA 39 | 40 | pred_len=192 41 | python -u run.py \ 42 | --task_name long_term_forecast \ 43 | --is_training 1 \ 44 | --root_path $root_path_name \ 45 | --data_path $data_path_name \ 46 | --enc_in 21 \ 47 | --dec_in 21 \ 48 | --c_out 21 \ 49 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 50 | --model $model_name \ 51 | --data $data_name \ 52 | --features M \ 53 | --seq_len $seq_len \ 54 | --label_len 48 \ 55 | --pred_len $pred_len \ 56 | --e_layers 2 \ 57 | --d_layers 2 \ 58 | --factor 1 \ 59 | --des 'Exp' \ 60 | --itr 1 \ 61 | --n_heads 16 \ 62 | --d_model 32 \ 63 | --d_ff 512 \ 64 | --dropout 0.3 \ 65 | --patch_len 48 --stride 48 \ 66 | --train_epochs 10 --patience 3 --batch_size 16 --learning_rate 0.001 --VPT_mode 1 --ATSP_solver SA 67 | 68 | pred_len=336 69 | python -u run.py \ 70 | --task_name long_term_forecast \ 71 | --is_training 1 \ 72 | --root_path $root_path_name \ 73 | --data_path $data_path_name \ 74 | --enc_in 21 \ 75 | --dec_in 21 \ 76 | --c_out 21 \ 77 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 78 | --model $model_name \ 79 | --data $data_name \ 80 | --features M \ 81 | --seq_len $seq_len \ 82 | --label_len 48 \ 83 | --pred_len $pred_len \ 84 | --e_layers 2 \ 85 | --d_layers 2 \ 86 | --factor 1 \ 87 | --des 'Exp' \ 88 | --itr 1 \ 89 | --n_heads 16 \ 90 | --d_model 32 \ 91 | --d_ff 512 \ 92 | --dropout 0.3 \ 93 | --patch_len 48 --stride 48 \ 94 | --train_epochs 10 --patience 3 --batch_size 8 --learning_rate 0.001 --VPT_mode 1 --ATSP_solver SA 95 | 96 | pred_len=720 97 | python -u run.py \ 98 | --task_name long_term_forecast \ 99 | --is_training 1 \ 100 | --root_path $root_path_name \ 101 | --data_path $data_path_name \ 102 | --enc_in 21 \ 103 | --dec_in 21 \ 104 | --c_out 21 \ 105 | --model_id $model_id_name'_'$seq_len'_'$pred_len \ 106 | --model $model_name \ 107 | --data $data_name \ 108 | --features M \ 109 | --seq_len $seq_len \ 110 | --label_len 48 \ 111 | --pred_len $pred_len \ 112 | --e_layers 2 \ 113 | --d_layers 2 \ 114 | --factor 1 \ 115 | --des 'Exp' \ 116 | --itr 1 \ 117 | --n_heads 16 \ 118 | --d_model 32 \ 119 | --d_ff 512 \ 120 | --dropout 0.3 \ 121 | --patch_len 48 --stride 48 \ 122 | --train_epochs 10 --patience 3 --batch_size 8 --learning_rate 0.001 --VPT_mode 1 --ATSP_solver SA 123 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiudingCai/MambaTS-pytorch/3b6797d9bf5178a490e1bb3c3e9f4d223f2af51d/utils/__init__.py -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright © 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | Loss functions for PyTorch. 17 | """ 18 | 19 | import torch as t 20 | import torch.nn as nn 21 | import numpy as np 22 | import pdb 23 | import torch 24 | from scipy.special import comb 25 | 26 | 27 | def divide_no_nan(a, b): 28 | """ 29 | a/b where the resulted NaN or Inf are replaced by 0. 30 | """ 31 | result = a / b 32 | result[result != result] = .0 33 | result[result == np.inf] = .0 34 | return result 35 | 36 | 37 | class mape_loss(nn.Module): 38 | def __init__(self): 39 | super(mape_loss, self).__init__() 40 | 41 | def forward(self, insample: t.Tensor, freq: int, 42 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 43 | """ 44 | MAPE loss as defined in: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error 45 | 46 | :param forecast: Forecast values. Shape: batch, time 47 | :param target: Target values. Shape: batch, time 48 | :param mask: 0/1 mask. Shape: batch, time 49 | :return: Loss value 50 | """ 51 | weights = divide_no_nan(mask, target) 52 | return t.mean(t.abs((forecast - target) * weights)) 53 | 54 | 55 | class smape_loss(nn.Module): 56 | def __init__(self): 57 | super(smape_loss, self).__init__() 58 | 59 | def forward(self, insample: t.Tensor, freq: int, 60 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 61 | """ 62 | sMAPE loss as defined in https://robjhyndman.com/hyndsight/smape/ (Makridakis 1993) 63 | 64 | :param forecast: Forecast values. Shape: batch, time 65 | :param target: Target values. Shape: batch, time 66 | :param mask: 0/1 mask. Shape: batch, time 67 | :return: Loss value 68 | """ 69 | return 200 * t.mean(divide_no_nan(t.abs(forecast - target), 70 | t.abs(forecast.data) + t.abs(target.data)) * mask) 71 | 72 | 73 | class mase_loss(nn.Module): 74 | def __init__(self): 75 | super(mase_loss, self).__init__() 76 | 77 | def forward(self, insample: t.Tensor, freq: int, 78 | forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float: 79 | """ 80 | MASE loss as defined in "Scaled Errors" https://robjhyndman.com/papers/mase.pdf 81 | 82 | :param insample: Insample values. Shape: batch, time_i 83 | :param freq: Frequency value 84 | :param forecast: Forecast values. Shape: batch, time_o 85 | :param target: Target values. Shape: batch, time_o 86 | :param mask: 0/1 mask. Shape: batch, time_o 87 | :return: Loss value 88 | """ 89 | masep = t.mean(t.abs(insample[:, freq:] - insample[:, :-freq]), dim=1) 90 | masked_masep_inv = divide_no_nan(mask, masep[:, None]) 91 | return t.mean(t.abs(target - forecast) * masked_masep_inv) 92 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import math 13 | import warnings 14 | from typing import List 15 | 16 | from torch.optim.lr_scheduler import LambdaLR, _LRScheduler 17 | from torch import nn as nn 18 | from torch.optim import Adam, Optimizer 19 | from torch.optim.lr_scheduler import _LRScheduler 20 | 21 | __all__ = ["LinearLR", "ExponentialLR"] 22 | 23 | 24 | class _LRSchedulerMONAI(_LRScheduler): 25 | """Base class for increasing the learning rate between two boundaries over a number 26 | of iterations""" 27 | 28 | def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: 29 | """ 30 | Args: 31 | optimizer: wrapped optimizer. 32 | end_lr: the final learning rate. 33 | num_iter: the number of iterations over which the test occurs. 34 | last_epoch: the index of last epoch. 35 | Returns: 36 | None 37 | """ 38 | self.end_lr = end_lr 39 | self.num_iter = num_iter 40 | super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) 41 | 42 | 43 | class LinearLR(_LRSchedulerMONAI): 44 | """Linearly increases the learning rate between two boundaries over a number of 45 | iterations. 46 | """ 47 | 48 | def get_lr(self): 49 | r = self.last_epoch / (self.num_iter - 1) 50 | return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] 51 | 52 | 53 | class ExponentialLR(_LRSchedulerMONAI): 54 | """Exponentially increases the learning rate between two boundaries over a number of 55 | iterations. 56 | """ 57 | 58 | def get_lr(self): 59 | r = self.last_epoch / (self.num_iter - 1) 60 | return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] 61 | 62 | 63 | class WarmupCosineSchedule(LambdaLR): 64 | """Linear warmup and then cosine decay. 65 | Based on https://huggingface.co/ implementation. 66 | """ 67 | 68 | def __init__( 69 | self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1 70 | ) -> None: 71 | """ 72 | Args: 73 | optimizer: wrapped optimizer. 74 | warmup_steps: number of warmup iterations. 75 | t_total: total number of training iterations. 76 | cycles: cosine cycles parameter. 77 | last_epoch: the index of last epoch. 78 | Returns: 79 | None 80 | """ 81 | self.warmup_steps = warmup_steps 82 | self.t_total = t_total 83 | self.cycles = cycles 84 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch) 85 | 86 | def lr_lambda(self, step): 87 | if step < self.warmup_steps: 88 | return float(step) / float(max(1.0, self.warmup_steps)) 89 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 90 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 91 | 92 | 93 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 94 | 95 | def __init__( 96 | self, 97 | optimizer: Optimizer, 98 | warmup_epochs: int, 99 | max_epochs: int, 100 | warmup_start_lr: float = 0.0, 101 | eta_min: float = 0.0, 102 | last_epoch: int = -1, 103 | ) -> None: 104 | """ 105 | Args: 106 | optimizer (Optimizer): Wrapped optimizer. 107 | warmup_epochs (int): Maximum number of iterations for linear warmup 108 | max_epochs (int): Maximum number of iterations 109 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 110 | eta_min (float): Minimum learning rate. Default: 0. 111 | last_epoch (int): The index of last epoch. Default: -1. 112 | """ 113 | self.warmup_epochs = warmup_epochs 114 | self.max_epochs = max_epochs 115 | self.warmup_start_lr = warmup_start_lr 116 | self.eta_min = eta_min 117 | 118 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 119 | 120 | def get_lr(self) -> List[float]: 121 | """ 122 | Compute learning rate using chainable form of the scheduler 123 | """ 124 | if not self._get_lr_called_within_step: 125 | warnings.warn( 126 | "To get the last learning rate computed by the scheduler, " 127 | "please use `get_last_lr()`.", 128 | UserWarning, 129 | ) 130 | 131 | if self.last_epoch == 0: 132 | return [self.warmup_start_lr] * len(self.base_lrs) 133 | elif self.last_epoch < self.warmup_epochs: 134 | return [ 135 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 136 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 137 | ] 138 | elif self.last_epoch == self.warmup_epochs: 139 | return self.base_lrs 140 | elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: 141 | return [ 142 | group["lr"] + (base_lr - self.eta_min) * 143 | (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 144 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 145 | ] 146 | 147 | return [ 148 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) / 149 | ( 150 | 1 + 151 | math.cos( 152 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)) 153 | ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups 154 | ] 155 | 156 | def _get_closed_form_lr(self) -> List[float]: 157 | """ 158 | Called when epoch is passed as a param to the `step` function of the scheduler. 159 | """ 160 | if self.last_epoch < self.warmup_epochs: 161 | return [ 162 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 163 | for base_lr in self.base_lrs 164 | ] 165 | 166 | return [ 167 | self.eta_min + 0.5 * (base_lr - self.eta_min) * 168 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 169 | for base_lr in self.base_lrs 170 | ] 171 | -------------------------------------------------------------------------------- /utils/m4_summary.py: -------------------------------------------------------------------------------- 1 | # This source code is provided for the purposes of scientific reproducibility 2 | # under the following limited license from Element AI Inc. The code is an 3 | # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis 4 | # expansion analysis for interpretable time series forecasting, 5 | # https://arxiv.org/abs/1905.10437). The copyright to the source code is 6 | # licensed under the Creative Commons - Attribution-NonCommercial 4.0 7 | # International license (CC BY-NC 4.0): 8 | # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether 9 | # for the benefit of third parties or internally in production) requires an 10 | # explicit license. The subject-matter of the N-BEATS model and associated 11 | # materials are the property of Element AI Inc. and may be subject to patent 12 | # protection. No license to patents is granted hereunder (whether express or 13 | # implied). Copyright 2020 Element AI Inc. All rights reserved. 14 | 15 | """ 16 | M4 Summary 17 | """ 18 | from collections import OrderedDict 19 | 20 | import numpy as np 21 | import pandas as pd 22 | 23 | from data_provider.m4 import M4Dataset 24 | from data_provider.m4 import M4Meta 25 | import os 26 | 27 | 28 | def group_values(values, groups, group_name): 29 | return np.array([v[~np.isnan(v)] for v in values[groups == group_name]]) 30 | 31 | 32 | def mase(forecast, insample, outsample, frequency): 33 | return np.mean(np.abs(forecast - outsample)) / np.mean(np.abs(insample[:-frequency] - insample[frequency:])) 34 | 35 | 36 | def smape_2(forecast, target): 37 | denom = np.abs(target) + np.abs(forecast) 38 | # divide by 1.0 instead of 0.0, in case when denom is zero the enumerator will be 0.0 anyway. 39 | denom[denom == 0.0] = 1.0 40 | return 200 * np.abs(forecast - target) / denom 41 | 42 | 43 | def mape(forecast, target): 44 | denom = np.abs(target) 45 | # divide by 1.0 instead of 0.0, in case when denom is zero the enumerator will be 0.0 anyway. 46 | denom[denom == 0.0] = 1.0 47 | return 100 * np.abs(forecast - target) / denom 48 | 49 | 50 | class M4Summary: 51 | def __init__(self, file_path, root_path): 52 | self.file_path = file_path 53 | self.training_set = M4Dataset.load(training=True, dataset_file=root_path) 54 | self.test_set = M4Dataset.load(training=False, dataset_file=root_path) 55 | self.naive_path = os.path.join(root_path, 'submission-Naive2.csv') 56 | 57 | def evaluate(self): 58 | """ 59 | Evaluate forecasts using M4 test dataset. 60 | 61 | :param forecast: Forecasts. Shape: timeseries, time. 62 | :return: sMAPE and OWA grouped by seasonal patterns. 63 | """ 64 | grouped_owa = OrderedDict() 65 | 66 | naive2_forecasts = pd.read_csv(self.naive_path).values[:, 1:].astype(np.float32) 67 | naive2_forecasts = np.array([v[~np.isnan(v)] for v in naive2_forecasts]) 68 | 69 | model_mases = {} 70 | naive2_smapes = {} 71 | naive2_mases = {} 72 | grouped_smapes = {} 73 | grouped_mapes = {} 74 | for group_name in M4Meta.seasonal_patterns: 75 | file_name = self.file_path + group_name + "_forecast.csv" 76 | if os.path.exists(file_name): 77 | model_forecast = pd.read_csv(file_name).values 78 | 79 | naive2_forecast = group_values(naive2_forecasts, self.test_set.groups, group_name) 80 | target = group_values(self.test_set.values, self.test_set.groups, group_name) 81 | # all timeseries within group have same frequency 82 | frequency = self.training_set.frequencies[self.test_set.groups == group_name][0] 83 | insample = group_values(self.training_set.values, self.test_set.groups, group_name) 84 | 85 | model_mases[group_name] = np.mean([mase(forecast=model_forecast[i], 86 | insample=insample[i], 87 | outsample=target[i], 88 | frequency=frequency) for i in range(len(model_forecast))]) 89 | naive2_mases[group_name] = np.mean([mase(forecast=naive2_forecast[i], 90 | insample=insample[i], 91 | outsample=target[i], 92 | frequency=frequency) for i in range(len(model_forecast))]) 93 | 94 | naive2_smapes[group_name] = np.mean(smape_2(naive2_forecast, target)) 95 | grouped_smapes[group_name] = np.mean(smape_2(forecast=model_forecast, target=target)) 96 | grouped_mapes[group_name] = np.mean(mape(forecast=model_forecast, target=target)) 97 | 98 | grouped_smapes = self.summarize_groups(grouped_smapes) 99 | grouped_mapes = self.summarize_groups(grouped_mapes) 100 | grouped_model_mases = self.summarize_groups(model_mases) 101 | grouped_naive2_smapes = self.summarize_groups(naive2_smapes) 102 | grouped_naive2_mases = self.summarize_groups(naive2_mases) 103 | for k in grouped_model_mases.keys(): 104 | grouped_owa[k] = (grouped_model_mases[k] / grouped_naive2_mases[k] + 105 | grouped_smapes[k] / grouped_naive2_smapes[k]) / 2 106 | 107 | def round_all(d): 108 | return dict(map(lambda kv: (kv[0], np.round(kv[1], 3)), d.items())) 109 | 110 | return round_all(grouped_smapes), round_all(grouped_owa), round_all(grouped_mapes), round_all( 111 | grouped_model_mases) 112 | 113 | def summarize_groups(self, scores): 114 | """ 115 | Re-group scores respecting M4 rules. 116 | :param scores: Scores per group. 117 | :return: Grouped scores. 118 | """ 119 | scores_summary = OrderedDict() 120 | 121 | def group_count(group_name): 122 | return len(np.where(self.test_set.groups == group_name)[0]) 123 | 124 | weighted_score = {} 125 | for g in ['Yearly', 'Quarterly', 'Monthly']: 126 | weighted_score[g] = scores[g] * group_count(g) 127 | scores_summary[g] = scores[g] 128 | 129 | others_score = 0 130 | others_count = 0 131 | for g in ['Weekly', 'Daily', 'Hourly']: 132 | others_score += scores[g] * group_count(g) 133 | others_count += group_count(g) 134 | weighted_score['Others'] = others_score 135 | scores_summary['Others'] = others_score / others_count 136 | 137 | average = np.sum(list(weighted_score.values())) / len(self.test_set.groups) 138 | scores_summary['Average'] = average 139 | 140 | return scores_summary 141 | -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TriangularCausalMask(): 5 | def __init__(self, B, L, device="cpu"): 6 | mask_shape = [B, 1, L, L] 7 | with torch.no_grad(): 8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 9 | 10 | @property 11 | def mask(self): 12 | return self._mask 13 | 14 | 15 | class ProbMask(): 16 | def __init__(self, B, H, L, index, scores, device="cpu"): 17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 19 | indicator = _mask_ex[torch.arange(B)[:, None, None], 20 | torch.arange(H)[None, :, None], 21 | index, :].to(device) 22 | self._mask = indicator.view(scores.shape).to(device) 23 | 24 | @property 25 | def mask(self): 26 | return self._mask 27 | 28 | 29 | def random_shuffle(x, ids_shuffle=None, return_ids_shuffle=False, mask_ratio=0): 30 | """ 31 | Perform per-sample random masking by per-sample shuffling. 32 | Per-sample shuffling is done by argsort random noise. 33 | x: [N, L, D], sequence 34 | """ 35 | N, L, D = x.shape # batch, length, dim 36 | len_keep = int(L * (1 - mask_ratio)) 37 | 38 | if ids_shuffle is None: 39 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 40 | 41 | # sort noise for each sample 42 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 43 | 44 | # keep the first subset 45 | ids_shuffle_ = ids_shuffle[:, :len_keep] 46 | 47 | if ids_shuffle_.shape[0] != x.shape[0]: 48 | ids_shuffle_ = ids_shuffle_[[0]].repeat(x.shape[0], 1) 49 | 50 | ids_restore = torch.argsort(ids_shuffle_, dim=1) 51 | 52 | x_shuffled = torch.gather(x, dim=1, index=ids_shuffle_.unsqueeze(-1).repeat(1, 1, D)) 53 | 54 | if return_ids_shuffle: 55 | return x_shuffled, ids_shuffle, ids_restore 56 | else: 57 | return x_shuffled, ids_restore 58 | 59 | 60 | def unshuffle(x, ids_restore): 61 | x_restore = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[-1])) 62 | return x_restore 63 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 11 | return (u / d).mean(-1) 12 | 13 | 14 | def MAE(pred, true): 15 | return np.mean(np.abs(pred - true)) 16 | 17 | 18 | def MSE(pred, true): 19 | return np.mean((pred - true) ** 2) 20 | 21 | 22 | def RMSE(pred, true): 23 | return np.sqrt(MSE(pred, true)) 24 | 25 | 26 | 27 | def metric_with_mask(pred, true): 28 | # mask 29 | mask = true != 0 30 | pred = pred[mask] 31 | true = true[mask] 32 | # compute 33 | mae = MAE(pred, true) 34 | mse = MSE(pred, true) 35 | rmse = RMSE(pred, true) 36 | mape = MAPE(pred, true) 37 | mspe = MSPE(pred, true) 38 | 39 | return mae, mse, rmse, mape, mspe 40 | 41 | def MAPE(pred, true): 42 | return np.mean(np.abs((pred - true) / true)) 43 | 44 | 45 | def MSPE(pred, true): 46 | return np.mean(np.square((pred - true) / true)) 47 | 48 | 49 | def metric(pred, true): 50 | mae = MAE(pred, true) 51 | mse = MSE(pred, true) 52 | rmse = RMSE(pred, true) 53 | mape = MAPE(pred, true) 54 | mspe = MSPE(pred, true) 55 | 56 | return mae, mse, rmse, mape, mspe 57 | -------------------------------------------------------------------------------- /utils/optim_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_optimizer(optim_name, model, args): 5 | optim_name = optim_name.lower() 6 | model_weights = filter(lambda p: p.requires_grad, model.parameters()) 7 | 8 | if optim_name == 'adam': 9 | optimizer = torch.optim.Adam(model_weights, lr=args.learning_rate, weight_decay=args.reg_weight) 10 | elif optim_name == 'adamw': 11 | optimizer = torch.optim.AdamW(model_weights, lr=args.learning_rate, weight_decay=args.reg_weight) 12 | elif optim_name == 'sgd': 13 | optimizer = torch.optim.SGD(model_weights, lr=args.learning_rate, momentum=args.momentum, 14 | nesterov=True, weight_decay=args.reg_weight) 15 | elif optim_name == 'lion': 16 | from lion_pytorch import Lion 17 | optimizer = Lion(model_weights, lr=args.learning_rate, weight_decay=args.reg_weight) 18 | elif optim_name == 'adabound': 19 | import adabound 20 | optimizer = adabound.AdaBound(model_weights, lr=args.learning_rate, final_lr=0.1) 21 | elif optim_name == 'prodigy': 22 | from prodigyopt import Prodigy 23 | # you can choose weight decay value based on your problem, 0 by default 24 | optimizer = Prodigy(model_weights, lr=1., weight_decay=args.reg_weight) 25 | elif optim_name == 'd_adam': 26 | from dadaptation import DAdaptAdam 27 | optimizer = DAdaptAdam(model_weights, lr=1., ) 28 | elif optim_name == 'd_lion': 29 | from dadaptation import DAdaptLion 30 | optimizer = DAdaptLion(model_weights, lr=1., ) 31 | else: 32 | raise ValueError('Unsupported Optimization Procedure: ' + str(optim_name)) 33 | 34 | return optimizer 35 | 36 | 37 | def get_lr_scheduler(optimizer, total_steps, args): 38 | if args.lr_schedule == 'warmup_cosine': 39 | from utils.lr_scheduler import LinearWarmupCosineAnnealingLR 40 | scheduler = LinearWarmupCosineAnnealingLR(optimizer, 41 | warmup_epochs=args.warmup_steps, 42 | max_epochs=total_steps) 43 | elif args.lr_schedule == 'cosine_anneal': 44 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps) 45 | # if args.checkpoint is not None: 46 | # scheduler.step(epoch=start_epoch) 47 | elif args.lr_schedule == 'poly': 48 | def lambdas(epoch): 49 | return (1 - float(epoch) / float(total_steps)) ** 0.9 50 | 51 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambdas) 52 | else: 53 | raise NotImplementedError 54 | return scheduler 55 | -------------------------------------------------------------------------------- /utils/print_args.py: -------------------------------------------------------------------------------- 1 | def print_args(args): 2 | print("\033[1m" + "Basic Config" + "\033[0m") 3 | print(f' {"Task Name:":<20}{args.task_name:<20}{"Is Training:":<20}{args.is_training:<20}') 4 | print(f' {"Model ID:":<20}{args.model_id:<20}{"Model:":<20}{args.model:<20}') 5 | print() 6 | 7 | print("\033[1m" + "Data Loader" + "\033[0m") 8 | print(f' {"Data:":<20}{args.data:<20}{"Root Path:":<20}{args.root_path:<20}') 9 | print(f' {"Data Path:":<20}{args.data_path:<20}{"Features:":<20}{args.features:<20}') 10 | print(f' {"Target:":<20}{args.target:<20}{"Freq:":<20}{args.freq:<20}') 11 | print(f' {"Checkpoints:":<20}{args.checkpoints:<20}') 12 | print() 13 | 14 | if args.task_name in ['long_term_forecast', 'short_term_forecast']: 15 | print("\033[1m" + "Forecasting Task" + "\033[0m") 16 | print(f' {"Seq Len:":<20}{args.seq_len:<20}{"Label Len:":<20}{args.label_len:<20}') 17 | print(f' {"Pred Len:":<20}{args.pred_len:<20}{"Seasonal Patterns:":<20}{args.seasonal_patterns:<20}') 18 | print(f' {"Inverse:":<20}{args.inverse:<20}') 19 | print() 20 | 21 | if args.task_name == 'imputation': 22 | print("\033[1m" + "Imputation Task" + "\033[0m") 23 | print(f' {"Mask Rate:":<20}{args.mask_rate:<20}') 24 | print() 25 | 26 | if args.task_name == 'anomaly_detection': 27 | print("\033[1m" + "Anomaly Detection Task" + "\033[0m") 28 | print(f' {"Anomaly Ratio:":<20}{args.anomaly_ratio:<20}') 29 | print() 30 | 31 | print("\033[1m" + "Model Parameters" + "\033[0m") 32 | print(f' {"Top k:":<20}{args.top_k:<20}{"Num Kernels:":<20}{args.num_kernels:<20}') 33 | print(f' {"Enc In:":<20}{args.enc_in:<20}{"Dec In:":<20}{args.dec_in:<20}') 34 | print(f' {"C Out:":<20}{args.c_out:<20}{"d model:":<20}{args.d_model:<20}') 35 | print(f' {"n heads:":<20}{args.n_heads:<20}{"e layers:":<20}{args.e_layers:<20}') 36 | print(f' {"d layers:":<20}{args.d_layers:<20}{"d FF:":<20}{args.d_ff:<20}') 37 | print(f' {"Moving Avg:":<20}{args.moving_avg:<20}{"Factor:":<20}{args.factor:<20}') 38 | print(f' {"Distil:":<20}{args.distil:<20}{"Dropout:":<20}{args.dropout:<20}') 39 | print(f' {"Embed:":<20}{args.embed:<20}{"Activation:":<20}{args.activation:<20}') 40 | print(f' {"Output Attention:":<20}{args.output_attention:<20}') 41 | print() 42 | 43 | print("\033[1m" + "Run Parameters" + "\033[0m") 44 | print(f' {"Num Workers:":<20}{args.num_workers:<20}{"Itr:":<20}{args.itr:<20}') 45 | print(f' {"Train Epochs:":<20}{args.train_epochs:<20}{"Batch Size:":<20}{args.batch_size:<20}') 46 | print(f' {"Patience:":<20}{args.patience:<20}{"Learning Rate:":<20}{args.learning_rate:<20}') 47 | print(f' {"Des:":<20}{args.des:<20}{"Loss:":<20}{args.loss:<20}') 48 | print(f' {"Lradj:":<20}{args.lradj:<20}{"Use Amp:":<20}{args.use_amp:<20}') 49 | print() 50 | 51 | print("\033[1m" + "GPU" + "\033[0m") 52 | print(f' {"Use GPU:":<20}{args.use_gpu:<20}{"GPU:":<20}{args.gpu:<20}') 53 | print(f' {"Use Multi GPU:":<20}{args.use_multi_gpu:<20}{"Devices:":<20}{args.devices:<20}') 54 | print() 55 | 56 | print("\033[1m" + "De-stationary Projector Params" + "\033[0m") 57 | p_hidden_dims_str = ', '.join(map(str, args.p_hidden_dims)) 58 | print(f' {"P Hidden Dims:":<20}{p_hidden_dims_str:<20}{"P Hidden Layers:":<20}{args.p_hidden_layers:<20}') 59 | print() 60 | -------------------------------------------------------------------------------- /utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | # From: gluonts/src/gluonts/time_feature/_base.py 2 | # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"). 5 | # You may not use this file except in compliance with the License. 6 | # A copy of the License is located at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # or in the "license" file accompanying this file. This file is distributed 11 | # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 12 | # express or implied. See the License for the specific language governing 13 | # permissions and limitations under the License. 14 | 15 | from typing import List 16 | 17 | import numpy as np 18 | import pandas as pd 19 | from pandas.tseries import offsets 20 | from pandas.tseries.frequencies import to_offset 21 | 22 | 23 | class TimeFeature: 24 | def __init__(self): 25 | pass 26 | 27 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 28 | pass 29 | 30 | def __repr__(self): 31 | return self.__class__.__name__ + "()" 32 | 33 | 34 | class SecondOfMinute(TimeFeature): 35 | """Minute of hour encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.second / 59.0 - 0.5 39 | 40 | 41 | class MinuteOfHour(TimeFeature): 42 | """Minute of hour encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.minute / 59.0 - 0.5 46 | 47 | 48 | class HourOfDay(TimeFeature): 49 | """Hour of day encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return index.hour / 23.0 - 0.5 53 | 54 | 55 | class DayOfWeek(TimeFeature): 56 | """Hour of day encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return index.dayofweek / 6.0 - 0.5 60 | 61 | 62 | class DayOfMonth(TimeFeature): 63 | """Day of month encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.day - 1) / 30.0 - 0.5 67 | 68 | 69 | class DayOfYear(TimeFeature): 70 | """Day of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.dayofyear - 1) / 365.0 - 0.5 74 | 75 | 76 | class MonthOfYear(TimeFeature): 77 | """Month of year encoded as value between [-0.5, 0.5]""" 78 | 79 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 80 | return (index.month - 1) / 11.0 - 0.5 81 | 82 | 83 | class WeekOfYear(TimeFeature): 84 | """Week of year encoded as value between [-0.5, 0.5]""" 85 | 86 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 87 | return (index.isocalendar().week - 1) / 52.0 - 0.5 88 | 89 | 90 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 91 | """ 92 | Returns a list of time features that will be appropriate for the given frequency string. 93 | Parameters 94 | ---------- 95 | freq_str 96 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 97 | """ 98 | 99 | features_by_offsets = { 100 | offsets.YearEnd: [], 101 | offsets.QuarterEnd: [MonthOfYear], 102 | offsets.MonthEnd: [MonthOfYear], 103 | offsets.Week: [DayOfMonth, WeekOfYear], 104 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 105 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 106 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 107 | offsets.Minute: [ 108 | MinuteOfHour, 109 | HourOfDay, 110 | DayOfWeek, 111 | DayOfMonth, 112 | DayOfYear, 113 | ], 114 | offsets.Second: [ 115 | SecondOfMinute, 116 | MinuteOfHour, 117 | HourOfDay, 118 | DayOfWeek, 119 | DayOfMonth, 120 | DayOfYear, 121 | ], 122 | } 123 | 124 | offset = to_offset(freq_str) 125 | 126 | for offset_type, feature_classes in features_by_offsets.items(): 127 | if isinstance(offset, offset_type): 128 | return [cls() for cls in feature_classes] 129 | 130 | supported_freq_msg = f""" 131 | Unsupported frequency {freq_str} 132 | The following frequencies are supported: 133 | Y - yearly 134 | alias: A 135 | M - monthly 136 | W - weekly 137 | D - daily 138 | B - business days 139 | H - hourly 140 | T - minutely 141 | alias: min 142 | S - secondly 143 | """ 144 | raise RuntimeError(supported_freq_msg) 145 | 146 | 147 | def time_features(dates, freq='h'): 148 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 149 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import math 8 | from random import shuffle, choice 9 | import subprocess 10 | 11 | plt.switch_backend('agg') 12 | 13 | 14 | def set_seed(seed): 15 | random.seed(seed) 16 | torch.manual_seed(seed) 17 | np.random.seed(seed) 18 | 19 | 20 | def adjust_learning_rate(optimizer, epoch, args): 21 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 22 | if args.lradj == 'type1': 23 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 24 | elif args.lradj == 'type2': 25 | lr_adjust = { 26 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 27 | 10: 5e-7, 15: 1e-7, 20: 5e-8 28 | } 29 | elif args.lradj == "cosine": 30 | lr_adjust = {epoch: args.learning_rate / 2 * (1 + math.cos(epoch / args.train_epochs * math.pi))} 31 | if epoch in lr_adjust.keys(): 32 | lr = lr_adjust[epoch] 33 | for param_group in optimizer.param_groups: 34 | param_group['lr'] = lr 35 | print('Updating learning rate to {}'.format(lr)) 36 | 37 | 38 | class EarlyStopping: 39 | def __init__(self, patience=7, verbose=False, delta=0): 40 | self.patience = patience 41 | self.verbose = verbose 42 | self.counter = 0 43 | self.best_score = None 44 | self.early_stop = False 45 | self.val_loss_min = np.Inf 46 | self.delta = delta 47 | 48 | def __call__(self, val_loss, model, path): 49 | score = -val_loss 50 | if self.best_score is None: 51 | self.best_score = score 52 | self.save_checkpoint(val_loss, model, path) 53 | elif score < self.best_score + self.delta: 54 | self.counter += 1 55 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 56 | if self.counter >= self.patience: 57 | self.early_stop = True 58 | else: 59 | self.best_score = score 60 | self.save_checkpoint(val_loss, model, path) 61 | self.counter = 0 62 | 63 | def save_checkpoint(self, val_loss, model, path): 64 | if self.verbose: 65 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 66 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 67 | self.val_loss_min = val_loss 68 | 69 | 70 | class dotdict(dict): 71 | """dot.notation access to dictionary attributes""" 72 | __getattr__ = dict.get 73 | __setattr__ = dict.__setitem__ 74 | __delattr__ = dict.__delitem__ 75 | 76 | 77 | class StandardScaler(): 78 | def __init__(self, mean, std): 79 | self.mean = mean 80 | self.std = std 81 | 82 | def transform(self, data): 83 | return (data - self.mean) / self.std 84 | 85 | def inverse_transform(self, data): 86 | return (data * self.std) + self.mean 87 | 88 | 89 | def visual(true, preds=None, name='./pic/test.pdf'): 90 | """ 91 | Results visualization 92 | """ 93 | plt.figure() 94 | plt.plot(true, label='GroundTruth', linewidth=2) 95 | if preds is not None: 96 | plt.plot(preds, label='Prediction', linewidth=2) 97 | plt.legend() 98 | plt.savefig(name, bbox_inches='tight') 99 | 100 | 101 | def adjustment(gt, pred): 102 | anomaly_state = False 103 | for i in range(len(gt)): 104 | if gt[i] == 1 and pred[i] == 1 and not anomaly_state: 105 | anomaly_state = True 106 | for j in range(i, 0, -1): 107 | if gt[j] == 0: 108 | break 109 | else: 110 | if pred[j] == 0: 111 | pred[j] = 1 112 | for j in range(i, len(gt)): 113 | if gt[j] == 0: 114 | break 115 | else: 116 | if pred[j] == 0: 117 | pred[j] = 1 118 | elif gt[i] == 0: 119 | anomaly_state = False 120 | if anomaly_state: 121 | pred[i] = 1 122 | return gt, pred 123 | 124 | 125 | def cal_accuracy(y_pred, y_true): 126 | return np.mean(y_pred == y_true) 127 | 128 | 129 | def get_gpu_usage(): 130 | result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE) 131 | output = result.stdout.decode('utf-8') 132 | lines = output.split('\n') 133 | 134 | for line in lines: 135 | if 'MiB /' in line: 136 | # print(line) 137 | 138 | gpu_usage = line.split('MiB / ')[0] 139 | gpu_usage = gpu_usage.split('|')[-1] 140 | gpu_usage = gpu_usage.strip() 141 | 142 | return (int(gpu_usage) - 10) / 1000 143 | --------------------------------------------------------------------------------