├── .gitignore ├── README.md ├── data ├── data_loader.py ├── data_loader_delay.py ├── data_loader_mix.py └── data_loader_v3.py ├── exp ├── __init__.py ├── exp_basic.py ├── exp_cross_former.py ├── exp_data_visualize.py ├── exp_derpp.py ├── exp_dlinear.py ├── exp_er.py ├── exp_fedformer.py ├── exp_fsnet.py ├── exp_fsnet_d3a.py ├── exp_fsnet_time.py ├── exp_large.py ├── exp_naive.py ├── exp_naive_time.py ├── exp_nomem.py ├── exp_ogd.py ├── exp_onenet_d3a.py ├── exp_onenet_egd.py ├── exp_onenet_fsnet.py ├── exp_onenet_gate.py ├── exp_onenet_linear_regression.py ├── exp_onenet_minus.py ├── exp_onenet_tcn.py ├── exp_onenet_tcn_minus.py ├── exp_onenet_weight.py ├── exp_patch.py ├── exp_patch_fsnet.py ├── exp_patch_fsnet_time.py ├── exp_patch_tcn.py ├── exp_patch_tcn_er.py ├── exp_patch_tcn_learn_linear.py ├── exp_patch_tcn_moe.py ├── exp_patch_tcn_time.py └── exp_ts2vec.py ├── framework.png ├── layers ├── AutoCorrelation.py ├── Autoformer_EncDec.py ├── Embed.py ├── FourierCorrelation.py ├── MultiWaveletCorrelation.py ├── PatchTST_backbone.py ├── PatchTST_layers.py ├── RevIN.py ├── SelfAttention_Family.py ├── Transformer_EncDec.py └── utils.py ├── main.py ├── models ├── Autoformer.py ├── DLinear.py ├── FEDformer.py ├── Informer.py ├── PatchTST.py ├── __init__.py ├── attn.py ├── cross_models │ ├── attn.py │ ├── cross_decoder.py │ ├── cross_embed.py │ ├── cross_encoder.py │ └── cross_former.py ├── decoder.py ├── embed.py ├── encoder.py ├── model.py └── ts2vec │ ├── __init__.py │ ├── dev.py │ ├── dilated_conv.py │ ├── encoder.py │ ├── fsnet.py │ ├── fsnet_.py │ ├── losses.py │ ├── ncca.py │ ├── ncca_.py │ └── nomem.py ├── onenet_result.png ├── requirement.txt ├── run.sh ├── run_d3a.sh ├── teaser_d3a.png └── utils ├── Adbfgs.py ├── __init__.py ├── augmentations.py ├── buffer.py ├── detector.py ├── masking.py ├── metrics.py ├── timefeatures.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .DS_Store? 3 | .idea 4 | *.pyc 5 | **/__pycache__/ 6 | checkpoints/ 7 | results/ 8 | results1/ 9 | imgs/* 10 | *.pdf 11 | *.out 12 | *.csv -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # (NeurIPS 2023) OneNet: Enhancing Time Series Forecasting Models under Concept Drift by Online Ensembling 2 | 3 | This codebase is the official implementation of [`OneNet: Enhancing Time Series Forecasting Models under Concept Drift by Online Ensembling`](https://arxiv.org/abs/2309.12659) (**NeurIPS 2023**) and [`Addressing Concept Shift in Online Time Series Forecasting: Detect-then-Adapt`](https://arxiv.org/abs/2403.14949) 4 | 5 | 6 | ## 🔥 Update 7 | * [2023-09-22]: ⭐️ Paper online. Check out [Detect-then-Adapt](https://arxiv.org/abs/2403.14949) for details. 8 | * [2023-09-22]: ⭐️ Paper online. Check out [OneNet](https://arxiv.org/abs/2309.12659) for details. 9 | * [2023-09-20]: 🚀🚀 Codes released. 10 | 11 | 12 | ## Introduction for OneNet 13 | 14 | Online updating of time series forecasting models aims to address the **concept drifting problem** by efficiently updating forecasting models based on streaming data. Many algorithms are designed for online time series forecasting, with some exploiting **cross-variable dependency** while others assume **independence among variables**. Given every data assumption has its own pros and cons in online time series modeling, we propose **On**line **e**nsembling **Net**work (**OneNet**). It dynamically updates and combines two models, with one focusing on modeling the dependency across the time dimension and the other on cross-variate dependency. Our method incorporates a reinforcement learning-based approach into the traditional online convex programming framework, allowing for the linear combination of the two models with dynamically adjusted weights. **OneNet** addresses the main shortcomings of classical online learning methods that tend to be slow in adapting to the concept drift. Empirical results show that OneNet reduces online forecasting error by more than $50$% compared to the State-Of-The-Art (SOTA) method. 15 | 16 | ![OneNet](framework.png) 17 | 18 | 1) The proposed OneNet-TCN (online ensembling of TCN and Time-TCN) surpasses most of the competing baselines across various forecasting horizons; 19 | 2) If the combined branches are stronger, for example, OneNet combined FSNet and Time-FSNet, achieving much better performance than OneNet-TCN. Namely, OneNet can integrate any advanced online forecasting methods or representation learning structures to enhance the robustness of the model. 20 | 3) The average MSE and MAE of OneNet are significantly better than using either branch (FSNet or Time-TCN) alone, which underscores the significance of incorporating online ensembling. 21 | 4) OneNet achieves faster and better convergence than other methods; 22 | 23 | ## Introduction for Detect-then-Adapt 24 | While numerous algorithms have been developed, most of them focus on model design and updating. In practice, many of these methods struggle with continuous performance regression in the face of accumulated concept drifts over time. We first detects drifting conception and then aggressively adapts the current model to the drifted concepts after the detection for rapid adaption. Our empirical studies across six datasets demonstrate the effectiveness of in improving model adaptation capability. Notably, compared to a simple Temporal Convolutional Network (TCN) baseline, $D^3A$ reduces the average Mean Squared Error (MSE) by $43.9$%. For the state-of-the-art (SOTA) model, the MSE is reduced by $33.3$%. 25 | 26 | ![Detect-then-Adapt](teaser_d3a.png) 27 | 28 | 1) **Introduce a Concept Detection Framework:** Our framework monitors loss distribution drift, aiming to predict the occurrence of concept drift. This detector provides instructions for our model updating, enhancing model robustness and AI safety, particularly in high-risk tasks. 29 | 30 | 2) **More realistic Evaluation setting:** We observe that previous benchmarks often presume a substantial overlap in the forecasting target during testing. In this paper, we advocate for the evaluation of online time series forecasting models with delayed feedback, demonstrating a more realistic and challenging assessment. 31 | 32 | ## Requirements 33 | 34 | - python == 3.7.3 35 | - pytorch == 1.8.0 36 | - matplotlib == 3.1.1 37 | - numpy == 1.19.4 38 | - pandas == 0.25.1 39 | - scikit_learn == 0.21.3 40 | - tqdm == 4.62.3 41 | - einops == 0.4.0 42 | 43 | ## Benchmarking 44 | 45 | ### 1. Data preparation 46 | 47 | We follow the same data formatting as the Informer repo (https://github.com/zhouhaoyi/Informer2020), which also hosts the raw data. 48 | Please put all raw data (csv) files in the ```./data``` folder. 49 | 50 | ### 2. Run experiments 51 | 52 | To replicate our results on the ETT, ECL, Traffic, and WTH datasets, run 53 | ``` 54 | sh run.sh 55 | ``` 56 | 57 | 58 | To replicate our results of $D^3A$, run 59 | ``` 60 | sh run_d3a.sh 61 | ``` 62 | 63 | ### 3. Arguments 64 | 65 | You can specify one of the above method via the ```--method``` argument. 66 | 67 | **Dataset:** Our implementation currently supports the following datasets: Electricity Transformer - ETT (including ETTh1, ETTh2, ETTm1, and ETTm2), ECL, Traffic, and WTH. You can specify the dataset via the ```--data``` argument. 68 | 69 | **Other arguments:** Other useful arguments for experiments are: 70 | - ```--test_bsz```: batch size used for testing: must be set to **1** for online learning, 71 | - ```--seq_len```: look-back windows' length, set to **60** by default, 72 | - ```--pred_len```: forecast windows' length, set to **1** for online learning. 73 | 74 | 75 | **D3A Arguments:** 76 | Here are additional arguments useful for experiments: 77 | 78 | - `--sleep_interval`: Corresponds to \( l_w \) in our paper, representing the window size for the drift detector. 79 | - `--sleep_epochs`: Determines the number of epochs the model should be fully fine-tuned when a drift is detected. It is set to **20** by default. 80 | - `--online_adjust`: After detecting a drift, the regularization weight \( \lambda \) in our paper is set to **0.5** by default. 81 | - `--offline_adjust`: During each step, the algorithm samples previous data and augments it for regularization. The regularization weight is set to **0.5** by default. 82 | - `--alpha_d`: Represents a predefined confidence level for triggering concept drift, set to **0.003** by default. 83 | 84 | ### 4. Baselines 85 | 86 | **Backbones:** Our implementation supports the following backbones in Table.1: 87 | 88 | - patch: PatchTST for online time series forecasting 89 | - fedformer: FedFormer for online time series forecasting 90 | - dlinear: DLinear for online time series forecasting 91 | - cross_former: Crossformer for online time series forecasting 92 | - naive_time: The proposed Time-TCN for online time series forecasting 93 | - naive_time: The proposed Time-TCN for online time series forecasting 94 | 95 | 96 | **Ablations:** Our online learning and ensembling ablation baselines in Table.4: 97 | - fsnet_plus_time: Simple averaging 98 | - onenet_gate: Gating mechanism 99 | - onenet_linear_regression: Linear Regression (LR) 100 | - onenet_egd: Exponentiated Gradient Descent (EGD) 101 | - onenet_weight: Reinforcement learning to learn the weight directly (RL-W) 102 | 103 | **Algorithms:** Our implementation supports the following training strategies in Table.2,3: 104 | - ogd: OGD training 105 | - large: OGD training with a large backbone 106 | - er: experience replay 107 | - derpp: dark experience replay 108 | - nomem: FSNET without the associative memory 109 | - naive: FSNET without both the memory and adapter, directly trains the adaptation coefficients. 110 | - fsnet: FSNet framework 111 | - fsnet_d3a: FSNet with Detect-then-Adapt framework 112 | - fsnet_time: Cross-Time FSNet 113 | - onenet_minus: the proposed OneNet- in section 4 114 | - onenet_tcn: the proposed OneNet with tcn backbone 115 | - onenet_fsnet: the proposed OneNet 116 | - onenet_d3a: the proposed OneNet with Detect-then-Adapt framework 117 | 118 | 119 | ### 5. Baselines 120 | 121 | ## License 122 | 123 | This source code is released under the MIT license, included [here](LICENSE). 124 | 125 | ### Citation 126 | If you find this repo useful, please consider citing: 127 | ``` 128 | @inproceedings{ 129 | zhang2023onenet, 130 | title={OneNet: Enhancing Time Series Forecasting Models under Concept Drift by Online Ensembling}, 131 | author={YiFan Zhang and Qingsong Wen and Xue Wang and Weiqi Chen and Liang Sun and Zhang Zhang and Liang Wang and Rong Jin and Tieniu Tan}, 132 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 133 | year={2023} 134 | } 135 | 136 | @misc{zhang2024addressing, 137 | title={Addressing Concept Shift in Online Time Series Forecasting: Detect-then-Adapt}, 138 | author={YiFan Zhang and Weiqi Chen and Zhaoyang Zhu and Dalin Qin and Liang Sun and Xue Wang and Qingsong Wen and Zhang Zhang and Liang Wang and Rong Jin}, 139 | year={2024}, 140 | eprint={2403.14949}, 141 | archivePrefix={arXiv}, 142 | primaryClass={cs.LG} 143 | } 144 | ``` 145 | -------------------------------------------------------------------------------- /data/data_loader_mix.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | from torchvision.datasets import ImageFolder 5 | import bisect 6 | import torch 7 | from typing import TypeVar 8 | 9 | T_co = TypeVar('T_co', covariant=True) 10 | class ConcatDataset(Dataset[T_co]): 11 | r"""Dataset as a concatenation of multiple datasets. 12 | 13 | This class is useful to assemble different existing datasets. 14 | 15 | Arguments: 16 | datasets (sequence): List of datasets to be concatenated 17 | """ 18 | 19 | @staticmethod 20 | def cumsum(sequence): 21 | r, s = [], 0 22 | for e in sequence: 23 | l = len(e) 24 | r.append(l + s) 25 | s += l 26 | return r 27 | 28 | def __init__(self, datasets) -> None: 29 | super(ConcatDataset, self).__init__() 30 | assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore 31 | self.datasets = list(datasets) 32 | self.cumulative_sizes = self.cumsum(self.datasets) 33 | self.domain_label = torch.randint(low=0, high=len(datasets), size=(self.cummulative_sizes[-1],)) 34 | 35 | def __len__(self): 36 | return self.cumulative_sizes[-1] 37 | 38 | def __getitem__(self, idx): 39 | if idx < 0: 40 | if -idx > len(self): 41 | raise ValueError("absolute value of index should not exceed dataset length") 42 | idx = len(self) + idx 43 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 44 | if dataset_idx == 0: 45 | sample_idx = idx 46 | else: 47 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 48 | seq_x, seq_y, seq_x_mark, seq_y_mark = self.datasets[dataset_idx].indices[sample_idx] 49 | return seq_x, seq_y, seq_x_mark, seq_y_mark 50 | 51 | @property 52 | def cummulative_sizes(self): 53 | return self.cumulative_sizes -------------------------------------------------------------------------------- /exp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzhang114/OneNet/076188ad4c1b28cbda8cb4b1e231499b233e1437/exp/__init__.py -------------------------------------------------------------------------------- /exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | class Exp_Basic(object): 6 | def __init__(self, args): 7 | self.args = args 8 | self.device = self._acquire_device() 9 | self.model = self._build_model().to(self.device) 10 | 11 | def _build_model(self): 12 | raise NotImplementedError 13 | return None 14 | 15 | def _acquire_device(self): 16 | if self.args.use_gpu: 17 | os.environ["CUDA_VISIBLE_DEVICES"] = str(self.args.gpu) if not self.args.use_multi_gpu else self.args.devices 18 | device = torch.device('cuda:{}'.format(self.args.gpu)) 19 | print('Use GPU: cuda:{}'.format(self.args.gpu)) 20 | else: 21 | device = torch.device('cpu') 22 | print('Use CPU') 23 | return device 24 | 25 | def _get_data(self): 26 | pass 27 | 28 | def vali(self): 29 | pass 30 | 31 | def train(self): 32 | pass 33 | 34 | def test(self): 35 | pass 36 | -------------------------------------------------------------------------------- /exp/exp_data_visualize.py: -------------------------------------------------------------------------------- 1 | from data.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred 2 | from exp.exp_basic import Exp_Basic 3 | 4 | import numpy as np 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | from matplotlib.backends.backend_pdf import PdfPages 8 | import pdb 9 | import numpy as np 10 | from einops import rearrange 11 | from collections import OrderedDict 12 | import time 13 | import torch 14 | import torch.nn as nn 15 | from torch import optim 16 | from torch.utils.data import DataLoader 17 | from collections import defaultdict 18 | from sklearn.linear_model import Ridge 19 | from sklearn.model_selection import GridSearchCV, train_test_split 20 | 21 | import os 22 | import time 23 | from pathlib import Path 24 | 25 | import warnings 26 | warnings.filterwarnings('ignore') 27 | 28 | 29 | __all__ = ['PatchTST'] 30 | 31 | # Cell 32 | from typing import Callable, Optional 33 | import torch 34 | from torch import nn 35 | from torch import Tensor 36 | import torch.nn.functional as F 37 | import numpy as np 38 | from utils.augmentations import Augmenter 39 | import math 40 | 41 | from layers.PatchTST_backbone import PatchTST_backbone 42 | from layers.PatchTST_layers import series_decomp 43 | 44 | class MMD_loss(nn.Module): 45 | def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5): 46 | super(MMD_loss, self).__init__() 47 | self.kernel_num = kernel_num 48 | self.kernel_mul = kernel_mul 49 | self.fix_sigma = None 50 | self.kernel_type = kernel_type 51 | 52 | def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 53 | n_samples = int(source.size()[0]) + int(target.size()[0]) 54 | total = torch.cat([source, target], dim=0) 55 | total0 = total.unsqueeze(0).expand( 56 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 57 | total1 = total.unsqueeze(1).expand( 58 | int(total.size(0)), int(total.size(0)), int(total.size(1))) 59 | L2_distance = ((total0 - total1) ** 2).sum(2) 60 | if fix_sigma: 61 | bandwidth = fix_sigma 62 | else: 63 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) 64 | bandwidth /= kernel_mul ** (kernel_num // 2) 65 | bandwidth_list = [bandwidth * (kernel_mul ** i) 66 | for i in range(kernel_num)] 67 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) 68 | for bandwidth_temp in bandwidth_list] 69 | return sum(kernel_val) 70 | 71 | def linear_mmd2(self, f_of_X, f_of_Y): 72 | loss = 0.0 73 | delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0) 74 | loss = delta.dot(delta.T) 75 | return loss 76 | 77 | def forward(self, source, target): 78 | if self.kernel_type == 'linear': 79 | return self.linear_mmd2(source, target) 80 | elif self.kernel_type == 'rbf': 81 | batch_size = int(source.size()[0]) 82 | kernels = self.guassian_kernel( 83 | source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 84 | with torch.no_grad(): 85 | XX = torch.mean(kernels[:batch_size, :batch_size]) 86 | YY = torch.mean(kernels[batch_size:, batch_size:]) 87 | XY = torch.mean(kernels[:batch_size, batch_size:]) 88 | YX = torch.mean(kernels[batch_size:, :batch_size]) 89 | loss = torch.mean(XX + YY - XY - YX) 90 | torch.cuda.empty_cache() 91 | return loss 92 | 93 | class Exp_TS2VecSupervised(Exp_Basic): 94 | def __init__(self, args): 95 | self.args = args 96 | self.input_channels_dim = args.enc_in 97 | self.device = self._acquire_device() 98 | self.online = args.online_learning 99 | assert self.online in ['none', 'full', 'regressor'] 100 | self.n_inner = args.n_inner 101 | self.opt_str = args.opt 102 | self.augmenter = None 103 | self.aug = args.aug 104 | self.MMD_loss = MMD_loss(kernel_mul=2.0, kernel_num=5) 105 | 106 | 107 | def _get_data(self, flag): 108 | args = self.args 109 | 110 | data_dict_ = { 111 | 'ETTh1': Dataset_ETT_hour, 112 | 'ETTh2': Dataset_ETT_hour, 113 | 'ETTm1': Dataset_ETT_minute, 114 | 'ETTm2': Dataset_ETT_minute, 115 | 'WTH': Dataset_Custom, 116 | 'ECL': Dataset_Custom, 117 | 'Solar': Dataset_Custom, 118 | 'custom': Dataset_Custom, 119 | } 120 | data_dict = defaultdict(lambda: Dataset_Custom, data_dict_) 121 | Data = data_dict[self.args.data] 122 | timeenc = 2 123 | 124 | if flag == 'test': 125 | shuffle_flag = False; 126 | drop_last = False; 127 | batch_size = args.test_bsz; 128 | freq = args.freq 129 | elif flag == 'val': 130 | shuffle_flag = False; 131 | drop_last = False; 132 | batch_size = args.batch_size; 133 | freq = args.detail_freq 134 | elif flag == 'pred': 135 | shuffle_flag = False; 136 | drop_last = False; 137 | batch_size = 1; 138 | freq = args.detail_freq 139 | Data = Dataset_Pred 140 | else: 141 | shuffle_flag = True; 142 | drop_last = True; 143 | batch_size = args.batch_size; 144 | freq = args.freq 145 | 146 | data_set = Data( 147 | root_path=args.root_path, 148 | data_path=args.data_path, 149 | flag=flag, 150 | size=[args.seq_len, args.label_len, args.pred_len], 151 | features=args.features, 152 | target=args.target, 153 | inverse=args.inverse, 154 | timeenc=timeenc, 155 | freq=freq, 156 | cols=args.cols 157 | ) 158 | print(flag, len(data_set)) 159 | data_loader = DataLoader( 160 | data_set, 161 | batch_size=batch_size, 162 | shuffle=shuffle_flag, 163 | num_workers=args.num_workers, 164 | drop_last=drop_last) 165 | 166 | return data_set, data_loader 167 | 168 | def train(self, setting): 169 | train_data, train_loader = self._get_data(flag='train') 170 | vali_data, vali_loader = self._get_data(flag='val') 171 | test_data, test_loader = self._get_data(flag='test') 172 | 173 | train_x, l_tr = train_data.data_x, train_data.data_x.shape[0] 174 | val_x, l_val = vali_data.data_x, vali_data.data_x.shape[0] 175 | test_x, l_te = test_data.data_x, test_data.data_x.shape[0] 176 | repeat = 10 177 | 178 | train_x_ = torch.from_numpy(train_x) 179 | test_x_ = torch.from_numpy(test_x) 180 | idx_train = np.arange(l_tr) 181 | idx_te = np.arange(l_te) 182 | l_idx = min(l_tr, l_te) 183 | mmd = 0 184 | for i in range(repeat): 185 | idx_tr = np.random.choice(idx_train, l_idx, replace=False) 186 | idx_te = np.random.choice(idx_te, l_idx, replace=False) 187 | 188 | mmd += self.MMD_loss(train_x_[idx_tr], test_x_[idx_te]) 189 | print('mmd distance is', mmd / repeat) 190 | exit() 191 | # step=5 192 | # channels = 4 193 | # name = f"{self.args.data}_channel{channels}_step{step}" 194 | # x = np.arange(train_x.shape[0]) 195 | # x_val = np.arange(train_x.shape[0], train_x.shape[0]+val_x.shape[0]) 196 | # x_te = np.arange(train_x.shape[0]+val_x.shape[0], train_x.shape[0]+test_x.shape[0]+val_x.shape[0]) 197 | # style_dict = { 198 | # '0':dict(linestyle='-', marker='o',markersize=0.1,color='#dd7e6b'), 199 | # '1':dict(linestyle='-',marker='o',markersize=0.1,color='#b6d7a8'), 200 | # '2':dict(linestyle='-',marker='o',markersize=0.1,color='#f9cb9c'), 201 | # '3':dict(linestyle='-',marker='o',markersize=0.1,color='#a4c2f4'), 202 | # '4':dict(linestyle='-',marker='o',markersize=0.1,color='#b4a7d6') 203 | # } 204 | # style_dict_te = { 205 | # '0':dict(linestyle='--', marker='+',markersize=0.1,color='#dd7e6b'), 206 | # '1':dict(linestyle='--',marker='+',markersize=0.1,color='#b6d7a8'), 207 | # '2':dict(linestyle='--',marker='+',markersize=0.1,color='#f9cb9c'), 208 | # '3':dict(linestyle='--',marker='+',markersize=0.1,color='#a4c2f4'), 209 | # '4':dict(linestyle='--',marker='+',markersize=0.1,color='#b4a7d6') 210 | # } 211 | 212 | # for i in range(channels-1, channels): 213 | # tr_x, te_x, v_x = train_x[:,i], test_x[:,i], val_x[:, i] 214 | # plt.plot(x[::step], tr_x[::step], **style_dict[str(i%5)], label=f'Train {i}') 215 | # plt.plot(x_val[::step], v_x[::step], **style_dict[str((i+1)%5)], label=f'Val {i}') 216 | # plt.plot(x_te[::step], te_x[::step], **style_dict_te[str((i+2)%5)], label=f'Test {i}') 217 | # plt.ylabel('Value')#, fontdict=font_y) 218 | # plt.xlabel('Time step')#, fontdict=font_y) 219 | # #plt.xlabel('Treatment Selection Bias', fontdict=font_y) 220 | # plt.legend() 221 | # plt.savefig(f'imgs/{name}.jpg') 222 | # exit() 223 | -------------------------------------------------------------------------------- /exp/exp_ts2vec.py: -------------------------------------------------------------------------------- 1 | from data.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred 2 | from exp.exp_basic import Exp_Basic 3 | from models.ts2vec.encoder import TSEncoder 4 | from models.ts2vec.losses import hierarchical_contrastive_loss 5 | from utils.metrics import metric 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch import optim 12 | from torch.utils.data import DataLoader 13 | 14 | from sklearn.linear_model import Ridge 15 | from sklearn.model_selection import GridSearchCV, train_test_split 16 | 17 | import os 18 | import time 19 | 20 | import warnings 21 | warnings.filterwarnings('ignore') 22 | 23 | 24 | def take_per_row(A, indx, num_elem): 25 | all_indx = indx[:,None] + np.arange(num_elem) 26 | return A[torch.arange(all_indx.shape[0])[:,None], all_indx] 27 | 28 | 29 | def fit_ridge(train_features, train_y, valid_features, valid_y, MAX_SAMPLES=100000): 30 | # If the training set is too large, subsample MAX_SAMPLES examples 31 | if train_features.shape[0] > MAX_SAMPLES: 32 | split = train_test_split( 33 | train_features, train_y, 34 | train_size=MAX_SAMPLES, random_state=0 35 | ) 36 | train_features = split[0] 37 | train_y = split[2] 38 | if valid_features.shape[0] > MAX_SAMPLES: 39 | split = train_test_split( 40 | valid_features, valid_y, 41 | train_size=MAX_SAMPLES, random_state=0 42 | ) 43 | valid_features = split[0] 44 | valid_y = split[2] 45 | alphas = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000] 46 | valid_results = [] 47 | for alpha in alphas: 48 | lr = Ridge(alpha=alpha).fit(train_features, train_y) 49 | valid_pred = lr.predict(valid_features) 50 | score = np.sqrt(((valid_pred - valid_y) ** 2).mean()) + np.abs(valid_pred - valid_y).mean() 51 | valid_results.append(score) 52 | best_alpha = alphas[np.argmin(valid_results)] 53 | 54 | lr = Ridge(alpha=best_alpha) 55 | lr.fit(train_features, train_y) 56 | return lr 57 | 58 | 59 | class Exp_TS2Vec(Exp_Basic): 60 | def __init__(self, args): 61 | self.args = args 62 | self.device = self._acquire_device() 63 | self._net = TSEncoder(input_dims=self.args.enc_in + 7, output_dims=self.args.c_out, 64 | hidden_dims=64, depth=10).to(self.device) 65 | self.net = torch.optim.swa_utils.AveragedModel(self._net) 66 | self.net.update_parameters(self._net) 67 | self.temporal_unit = 0 68 | 69 | def _get_data(self, flag, downstream=False): 70 | args = self.args 71 | 72 | data_dict = { 73 | 'ETTh1': Dataset_ETT_hour, 74 | 'ETTh2': Dataset_ETT_hour, 75 | 'ETTm1': Dataset_ETT_minute, 76 | 'ETTm2': Dataset_ETT_minute, 77 | 'WTH': Dataset_Custom, 78 | 'ECL': Dataset_Custom, 79 | 'Solar': Dataset_Custom, 80 | 'custom': Dataset_Custom, 81 | } 82 | Data = data_dict[self.args.data] 83 | timeenc = 2 84 | 85 | if flag in ('test', 'val') or downstream: 86 | shuffle_flag = False; 87 | drop_last = False; 88 | batch_size = args.batch_size; 89 | freq = args.freq 90 | elif flag == 'pred': 91 | shuffle_flag = False; 92 | drop_last = False; 93 | batch_size = 1; 94 | freq = args.detail_freq 95 | Data = Dataset_Pred 96 | else: 97 | shuffle_flag = True; 98 | drop_last = True; 99 | batch_size = args.batch_size; 100 | freq = args.freq 101 | 102 | if flag == 'train' and not downstream: 103 | data_set = Data( 104 | root_path=args.root_path, 105 | data_path=args.data_path, 106 | flag=flag, 107 | size=[3000, 0, 0], 108 | features=args.features, 109 | target=args.target, 110 | inverse=args.inverse, 111 | timeenc=timeenc, 112 | freq=freq, 113 | cols=args.cols 114 | ) 115 | else: 116 | data_set = Data( 117 | root_path=args.root_path, 118 | data_path=args.data_path, 119 | flag=flag, 120 | size=[args.seq_len, args.label_len, args.pred_len], 121 | features=args.features, 122 | target=args.target, 123 | inverse=args.inverse, 124 | timeenc=timeenc, 125 | freq=freq, 126 | cols=args.cols 127 | ) 128 | print(flag, len(data_set)) 129 | data_loader = DataLoader( 130 | data_set, 131 | batch_size=batch_size, 132 | shuffle=shuffle_flag, 133 | num_workers=args.num_workers, 134 | drop_last=drop_last) 135 | 136 | return data_set, data_loader 137 | 138 | def train(self, setting): 139 | train_data, train_loader = self._get_data(flag='train', downstream=False) 140 | 141 | path = os.path.join(self.args.checkpoints, setting) 142 | if not os.path.exists(path): 143 | os.makedirs(path) 144 | 145 | time_now = time.time() 146 | 147 | train_steps = len(train_loader) 148 | optimizer = torch.optim.AdamW(self._net.parameters(), lr=self.args.learning_rate) 149 | 150 | for epoch in range(self.args.train_epochs): 151 | iter_count = 0 152 | train_loss = [] 153 | 154 | self.net.train() 155 | epoch_time = time.time() 156 | for i, (batch_x, _, batch_x_mark, _) in enumerate(train_loader): 157 | iter_count += 1 158 | 159 | x = torch.cat([batch_x.float(), batch_x_mark.float()], dim=-1).to(self.device) 160 | 161 | ts_l = x.size(1) 162 | crop_l = np.random.randint(low=2 ** (self.temporal_unit + 1), high=ts_l + 1) 163 | crop_left = np.random.randint(ts_l - crop_l + 1) 164 | crop_right = crop_left + crop_l 165 | crop_eleft = np.random.randint(crop_left + 1) 166 | crop_eright = np.random.randint(low=crop_right, high=ts_l + 1) 167 | crop_offset = np.random.randint(low=-crop_eleft, high=ts_l - crop_eright + 1, size=x.size(0)) 168 | 169 | optimizer.zero_grad() 170 | 171 | out1 = self._net(take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft)) 172 | out1 = out1[:, -crop_l:] 173 | 174 | out2 = self._net(take_per_row(x, crop_offset + crop_left, crop_eright - crop_left)) 175 | out2 = out2[:, :crop_l] 176 | 177 | loss = hierarchical_contrastive_loss( 178 | out1, 179 | out2, 180 | temporal_unit=self.temporal_unit 181 | ) 182 | 183 | loss.backward() 184 | optimizer.step() 185 | self.net.update_parameters(self._net) 186 | 187 | train_loss.append(loss.item()) 188 | if (i+1) % 100==0: 189 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) 190 | speed = (time.time()-time_now)/iter_count 191 | left_time = speed*((self.args.train_epochs - epoch)*train_steps - i) 192 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) 193 | iter_count = 0 194 | time_now = time.time() 195 | 196 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)) 197 | train_loss = np.average(train_loss) 198 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f}".format( 199 | epoch + 1, train_steps, train_loss)) 200 | 201 | 202 | def test(self, setting): 203 | _, train_loader = self._get_data(flag='train', downstream=True) 204 | _, vali_loader = self._get_data(flag='val', downstream=True) 205 | _, test_loader = self._get_data(flag='test', downstream=True) 206 | 207 | self.net.eval() 208 | 209 | train_repr = [] 210 | valid_repr = [] 211 | test_repr = [] 212 | 213 | train_data = [] 214 | valid_data = [] 215 | test_data = [] 216 | 217 | for i, (batch_x, batch_y, batch_x_mark, _) in enumerate(train_loader): 218 | x = torch.cat([batch_x_mark.float(), batch_x.float()], dim=-1) 219 | x = x.to(self.device) 220 | out = self.net(x)[:, -1].detach().cpu() 221 | train_repr.append(out) 222 | train_data.append(batch_y.view(batch_y.size(0), -1)) 223 | 224 | for i, (batch_x, batch_y, batch_x_mark, _) in enumerate(vali_loader): 225 | x = torch.cat([batch_x_mark.float(), batch_x.float()], dim=-1) 226 | x = x.to(self.device) 227 | out = self.net(x)[:, -1].detach().cpu() 228 | valid_repr.append(out) 229 | valid_data.append(batch_y.view(batch_y.size(0), -1)) 230 | 231 | for i, (batch_x, batch_y, batch_x_mark, _) in enumerate(test_loader): 232 | x = torch.cat([batch_x_mark.float(), batch_x.float()], dim=-1) 233 | x = x.to(self.device) 234 | out = self.net(x)[:, -1].detach().cpu() 235 | test_repr.append(out) 236 | test_data.append(batch_y.view(batch_y.size(0), -1)) 237 | 238 | train_repr = torch.cat(train_repr, dim=0).numpy() 239 | train_data = torch.cat(train_data, dim=0).numpy() 240 | valid_repr = torch.cat(valid_repr, dim=0).numpy() 241 | valid_data = torch.cat(valid_data, dim=0).numpy() 242 | test_repr = torch.cat(test_repr, dim=0).numpy() 243 | test_data = torch.cat(test_data, dim=0).numpy() 244 | 245 | lr = fit_ridge(train_repr, train_data, valid_repr, valid_data, MAX_SAMPLES=100000) 246 | test_pred = lr.predict(test_repr) 247 | 248 | # result save 249 | folder_path = './results/' + setting + '/' 250 | if not os.path.exists(folder_path): 251 | os.makedirs(folder_path) 252 | mae, mse, rmse, mape, mspe = metric(test_pred, test_data) 253 | 254 | print('mse:{}, mae:{}'.format(mse, mae)) 255 | 256 | np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe])) 257 | 258 | return -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzhang114/OneNet/076188ad4c1b28cbda8cb4b1e231499b233e1437/framework.png -------------------------------------------------------------------------------- /layers/AutoCorrelation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import math 7 | from math import sqrt 8 | import os 9 | 10 | 11 | class AutoCorrelation(nn.Module): 12 | """ 13 | AutoCorrelation Mechanism with the following two phases: 14 | (1) period-based dependencies discovery 15 | (2) time delay aggregation 16 | This block can replace the self-attention family mechanism seamlessly. 17 | """ 18 | def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False): 19 | super(AutoCorrelation, self).__init__() 20 | self.factor = factor 21 | self.scale = scale 22 | self.mask_flag = mask_flag 23 | self.output_attention = output_attention 24 | self.dropout = nn.Dropout(attention_dropout) 25 | 26 | def time_delay_agg_training(self, values, corr): 27 | """ 28 | SpeedUp version of Autocorrelation (a batch-normalization style design) 29 | This is for the training phase. 30 | """ 31 | head = values.shape[1] 32 | channel = values.shape[2] 33 | length = values.shape[3] 34 | # find top k 35 | top_k = int(self.factor * math.log(length)) 36 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 37 | index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] 38 | weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) 39 | # update corr 40 | tmp_corr = torch.softmax(weights, dim=-1) 41 | # aggregation 42 | tmp_values = values 43 | delays_agg = torch.zeros_like(values).float() 44 | for i in range(top_k): 45 | pattern = torch.roll(tmp_values, -int(index[i]), -1) 46 | delays_agg = delays_agg + pattern * \ 47 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 48 | return delays_agg 49 | 50 | def time_delay_agg_inference(self, values, corr): 51 | """ 52 | SpeedUp version of Autocorrelation (a batch-normalization style design) 53 | This is for the inference phase. 54 | """ 55 | batch = values.shape[0] 56 | head = values.shape[1] 57 | channel = values.shape[2] 58 | length = values.shape[3] 59 | # index init 60 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 61 | # find top k 62 | top_k = int(self.factor * math.log(length)) 63 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) 64 | weights = torch.topk(mean_value, top_k, dim=-1)[0] 65 | delay = torch.topk(mean_value, top_k, dim=-1)[1] 66 | # update corr 67 | tmp_corr = torch.softmax(weights, dim=-1) 68 | # aggregation 69 | tmp_values = values.repeat(1, 1, 1, 2) 70 | delays_agg = torch.zeros_like(values).float() 71 | for i in range(top_k): 72 | tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) 73 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 74 | delays_agg = delays_agg + pattern * \ 75 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) 76 | return delays_agg 77 | 78 | def time_delay_agg_full(self, values, corr): 79 | """ 80 | Standard version of Autocorrelation 81 | """ 82 | batch = values.shape[0] 83 | head = values.shape[1] 84 | channel = values.shape[2] 85 | length = values.shape[3] 86 | # index init 87 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda() 88 | # find top k 89 | top_k = int(self.factor * math.log(length)) 90 | weights = torch.topk(corr, top_k, dim=-1)[0] 91 | delay = torch.topk(corr, top_k, dim=-1)[1] 92 | # update corr 93 | tmp_corr = torch.softmax(weights, dim=-1) 94 | # aggregation 95 | tmp_values = values.repeat(1, 1, 1, 2) 96 | delays_agg = torch.zeros_like(values).float() 97 | for i in range(top_k): 98 | tmp_delay = init_index + delay[..., i].unsqueeze(-1) 99 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) 100 | delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) 101 | return delays_agg 102 | 103 | def forward(self, queries, keys, values, attn_mask): 104 | B, L, H, E = queries.shape 105 | _, S, _, D = values.shape 106 | if L > S: 107 | zeros = torch.zeros_like(queries[:, :(L - S), :]).float() 108 | values = torch.cat([values, zeros], dim=1) 109 | keys = torch.cat([keys, zeros], dim=1) 110 | else: 111 | values = values[:, :L, :, :] 112 | keys = keys[:, :L, :, :] 113 | 114 | # period-based dependencies 115 | q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) 116 | k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) 117 | res = q_fft * torch.conj(k_fft) 118 | corr = torch.fft.irfft(res, dim=-1) 119 | 120 | # time delay agg 121 | if self.training: 122 | V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 123 | else: 124 | V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) 125 | 126 | if self.output_attention: 127 | return (V.contiguous(), corr.permute(0, 3, 1, 2)) 128 | else: 129 | return (V.contiguous(), None) 130 | 131 | 132 | class AutoCorrelationLayer(nn.Module): 133 | def __init__(self, correlation, d_model, n_heads, d_keys=None, 134 | d_values=None): 135 | super(AutoCorrelationLayer, self).__init__() 136 | 137 | d_keys = d_keys or (d_model // n_heads) 138 | d_values = d_values or (d_model // n_heads) 139 | 140 | self.inner_correlation = correlation 141 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 142 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 143 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 144 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 145 | self.n_heads = n_heads 146 | 147 | def forward(self, queries, keys, values, attn_mask): 148 | B, L, _ = queries.shape 149 | _, S, _ = keys.shape 150 | H = self.n_heads 151 | 152 | queries = self.query_projection(queries).view(B, L, H, -1) 153 | keys = self.key_projection(keys).view(B, S, H, -1) 154 | values = self.value_projection(values).view(B, S, H, -1) 155 | 156 | out, attn = self.inner_correlation( 157 | queries, 158 | keys, 159 | values, 160 | attn_mask 161 | ) 162 | out = out.view(B, L, -1) 163 | 164 | return self.out_projection(out), attn -------------------------------------------------------------------------------- /layers/Autoformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from layers.SelfAttention_Family import FullAttention 6 | 7 | 8 | class my_Layernorm(nn.Module): 9 | """ 10 | Special designed layernorm for the seasonal part 11 | """ 12 | def __init__(self, channels): 13 | super(my_Layernorm, self).__init__() 14 | self.layernorm = nn.LayerNorm(channels) 15 | 16 | def forward(self, x): 17 | x_hat = self.layernorm(x) 18 | bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) 19 | return x_hat - bias 20 | 21 | 22 | class moving_avg(nn.Module): 23 | """ 24 | Moving average block to highlight the trend of time series 25 | """ 26 | def __init__(self, kernel_size, stride): 27 | super(moving_avg, self).__init__() 28 | self.kernel_size = kernel_size 29 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 30 | 31 | def forward(self, x): 32 | # padding on the both ends of time series 33 | front = x[:, 0:1, :].repeat(1, self.kernel_size - 1-math.floor((self.kernel_size - 1) // 2), 1) 34 | end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2), 1) 35 | x = torch.cat([front, x, end], dim=1) 36 | x = self.avg(x.permute(0, 2, 1)) 37 | x = x.permute(0, 2, 1) 38 | return x 39 | 40 | 41 | class series_decomp(nn.Module): 42 | """ 43 | Series decomposition block 44 | """ 45 | def __init__(self, kernel_size): 46 | super(series_decomp, self).__init__() 47 | self.moving_avg = moving_avg(kernel_size, stride=1) 48 | 49 | def forward(self, x): 50 | moving_mean = self.moving_avg(x) 51 | res = x - moving_mean 52 | return res, moving_mean 53 | 54 | 55 | class series_decomp_multi(nn.Module): 56 | """ 57 | Series decomposition block 58 | """ 59 | def __init__(self, kernel_size): 60 | super(series_decomp_multi, self).__init__() 61 | self.moving_avg = [moving_avg(kernel, stride=1) for kernel in kernel_size] 62 | self.layer = torch.nn.Linear(1, len(kernel_size)) 63 | 64 | def forward(self, x): 65 | moving_mean=[] 66 | for func in self.moving_avg: 67 | moving_avg = func(x) 68 | moving_mean.append(moving_avg.unsqueeze(-1)) 69 | moving_mean=torch.cat(moving_mean,dim=-1) 70 | moving_mean = torch.sum(moving_mean*nn.Softmax(-1)(self.layer(x.unsqueeze(-1))),dim=-1) 71 | res = x - moving_mean 72 | return res, moving_mean 73 | 74 | 75 | class FourierDecomp(nn.Module): 76 | def __init__(self): 77 | super(FourierDecomp, self).__init__() 78 | pass 79 | 80 | def forward(self, x): 81 | x_ft = torch.fft.rfft(x, dim=-1) 82 | 83 | 84 | class EncoderLayer(nn.Module): 85 | """ 86 | Autoformer encoder layer with the progressive decomposition architecture 87 | """ 88 | def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): 89 | super(EncoderLayer, self).__init__() 90 | d_ff = d_ff or 4 * d_model 91 | self.attention = attention 92 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 93 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 94 | 95 | if isinstance(moving_avg, list): 96 | self.decomp1 = series_decomp_multi(moving_avg) 97 | self.decomp2 = series_decomp_multi(moving_avg) 98 | else: 99 | self.decomp1 = series_decomp(moving_avg) 100 | self.decomp2 = series_decomp(moving_avg) 101 | 102 | self.dropout = nn.Dropout(dropout) 103 | self.activation = F.relu if activation == "relu" else F.gelu 104 | 105 | def forward(self, x, attn_mask=None): 106 | new_x, attn = self.attention( 107 | x, x, x, 108 | attn_mask=attn_mask 109 | ) 110 | x = x + self.dropout(new_x) 111 | x, _ = self.decomp1(x) 112 | y = x 113 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 114 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 115 | res, _ = self.decomp2(x + y) 116 | return res, attn 117 | 118 | 119 | class Encoder(nn.Module): 120 | """ 121 | Autoformer encoder 122 | """ 123 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 124 | super(Encoder, self).__init__() 125 | self.attn_layers = nn.ModuleList(attn_layers) 126 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 127 | self.norm = norm_layer 128 | 129 | def forward(self, x, attn_mask=None): 130 | attns = [] 131 | if self.conv_layers is not None: 132 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 133 | x, attn = attn_layer(x, attn_mask=attn_mask) 134 | x = conv_layer(x) 135 | attns.append(attn) 136 | x, attn = self.attn_layers[-1](x) 137 | attns.append(attn) 138 | else: 139 | for attn_layer in self.attn_layers: 140 | x, attn = attn_layer(x, attn_mask=attn_mask) 141 | attns.append(attn) 142 | 143 | if self.norm is not None: 144 | x = self.norm(x) 145 | 146 | return x, attns 147 | 148 | 149 | class DecoderLayer(nn.Module): 150 | """ 151 | Autoformer decoder layer with the progressive decomposition architecture 152 | """ 153 | def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None, 154 | moving_avg=25, dropout=0.1, activation="relu"): 155 | super(DecoderLayer, self).__init__() 156 | d_ff = d_ff or 4 * d_model 157 | self.self_attention = self_attention 158 | self.cross_attention = cross_attention 159 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) 160 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) 161 | 162 | if isinstance(moving_avg, list): 163 | self.decomp1 = series_decomp_multi(moving_avg) 164 | self.decomp2 = series_decomp_multi(moving_avg) 165 | self.decomp3 = series_decomp_multi(moving_avg) 166 | else: 167 | self.decomp1 = series_decomp(moving_avg) 168 | self.decomp2 = series_decomp(moving_avg) 169 | self.decomp3 = series_decomp(moving_avg) 170 | 171 | self.dropout = nn.Dropout(dropout) 172 | self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1, 173 | padding_mode='circular', bias=False) 174 | self.activation = F.relu if activation == "relu" else F.gelu 175 | 176 | def forward(self, x, cross, x_mask=None, cross_mask=None): 177 | x = x + self.dropout(self.self_attention( 178 | x, x, x, 179 | attn_mask=x_mask 180 | )[0]) 181 | 182 | x, trend1 = self.decomp1(x) 183 | x = x + self.dropout(self.cross_attention( 184 | x, cross, cross, 185 | attn_mask=cross_mask 186 | )[0]) 187 | 188 | x, trend2 = self.decomp2(x) 189 | y = x 190 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 191 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 192 | x, trend3 = self.decomp3(x + y) 193 | 194 | residual_trend = trend1 + trend2 + trend3 195 | residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) 196 | return x, residual_trend 197 | 198 | 199 | class Decoder(nn.Module): 200 | """ 201 | Autoformer encoder 202 | """ 203 | def __init__(self, layers, norm_layer=None, projection=None): 204 | super(Decoder, self).__init__() 205 | self.layers = nn.ModuleList(layers) 206 | self.norm = norm_layer 207 | self.projection = projection 208 | 209 | def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): 210 | for layer in self.layers: 211 | x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 212 | trend = trend + residual_trend 213 | 214 | if self.norm is not None: 215 | x = self.norm(x) 216 | 217 | if self.projection is not None: 218 | x = self.projection(x) 219 | return x, trend 220 | -------------------------------------------------------------------------------- /layers/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm 5 | import math 6 | 7 | 8 | class PositionalEmbedding(nn.Module): 9 | def __init__(self, d_model, max_len=5000): 10 | super(PositionalEmbedding, self).__init__() 11 | # Compute the positional encodings once in log space. 12 | pe = torch.zeros(max_len, d_model).float() 13 | pe.require_grad = False 14 | 15 | position = torch.arange(0, max_len).float().unsqueeze(1) 16 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 17 | 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | 21 | pe = pe.unsqueeze(0) 22 | self.register_buffer('pe', pe) 23 | 24 | def forward(self, x): 25 | return self.pe[:, :x.size(1)] 26 | 27 | 28 | class TokenEmbedding(nn.Module): 29 | def __init__(self, c_in, d_model): 30 | super(TokenEmbedding, self).__init__() 31 | padding = 1 if torch.__version__ >= '1.5.0' else 2 32 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 33 | kernel_size=3, padding=padding, padding_mode='circular', bias=False) 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv1d): 36 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') 37 | 38 | def forward(self, x): 39 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 40 | return x 41 | 42 | 43 | class FixedEmbedding(nn.Module): 44 | def __init__(self, c_in, d_model): 45 | super(FixedEmbedding, self).__init__() 46 | 47 | w = torch.zeros(c_in, d_model).float() 48 | w.require_grad = False 49 | 50 | position = torch.arange(0, c_in).float().unsqueeze(1) 51 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 52 | 53 | w[:, 0::2] = torch.sin(position * div_term) 54 | w[:, 1::2] = torch.cos(position * div_term) 55 | 56 | self.emb = nn.Embedding(c_in, d_model) 57 | self.emb.weight = nn.Parameter(w, requires_grad=False) 58 | 59 | def forward(self, x): 60 | return self.emb(x).detach() 61 | 62 | 63 | class TemporalEmbedding(nn.Module): 64 | def __init__(self, d_model, embed_type='fixed', freq='h'): 65 | super(TemporalEmbedding, self).__init__() 66 | 67 | minute_size = 4 68 | hour_size = 24 69 | weekday_size = 7 70 | day_size = 32 71 | month_size = 13 72 | 73 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding 74 | if freq == 't': 75 | self.minute_embed = Embed(minute_size, d_model) 76 | self.hour_embed = Embed(hour_size, d_model) 77 | self.weekday_embed = Embed(weekday_size, d_model) 78 | self.day_embed = Embed(day_size, d_model) 79 | self.month_embed = Embed(month_size, d_model) 80 | 81 | def forward(self, x): 82 | x = x.long() 83 | 84 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0. 85 | hour_x = self.hour_embed(x[:, :, 3]) 86 | weekday_x = self.weekday_embed(x[:, :, 2]) 87 | day_x = self.day_embed(x[:, :, 1]) 88 | month_x = self.month_embed(x[:, :, 0]) 89 | 90 | return hour_x + weekday_x + day_x + month_x + minute_x 91 | 92 | 93 | class TimeFeatureEmbedding(nn.Module): 94 | def __init__(self, d_model, embed_type='timeF', freq='h'): 95 | super(TimeFeatureEmbedding, self).__init__() 96 | 97 | freq_map = {'h': 7, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 98 | d_inp = freq_map[freq] 99 | self.embed = nn.Linear(d_inp, d_model, bias=False) 100 | 101 | def forward(self, x): 102 | return self.embed(x) 103 | 104 | 105 | class DataEmbedding(nn.Module): 106 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 107 | super(DataEmbedding, self).__init__() 108 | 109 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 110 | self.position_embedding = PositionalEmbedding(d_model=d_model) 111 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 112 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 113 | d_model=d_model, embed_type=embed_type, freq=freq) 114 | self.dropout = nn.Dropout(p=dropout) 115 | 116 | def forward(self, x, x_mark): 117 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 118 | return self.dropout(x) 119 | 120 | 121 | class DataEmbedding_wo_pos(nn.Module): 122 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 123 | super(DataEmbedding_wo_pos, self).__init__() 124 | 125 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 126 | self.position_embedding = PositionalEmbedding(d_model=d_model) 127 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 128 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 129 | d_model=d_model, embed_type=embed_type, freq=freq) 130 | self.dropout = nn.Dropout(p=dropout) 131 | 132 | def forward(self, x, x_mark): 133 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) 134 | return self.dropout(x) 135 | 136 | class DataEmbedding_wo_pos_temp(nn.Module): 137 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 138 | super(DataEmbedding_wo_pos_temp, self).__init__() 139 | 140 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 141 | self.position_embedding = PositionalEmbedding(d_model=d_model) 142 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 143 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 144 | d_model=d_model, embed_type=embed_type, freq=freq) 145 | self.dropout = nn.Dropout(p=dropout) 146 | 147 | def forward(self, x, x_mark): 148 | x = self.value_embedding(x) 149 | return self.dropout(x) 150 | 151 | class DataEmbedding_wo_temp(nn.Module): 152 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 153 | super(DataEmbedding_wo_temp, self).__init__() 154 | 155 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 156 | self.position_embedding = PositionalEmbedding(d_model=d_model) 157 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 158 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 159 | d_model=d_model, embed_type=embed_type, freq=freq) 160 | self.dropout = nn.Dropout(p=dropout) 161 | 162 | def forward(self, x, x_mark): 163 | x = self.value_embedding(x) + self.position_embedding(x) 164 | return self.dropout(x) -------------------------------------------------------------------------------- /layers/FourierCorrelation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # author=maziqing 3 | # email=maziqing.mzq@alibaba-inc.com 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def get_frequency_modes(seq_len, modes=64, mode_select_method='random'): 11 | """ 12 | get modes on frequency domain: 13 | 'random' means sampling randomly; 14 | 'else' means sampling the lowest modes; 15 | """ 16 | modes = min(modes, seq_len//2) 17 | if mode_select_method == 'random': 18 | index = list(range(0, seq_len // 2)) 19 | np.random.shuffle(index) 20 | index = index[:modes] 21 | else: 22 | index = list(range(0, modes)) 23 | index.sort() 24 | return index 25 | 26 | 27 | # ########## fourier layer ############# 28 | class FourierBlock(nn.Module): 29 | def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'): 30 | super(FourierBlock, self).__init__() 31 | print('fourier enhanced block used!') 32 | """ 33 | 1D Fourier block. It performs representation learning on frequency domain, 34 | it does FFT, linear transform, and Inverse FFT. 35 | """ 36 | # get modes on frequency domain 37 | self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method) 38 | print('modes={}, index={}'.format(modes, self.index)) 39 | 40 | self.scale = (1 / (in_channels * out_channels)) 41 | self.weights1 = nn.Parameter( 42 | self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.cfloat)) 43 | 44 | # Complex multiplication 45 | def compl_mul1d(self, input, weights): 46 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 47 | return torch.einsum("bhi,hio->bho", input, weights) 48 | 49 | def forward(self, q, k, v, mask): 50 | # size = [B, L, H, E] 51 | B, L, H, E = q.shape 52 | x = q.permute(0, 2, 3, 1) 53 | # Compute Fourier coefficients 54 | x_ft = torch.fft.rfft(x, dim=-1) 55 | # Perform Fourier neural operations 56 | out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) 57 | for wi, i in enumerate(self.index): 58 | out_ft[:, :, :, wi] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, wi]) 59 | # Return to time domain 60 | x = torch.fft.irfft(out_ft, n=x.size(-1)) 61 | return (x, None) 62 | 63 | 64 | # ########## Fourier Cross Former #################### 65 | class FourierCrossAttention(nn.Module): 66 | def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random', 67 | activation='tanh', policy=0): 68 | super(FourierCrossAttention, self).__init__() 69 | print(' fourier enhanced cross attention used!') 70 | """ 71 | 1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT. 72 | """ 73 | self.activation = activation 74 | self.in_channels = in_channels 75 | self.out_channels = out_channels 76 | # get modes for queries and keys (& values) on frequency domain 77 | self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method) 78 | self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method) 79 | 80 | print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q)) 81 | print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv)) 82 | 83 | self.scale = (1 / (in_channels * out_channels)) 84 | self.weights1 = nn.Parameter( 85 | self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index_q), dtype=torch.cfloat)) 86 | 87 | # Complex multiplication 88 | def compl_mul1d(self, input, weights): 89 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 90 | return torch.einsum("bhi,hio->bho", input, weights) 91 | 92 | def forward(self, q, k, v, mask): 93 | # size = [B, L, H, E] 94 | B, L, H, E = q.shape 95 | xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L] 96 | xk = k.permute(0, 2, 3, 1) 97 | xv = v.permute(0, 2, 3, 1) 98 | 99 | # Compute Fourier coefficients 100 | xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) 101 | xq_ft = torch.fft.rfft(xq, dim=-1) 102 | for i, j in enumerate(self.index_q): 103 | xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] 104 | xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat) 105 | xk_ft = torch.fft.rfft(xk, dim=-1) 106 | for i, j in enumerate(self.index_kv): 107 | xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] 108 | 109 | # perform attention mechanism on frequency domain 110 | xqk_ft = (torch.einsum("bhex,bhey->bhxy", xq_ft_, xk_ft_)) 111 | if self.activation == 'tanh': 112 | xqk_ft = xqk_ft.tanh() 113 | elif self.activation == 'softmax': 114 | xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) 115 | xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) 116 | else: 117 | raise Exception('{} actiation function is not implemented'.format(self.activation)) 118 | xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_) 119 | xqkvw = torch.einsum("bhex,heox->bhox", xqkv_ft, self.weights1) 120 | out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) 121 | for i, j in enumerate(self.index_q): 122 | out_ft[:, :, :, j] = xqkvw[:, :, :, i] 123 | # Return to time domain 124 | out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)) 125 | return (out, None) 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /layers/PatchTST_layers.py: -------------------------------------------------------------------------------- 1 | __all__ = ['Transpose', 'get_activation_fn', 'moving_avg', 'series_decomp', 'PositionalEncoding', 'SinCosPosEncoding', 'Coord2dPosEncoding', 'Coord1dPosEncoding', 'positional_encoding'] 2 | 3 | import torch 4 | from torch import nn 5 | import math 6 | 7 | class Transpose(nn.Module): 8 | def __init__(self, *dims, contiguous=False): 9 | super().__init__() 10 | self.dims, self.contiguous = dims, contiguous 11 | def forward(self, x): 12 | if self.contiguous: return x.transpose(*self.dims).contiguous() 13 | else: return x.transpose(*self.dims) 14 | 15 | 16 | def get_activation_fn(activation): 17 | if callable(activation): return activation() 18 | elif activation.lower() == "relu": return nn.ReLU() 19 | elif activation.lower() == "gelu": return nn.GELU() 20 | raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable') 21 | 22 | 23 | # decomposition 24 | 25 | class moving_avg(nn.Module): 26 | """ 27 | Moving average block to highlight the trend of time series 28 | """ 29 | def __init__(self, kernel_size, stride): 30 | super(moving_avg, self).__init__() 31 | self.kernel_size = kernel_size 32 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 33 | 34 | def forward(self, x): 35 | # padding on the both ends of time series 36 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 37 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 38 | x = torch.cat([front, x, end], dim=1) 39 | x = self.avg(x.permute(0, 2, 1)) 40 | x = x.permute(0, 2, 1) 41 | return x 42 | 43 | 44 | class series_decomp(nn.Module): 45 | """ 46 | Series decomposition block 47 | """ 48 | def __init__(self, kernel_size): 49 | super(series_decomp, self).__init__() 50 | self.moving_avg = moving_avg(kernel_size, stride=1) 51 | 52 | def forward(self, x): 53 | moving_mean = self.moving_avg(x) 54 | res = x - moving_mean 55 | return res, moving_mean 56 | 57 | 58 | 59 | # pos_encoding 60 | 61 | def PositionalEncoding(q_len, d_model, normalize=True): 62 | pe = torch.zeros(q_len, d_model) 63 | position = torch.arange(0, q_len).unsqueeze(1) 64 | div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) 65 | pe[:, 0::2] = torch.sin(position * div_term) 66 | pe[:, 1::2] = torch.cos(position * div_term) 67 | if normalize: 68 | pe = pe - pe.mean() 69 | pe = pe / (pe.std() * 10) 70 | return pe 71 | 72 | SinCosPosEncoding = PositionalEncoding 73 | 74 | def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False): 75 | x = .5 if exponential else 1 76 | i = 0 77 | for i in range(100): 78 | cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1 79 | pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose) 80 | if abs(cpe.mean()) <= eps: break 81 | elif cpe.mean() > eps: x += .001 82 | else: x -= .001 83 | i += 1 84 | if normalize: 85 | cpe = cpe - cpe.mean() 86 | cpe = cpe / (cpe.std() * 10) 87 | return cpe 88 | 89 | def Coord1dPosEncoding(q_len, exponential=False, normalize=True): 90 | cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1) 91 | if normalize: 92 | cpe = cpe - cpe.mean() 93 | cpe = cpe / (cpe.std() * 10) 94 | return cpe 95 | 96 | def positional_encoding(pe, learn_pe, q_len, d_model): 97 | # Positional encoding 98 | if pe == None: 99 | W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe 100 | nn.init.uniform_(W_pos, -0.02, 0.02) 101 | learn_pe = False 102 | elif pe == 'zero': 103 | W_pos = torch.empty((q_len, 1)) 104 | nn.init.uniform_(W_pos, -0.02, 0.02) 105 | elif pe == 'zeros': 106 | W_pos = torch.empty((q_len, d_model)) 107 | nn.init.uniform_(W_pos, -0.02, 0.02) 108 | elif pe == 'normal' or pe == 'gauss': 109 | W_pos = torch.zeros((q_len, 1)) 110 | torch.nn.init.normal_(W_pos, mean=0.0, std=0.1) 111 | elif pe == 'uniform': 112 | W_pos = torch.zeros((q_len, 1)) 113 | nn.init.uniform_(W_pos, a=0.0, b=0.1) 114 | elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True) 115 | elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True) 116 | elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True) 117 | elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True) 118 | elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True) 119 | else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \ 120 | 'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)") 121 | return nn.Parameter(W_pos, requires_grad=learn_pe) -------------------------------------------------------------------------------- /layers/RevIN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RevIN(nn.Module): 5 | def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False, last_n=1): 6 | """ 7 | :param num_features: the number of features or channels 8 | :param eps: a value added for numerical stability 9 | :param affine: if True, RevIN has learnable affine parameters 10 | """ 11 | super(RevIN, self).__init__() 12 | self.num_features = num_features 13 | self.eps = eps 14 | self.affine = affine 15 | self.subtract_last = subtract_last 16 | self.last_n = last_n 17 | if self.affine: 18 | self._init_params() 19 | 20 | def forward(self, x, mode:str): 21 | if mode == 'norm': 22 | self._get_statistics(x) 23 | x = self._normalize(x) 24 | elif mode == 'denorm': 25 | x = self._denormalize(x) 26 | else: raise NotImplementedError 27 | return x 28 | 29 | def _init_params(self): 30 | # initialize RevIN params: (C,) 31 | self.affine_weight = nn.Parameter(torch.ones(self.num_features)) 32 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) 33 | 34 | def _get_statistics(self, x): 35 | dim2reduce = tuple(range(1, x.ndim-1)) 36 | if self.subtract_last: 37 | last_n = min(self.last_n, x.shape[1]) 38 | self.last = x[:,-last_n:,:] 39 | if last_n == 1: 40 | self.last.unsqueeze(1) 41 | else: 42 | self.last = torch.mean(self.last, dim=dim2reduce, keepdim=True).detach() 43 | else: 44 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() 45 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() 46 | 47 | def _normalize(self, x): 48 | if self.subtract_last: 49 | x = x - self.last 50 | else: 51 | x = x - self.mean 52 | x = x / self.stdev 53 | if self.affine: 54 | x = x * self.affine_weight 55 | x = x + self.affine_bias 56 | return x 57 | 58 | def _denormalize(self, x): 59 | if self.affine: 60 | x = x - self.affine_bias 61 | x = x / (self.affine_weight + self.eps*self.eps) 62 | x = x * self.stdev 63 | if self.subtract_last: 64 | x = x + self.last 65 | else: 66 | x = x + self.mean 67 | return x -------------------------------------------------------------------------------- /layers/SelfAttention_Family.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import matplotlib.pyplot as plt 6 | 7 | import numpy as np 8 | import math 9 | from math import sqrt 10 | from utils.masking import TriangularCausalMask, ProbMask 11 | import os 12 | 13 | 14 | class FullAttention(nn.Module): 15 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 16 | super(FullAttention, self).__init__() 17 | self.scale = scale 18 | self.mask_flag = mask_flag 19 | self.output_attention = output_attention 20 | self.dropout = nn.Dropout(attention_dropout) 21 | 22 | def forward(self, queries, keys, values, attn_mask): 23 | B, L, H, E = queries.shape 24 | _, S, _, D = values.shape 25 | scale = self.scale or 1. / sqrt(E) 26 | 27 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) 28 | 29 | if self.mask_flag: 30 | if attn_mask is None: 31 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 32 | 33 | scores.masked_fill_(attn_mask.mask, -np.inf) 34 | 35 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 36 | V = torch.einsum("bhls,bshd->blhd", A, values) 37 | 38 | if self.output_attention: 39 | return (V.contiguous(), A) 40 | else: 41 | return (V.contiguous(), None) 42 | 43 | 44 | class ProbAttention(nn.Module): 45 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 46 | super(ProbAttention, self).__init__() 47 | self.factor = factor 48 | self.scale = scale 49 | self.mask_flag = mask_flag 50 | self.output_attention = output_attention 51 | self.dropout = nn.Dropout(attention_dropout) 52 | 53 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) 54 | # Q [B, H, L, D] 55 | B, H, L_K, E = K.shape 56 | _, _, L_Q, _ = Q.shape 57 | 58 | # calculate the sampled Q_K 59 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) 60 | index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q 61 | K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] 62 | Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() 63 | 64 | # find the Top_k query with sparisty measurement 65 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) 66 | M_top = M.topk(n_top, sorted=False)[1] 67 | 68 | # use the reduced Q to calculate Q_K 69 | Q_reduce = Q[torch.arange(B)[:, None, None], 70 | torch.arange(H)[None, :, None], 71 | M_top, :] # factor*ln(L_q) 72 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k 73 | 74 | return Q_K, M_top 75 | 76 | def _get_initial_context(self, V, L_Q): 77 | B, H, L_V, D = V.shape 78 | if not self.mask_flag: 79 | # V_sum = V.sum(dim=-2) 80 | V_sum = V.mean(dim=-2) 81 | contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() 82 | else: # use mask 83 | assert (L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only 84 | contex = V.cumsum(dim=-2) 85 | return contex 86 | 87 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): 88 | B, H, L_V, D = V.shape 89 | 90 | if self.mask_flag: 91 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) 92 | scores.masked_fill_(attn_mask.mask, -np.inf) 93 | 94 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) 95 | 96 | context_in[torch.arange(B)[:, None, None], 97 | torch.arange(H)[None, :, None], 98 | index, :] = torch.matmul(attn, V).type_as(context_in) 99 | if self.output_attention: 100 | attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) 101 | attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn 102 | return (context_in, attns) 103 | else: 104 | return (context_in, None) 105 | 106 | def forward(self, queries, keys, values, attn_mask): 107 | B, L_Q, H, D = queries.shape 108 | _, L_K, _, _ = keys.shape 109 | 110 | queries = queries.transpose(2, 1) 111 | keys = keys.transpose(2, 1) 112 | values = values.transpose(2, 1) 113 | 114 | U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) 115 | u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 116 | 117 | U_part = U_part if U_part < L_K else L_K 118 | u = u if u < L_Q else L_Q 119 | 120 | scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) 121 | 122 | # add scale factor 123 | scale = self.scale or 1. / sqrt(D) 124 | if scale is not None: 125 | scores_top = scores_top * scale 126 | # get the context 127 | context = self._get_initial_context(values, L_Q) 128 | # update the context with selected top_k queries 129 | context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask) 130 | 131 | return context.contiguous(), attn 132 | 133 | 134 | class AttentionLayer(nn.Module): 135 | def __init__(self, attention, d_model, n_heads, d_keys=None, 136 | d_values=None): 137 | super(AttentionLayer, self).__init__() 138 | 139 | d_keys = d_keys or (d_model // n_heads) 140 | d_values = d_values or (d_model // n_heads) 141 | 142 | self.inner_attention = attention 143 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 144 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 145 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 146 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 147 | self.n_heads = n_heads 148 | 149 | def forward(self, queries, keys, values, attn_mask): 150 | B, L, _ = queries.shape 151 | _, S, _ = keys.shape 152 | H = self.n_heads 153 | 154 | queries = self.query_projection(queries).view(B, L, H, -1) 155 | keys = self.key_projection(keys).view(B, S, H, -1) 156 | values = self.value_projection(values).view(B, S, H, -1) 157 | 158 | out, attn = self.inner_attention( 159 | queries, 160 | keys, 161 | values, 162 | attn_mask 163 | ) 164 | out = out.view(B, L, -1) 165 | 166 | return self.out_projection(out), attn -------------------------------------------------------------------------------- /layers/Transformer_EncDec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvLayer(nn.Module): 7 | def __init__(self, c_in): 8 | super(ConvLayer, self).__init__() 9 | self.downConv = nn.Conv1d(in_channels=c_in, 10 | out_channels=c_in, 11 | kernel_size=3, 12 | padding=2, 13 | padding_mode='circular') 14 | self.norm = nn.BatchNorm1d(c_in) 15 | self.activation = nn.ELU() 16 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 17 | 18 | def forward(self, x): 19 | x = self.downConv(x.permute(0, 2, 1)) 20 | x = self.norm(x) 21 | x = self.activation(x) 22 | x = self.maxPool(x) 23 | x = x.transpose(1, 2) 24 | return x 25 | 26 | 27 | class EncoderLayer(nn.Module): 28 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 29 | super(EncoderLayer, self).__init__() 30 | d_ff = d_ff or 4 * d_model 31 | self.attention = attention 32 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 33 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 34 | self.norm1 = nn.LayerNorm(d_model) 35 | self.norm2 = nn.LayerNorm(d_model) 36 | self.dropout = nn.Dropout(dropout) 37 | self.activation = F.relu if activation == "relu" else F.gelu 38 | 39 | def forward(self, x, attn_mask=None): 40 | new_x, attn = self.attention( 41 | x, x, x, 42 | attn_mask=attn_mask 43 | ) 44 | x = x + self.dropout(new_x) 45 | 46 | y = x = self.norm1(x) 47 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 48 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 49 | 50 | return self.norm2(x + y), attn 51 | 52 | 53 | class Encoder(nn.Module): 54 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 55 | super(Encoder, self).__init__() 56 | self.attn_layers = nn.ModuleList(attn_layers) 57 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 58 | self.norm = norm_layer 59 | 60 | def forward(self, x, attn_mask=None): 61 | # x [B, L, D] 62 | attns = [] 63 | if self.conv_layers is not None: 64 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 65 | x, attn = attn_layer(x, attn_mask=attn_mask) 66 | x = conv_layer(x) 67 | attns.append(attn) 68 | x, attn = self.attn_layers[-1](x) 69 | attns.append(attn) 70 | else: 71 | for attn_layer in self.attn_layers: 72 | x, attn = attn_layer(x, attn_mask=attn_mask) 73 | attns.append(attn) 74 | 75 | if self.norm is not None: 76 | x = self.norm(x) 77 | 78 | return x, attns 79 | 80 | 81 | class DecoderLayer(nn.Module): 82 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 83 | dropout=0.1, activation="relu"): 84 | super(DecoderLayer, self).__init__() 85 | d_ff = d_ff or 4 * d_model 86 | self.self_attention = self_attention 87 | self.cross_attention = cross_attention 88 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 89 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 90 | self.norm1 = nn.LayerNorm(d_model) 91 | self.norm2 = nn.LayerNorm(d_model) 92 | self.norm3 = nn.LayerNorm(d_model) 93 | self.dropout = nn.Dropout(dropout) 94 | self.activation = F.relu if activation == "relu" else F.gelu 95 | 96 | def forward(self, x, cross, x_mask=None, cross_mask=None): 97 | x = x + self.dropout(self.self_attention( 98 | x, x, x, 99 | attn_mask=x_mask 100 | )[0]) 101 | x = self.norm1(x) 102 | 103 | x = x + self.dropout(self.cross_attention( 104 | x, cross, cross, 105 | attn_mask=cross_mask 106 | )[0]) 107 | 108 | y = x = self.norm2(x) 109 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) 110 | y = self.dropout(self.conv2(y).transpose(-1, 1)) 111 | 112 | return self.norm3(x + y) 113 | 114 | 115 | class Decoder(nn.Module): 116 | def __init__(self, layers, norm_layer=None, projection=None): 117 | super(Decoder, self).__init__() 118 | self.layers = nn.ModuleList(layers) 119 | self.norm = norm_layer 120 | self.projection = projection 121 | 122 | def forward(self, x, cross, x_mask=None, cross_mask=None): 123 | for layer in self.layers: 124 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 125 | 126 | if self.norm is not None: 127 | x = self.norm(x) 128 | 129 | if self.projection is not None: 130 | x = self.projection(x) 131 | return x -------------------------------------------------------------------------------- /models/Autoformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # author=maziqing 3 | # email=maziqing.mzq@alibaba-inc.com 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from layers.Embed import DataEmbedding, DataEmbedding_wo_pos 10 | from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer 11 | from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp 12 | 13 | 14 | class Model(nn.Module): 15 | """ 16 | Autoformer is the first method to achieve the series-wise connection, 17 | with inherent O(LlogL) complexity 18 | """ 19 | def __init__(self, configs): 20 | super(Model, self).__init__() 21 | self.seq_len = configs.seq_len 22 | self.label_len = configs.label_len 23 | self.pred_len = configs.pred_len 24 | self.output_attention = configs.output_attention 25 | 26 | # Decomp 27 | kernel_size = configs.moving_avg 28 | self.decomp = series_decomp(kernel_size) 29 | 30 | # Embedding 31 | # The series-wise connection inherently contains the sequential information. 32 | # Thus, we can discard the position embedding of transformers. 33 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, 34 | configs.dropout) 35 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, 36 | configs.dropout) 37 | 38 | # Encoder 39 | self.encoder = Encoder( 40 | [ 41 | EncoderLayer( 42 | AutoCorrelationLayer( 43 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 44 | output_attention=configs.output_attention), 45 | configs.d_model, configs.n_heads), 46 | configs.d_model, 47 | configs.d_ff, 48 | moving_avg=configs.moving_avg, 49 | dropout=configs.dropout, 50 | activation=configs.activation 51 | ) for l in range(configs.e_layers) 52 | ], 53 | norm_layer=my_Layernorm(configs.d_model) 54 | ) 55 | # Decoder 56 | self.decoder = Decoder( 57 | [ 58 | DecoderLayer( 59 | AutoCorrelationLayer( 60 | AutoCorrelation(True, configs.factor, attention_dropout=configs.dropout, 61 | output_attention=False), 62 | configs.d_model, configs.n_heads), 63 | AutoCorrelationLayer( 64 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, 65 | output_attention=False), 66 | configs.d_model, configs.n_heads), 67 | configs.d_model, 68 | configs.c_out, 69 | configs.d_ff, 70 | moving_avg=configs.moving_avg, 71 | dropout=configs.dropout, 72 | activation=configs.activation, 73 | ) 74 | for l in range(configs.d_layers) 75 | ], 76 | norm_layer=my_Layernorm(configs.d_model), 77 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 78 | ) 79 | 80 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 81 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 82 | # decomp init 83 | mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) 84 | zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device) 85 | seasonal_init, trend_init = self.decomp(x_enc) 86 | # decoder input 87 | trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1) 88 | seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1) 89 | # enc 90 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 91 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 92 | # dec 93 | dec_out = self.dec_embedding(seasonal_init, x_mark_dec) 94 | seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, 95 | trend=trend_init) 96 | # final 97 | dec_out = trend_part + seasonal_part 98 | 99 | if self.output_attention: 100 | return dec_out[:, -self.pred_len:, :], attns 101 | else: 102 | return dec_out[:, -self.pred_len:, :] -------------------------------------------------------------------------------- /models/DLinear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class moving_avg(nn.Module): 7 | """ 8 | Moving average block to highlight the trend of time series 9 | """ 10 | def __init__(self, kernel_size, stride): 11 | super(moving_avg, self).__init__() 12 | self.kernel_size = kernel_size 13 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 14 | 15 | def forward(self, x): 16 | # padding on the both ends of time series 17 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) 18 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) 19 | x = torch.cat([front, x, end], dim=1) 20 | x = self.avg(x.permute(0, 2, 1)) 21 | x = x.permute(0, 2, 1) 22 | return x 23 | 24 | 25 | class series_decomp(nn.Module): 26 | """ 27 | Series decomposition block 28 | """ 29 | def __init__(self, kernel_size): 30 | super(series_decomp, self).__init__() 31 | self.moving_avg = moving_avg(kernel_size, stride=1) 32 | 33 | def forward(self, x): 34 | moving_mean = self.moving_avg(x) 35 | res = x - moving_mean 36 | return res, moving_mean 37 | 38 | class Model(nn.Module): 39 | """ 40 | Decomposition-Linear 41 | """ 42 | def __init__(self, configs): 43 | super(Model, self).__init__() 44 | self.seq_len = configs.seq_len 45 | self.pred_len = configs.pred_len 46 | 47 | # Decompsition Kernel Size 48 | kernel_size = 25 49 | self.decompsition = series_decomp(kernel_size) 50 | self.individual = configs.individual 51 | self.channels = configs.enc_in 52 | 53 | if self.individual: 54 | self.Linear_Seasonal = nn.ModuleList() 55 | self.Linear_Trend = nn.ModuleList() 56 | 57 | for i in range(self.channels): 58 | self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len)) 59 | self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len)) 60 | 61 | # Use this two lines if you want to visualize the weights 62 | # self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 63 | # self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 64 | else: 65 | self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len) 66 | self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len) 67 | 68 | # Use this two lines if you want to visualize the weights 69 | # self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 70 | # self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) 71 | 72 | def forward(self, x): 73 | # x: [Batch, Input length, Channel] 74 | seasonal_init, trend_init = self.decompsition(x) 75 | seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1) 76 | if self.individual: 77 | seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device) 78 | trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device) 79 | for i in range(self.channels): 80 | seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:]) 81 | trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:]) 82 | else: 83 | seasonal_output = self.Linear_Seasonal(seasonal_init) 84 | trend_output = self.Linear_Trend(trend_init) 85 | 86 | x = seasonal_output + trend_output 87 | return x.permute(0,2,1) # to [Batch, Output length, Channel] -------------------------------------------------------------------------------- /models/FEDformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layers.Embed import DataEmbedding, DataEmbedding_wo_pos 5 | from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer 6 | from layers.FourierCorrelation import FourierBlock, FourierCrossAttention 7 | from layers.MultiWaveletCorrelation import MultiWaveletCross, MultiWaveletTransform 8 | from layers.SelfAttention_Family import FullAttention, ProbAttention 9 | from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp, series_decomp_multi 10 | import math 11 | import numpy as np 12 | 13 | 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | class Model(nn.Module): 18 | """ 19 | FEDformer performs the attention mechanism on frequency domain and achieved O(N) complexity 20 | """ 21 | def __init__(self, configs): 22 | super(Model, self).__init__() 23 | self.version = configs.version 24 | self.mode_select = configs.mode_select 25 | self.modes = configs.modes 26 | self.seq_len = configs.seq_len 27 | self.label_len = configs.label_len 28 | self.pred_len = configs.pred_len 29 | self.output_attention = configs.output_attention 30 | 31 | # Decomp 32 | kernel_size = configs.moving_avg 33 | if isinstance(kernel_size, list): 34 | self.decomp = series_decomp_multi(kernel_size) 35 | else: 36 | self.decomp = series_decomp(kernel_size) 37 | 38 | # Embedding 39 | # The series-wise connection inherently contains the sequential information. 40 | # Thus, we can discard the position embedding of transformers. 41 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, 42 | configs.dropout) 43 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, 44 | configs.dropout) 45 | 46 | if configs.version == 'Wavelets': 47 | encoder_self_att = MultiWaveletTransform(ich=configs.d_model, L=configs.L, base=configs.base) 48 | decoder_self_att = MultiWaveletTransform(ich=configs.d_model, L=configs.L, base=configs.base) 49 | decoder_cross_att = MultiWaveletCross(in_channels=configs.d_model, 50 | out_channels=configs.d_model, 51 | seq_len_q=self.seq_len // 2 + self.pred_len, 52 | seq_len_kv=self.seq_len, 53 | modes=configs.modes, 54 | ich=configs.d_model, 55 | base=configs.base, 56 | activation=configs.cross_activation) 57 | else: 58 | encoder_self_att = FourierBlock(in_channels=configs.d_model, 59 | out_channels=configs.d_model, 60 | seq_len=self.seq_len, 61 | modes=configs.modes, 62 | mode_select_method=configs.mode_select) 63 | decoder_self_att = FourierBlock(in_channels=configs.d_model, 64 | out_channels=configs.d_model, 65 | seq_len=self.seq_len//2+self.pred_len, 66 | modes=configs.modes, 67 | mode_select_method=configs.mode_select) 68 | decoder_cross_att = FourierCrossAttention(in_channels=configs.d_model, 69 | out_channels=configs.d_model, 70 | seq_len_q=self.seq_len//2+self.pred_len, 71 | seq_len_kv=self.seq_len, 72 | modes=configs.modes, 73 | mode_select_method=configs.mode_select) 74 | # Encoder 75 | enc_modes = int(min(configs.modes, configs.seq_len//2)) 76 | dec_modes = int(min(configs.modes, (configs.seq_len//2+configs.pred_len)//2)) 77 | print('enc_modes: {}, dec_modes: {}'.format(enc_modes, dec_modes)) 78 | 79 | self.encoder = Encoder( 80 | [ 81 | EncoderLayer( 82 | AutoCorrelationLayer( 83 | encoder_self_att, 84 | configs.d_model, configs.n_heads), 85 | 86 | configs.d_model, 87 | configs.d_ff, 88 | moving_avg=configs.moving_avg, 89 | dropout=configs.dropout, 90 | activation=configs.activation 91 | ) for l in range(configs.e_layers) 92 | ], 93 | norm_layer=my_Layernorm(configs.d_model) 94 | ) 95 | # Decoder 96 | self.decoder = Decoder( 97 | [ 98 | DecoderLayer( 99 | AutoCorrelationLayer( 100 | decoder_self_att, 101 | configs.d_model, configs.n_heads), 102 | AutoCorrelationLayer( 103 | decoder_cross_att, 104 | configs.d_model, configs.n_heads), 105 | configs.d_model, 106 | configs.c_out, 107 | configs.d_ff, 108 | moving_avg=configs.moving_avg, 109 | dropout=configs.dropout, 110 | activation=configs.activation, 111 | ) 112 | for l in range(configs.d_layers) 113 | ], 114 | norm_layer=my_Layernorm(configs.d_model), 115 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 116 | ) 117 | 118 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 119 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 120 | # decomp init 121 | mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) 122 | zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]]).to(device) # cuda() 123 | seasonal_init, trend_init = self.decomp(x_enc) # [bsz, l, c] 124 | # decoder input 125 | trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1) 126 | seasonal_init = F.pad(seasonal_init[:, -self.label_len:, :], (0, 0, 0, self.pred_len)) 127 | # enc 128 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 129 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 130 | # dec 131 | dec_out = self.dec_embedding(seasonal_init, x_mark_dec) 132 | seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, 133 | trend=trend_init) 134 | # final 135 | dec_out = trend_part + seasonal_part 136 | 137 | if self.output_attention: 138 | return dec_out[:, -self.pred_len:, :], attns 139 | else: 140 | return dec_out[:, -self.pred_len:, :] # [B, L, D] 141 | 142 | 143 | if __name__ == '__main__': 144 | class Configs(object): 145 | ab = 0 146 | modes = 32 147 | mode_select = 'random' 148 | # version = 'Fourier' 149 | version = 'Wavelets' 150 | moving_avg = [12, 24] 151 | L = 1 152 | base = 'legendre' 153 | cross_activation = 'tanh' 154 | seq_len = 96 155 | label_len = 48 156 | pred_len = 96 157 | output_attention = True 158 | enc_in = 7 159 | dec_in = 7 160 | d_model = 16 161 | embed = 'timeF' 162 | dropout = 0.05 163 | freq = 'h' 164 | factor = 1 165 | n_heads = 8 166 | d_ff = 16 167 | e_layers = 2 168 | d_layers = 1 169 | c_out = 7 170 | activation = 'gelu' 171 | wavelet = 0 172 | 173 | configs = Configs() 174 | model = Model(configs) 175 | 176 | print('parameter number is {}'.format(sum(p.numel() for p in model.parameters()))) 177 | enc = torch.randn([3, configs.seq_len, 7]) 178 | enc_mark = torch.randn([3, configs.seq_len, 4]) 179 | 180 | dec = torch.randn([3, configs.seq_len//2+configs.pred_len, 7]) 181 | dec_mark = torch.randn([3, configs.seq_len//2+configs.pred_len, 4]) 182 | out = model.forward(enc, enc_mark, dec, dec_mark) 183 | print(out) 184 | -------------------------------------------------------------------------------- /models/Informer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.masking import TriangularCausalMask, ProbMask 5 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer 6 | from layers.SelfAttention_Family import FullAttention, ProbAttention, AttentionLayer 7 | from layers.Embed import DataEmbedding 8 | import numpy as np 9 | 10 | 11 | class Model(nn.Module): 12 | """ 13 | Informer with Propspare attention in O(LlogL) complexity 14 | """ 15 | def __init__(self, configs): 16 | super(Model, self).__init__() 17 | self.pred_len = configs.pred_len 18 | self.output_attention = configs.output_attention 19 | 20 | # Embedding 21 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, 22 | configs.dropout) 23 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, 24 | configs.dropout) 25 | 26 | # Encoder 27 | self.encoder = Encoder( 28 | [ 29 | EncoderLayer( 30 | AttentionLayer( 31 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, 32 | output_attention=configs.output_attention), 33 | configs.d_model, configs.n_heads), 34 | configs.d_model, 35 | configs.d_ff, 36 | dropout=configs.dropout, 37 | activation=configs.activation 38 | ) for l in range(configs.e_layers) 39 | ], 40 | [ 41 | ConvLayer( 42 | configs.d_model 43 | ) for l in range(configs.e_layers - 1) 44 | ] if configs.distil else None, 45 | norm_layer=torch.nn.LayerNorm(configs.d_model) 46 | ) 47 | # Decoder 48 | self.decoder = Decoder( 49 | [ 50 | DecoderLayer( 51 | AttentionLayer( 52 | ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False), 53 | configs.d_model, configs.n_heads), 54 | AttentionLayer( 55 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), 56 | configs.d_model, configs.n_heads), 57 | configs.d_model, 58 | configs.d_ff, 59 | dropout=configs.dropout, 60 | activation=configs.activation, 61 | ) 62 | for l in range(configs.d_layers) 63 | ], 64 | norm_layer=torch.nn.LayerNorm(configs.d_model), 65 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True) 66 | ) 67 | 68 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 69 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 70 | 71 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 72 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 73 | 74 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 75 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) 76 | 77 | if self.output_attention: 78 | return dec_out[:, -self.pred_len:, :], attns 79 | else: 80 | return dec_out[:, -self.pred_len:, :] 81 | -------------------------------------------------------------------------------- /models/PatchTST.py: -------------------------------------------------------------------------------- 1 | __all__ = ['PatchTST'] 2 | 3 | # Cell 4 | from typing import Callable, Optional 5 | import torch 6 | from torch import nn 7 | from torch import Tensor 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | from layers.PatchTST_backbone import PatchTST_backbone 12 | from layers.PatchTST_layers import series_decomp 13 | 14 | 15 | class Model(nn.Module): 16 | def __init__(self, configs, max_seq_len:Optional[int]=1024, d_k:Optional[int]=None, d_v:Optional[int]=None, norm:str='BatchNorm', attn_dropout:float=0., 17 | act:str="gelu", key_padding_mask:bool='auto',padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, 18 | pre_norm:bool=False, store_attn:bool=False, pe:str='zeros', learn_pe:bool=True, pretrain_head:bool=False, head_type = 'flatten', verbose:bool=False, **kwargs): 19 | 20 | super().__init__() 21 | 22 | # load parameters 23 | c_in = configs.enc_in 24 | context_window = configs.seq_len 25 | target_window = configs.pred_len 26 | 27 | n_layers = configs.e_layers 28 | n_heads = configs.n_heads 29 | d_model = configs.d_model 30 | d_ff = configs.d_ff 31 | dropout = configs.dropout 32 | fc_dropout = configs.fc_dropout 33 | head_dropout = configs.head_dropout 34 | 35 | individual = configs.individual 36 | 37 | patch_len = configs.patch_len 38 | stride = configs.stride 39 | padding_patch = configs.padding_patch 40 | 41 | revin = configs.revin 42 | affine = configs.affine 43 | subtract_last = configs.subtract_last 44 | 45 | decomposition = configs.decomposition 46 | kernel_size = configs.kernel_size 47 | 48 | 49 | # model 50 | self.decomposition = decomposition 51 | if self.decomposition: 52 | self.decomp_module = series_decomp(kernel_size) 53 | self.model_trend = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride, 54 | max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model, 55 | n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, 56 | dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, 57 | attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, 58 | pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch, 59 | pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine, 60 | subtract_last=subtract_last, verbose=verbose, **kwargs) 61 | self.model_res = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride, 62 | max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model, 63 | n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, 64 | dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, 65 | attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, 66 | pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch, 67 | pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine, 68 | subtract_last=subtract_last, verbose=verbose, **kwargs) 69 | else: 70 | self.model = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride, 71 | max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model, 72 | n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, 73 | dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, 74 | attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, 75 | pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch, 76 | pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine, 77 | subtract_last=subtract_last, verbose=verbose, **kwargs) 78 | 79 | 80 | def forward(self, x): # x: [Batch, Input length, Channel] 81 | if self.decomposition: 82 | res_init, trend_init = self.decomp_module(x) 83 | res_init, trend_init = res_init.permute(0,2,1), trend_init.permute(0,2,1) # x: [Batch, Channel, Input length] 84 | res = self.model_res(res_init) 85 | trend = self.model_trend(trend_init) 86 | x = res + trend 87 | x = x.permute(0,2,1) # x: [Batch, Input length, Channel] 88 | else: 89 | x = x.permute(0,2,1) # x: [Batch, Channel, Input length] 90 | x = self.model(x) 91 | x = x.permute(0,2,1) # x: [Batch, Input length, Channel] 92 | return x -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzhang114/OneNet/076188ad4c1b28cbda8cb4b1e231499b233e1437/models/__init__.py -------------------------------------------------------------------------------- /models/attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | from math import sqrt 8 | from utils.masking import TriangularCausalMask, ProbMask 9 | 10 | class FullAttention(nn.Module): 11 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 12 | super(FullAttention, self).__init__() 13 | self.scale = scale 14 | self.mask_flag = mask_flag 15 | self.output_attention = output_attention 16 | self.dropout = nn.Dropout(attention_dropout) 17 | 18 | def forward(self, queries, keys, values, attn_mask): 19 | B, L, H, E = queries.shape 20 | _, S, _, D = values.shape 21 | scale = self.scale or 1./sqrt(E) 22 | 23 | scores = torch.einsum("blhe,bshe->bhls", queries, keys) 24 | if self.mask_flag: 25 | if attn_mask is None: 26 | attn_mask = TriangularCausalMask(B, L, device=queries.device) 27 | 28 | scores.masked_fill_(attn_mask.mask, -np.inf) 29 | 30 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 31 | V = torch.einsum("bhls,bshd->blhd", A, values) 32 | 33 | if self.output_attention: 34 | return (V.contiguous(), A) 35 | else: 36 | return (V.contiguous(), None) 37 | 38 | class ProbAttention(nn.Module): 39 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): 40 | super(ProbAttention, self).__init__() 41 | self.factor = factor 42 | self.scale = scale 43 | self.mask_flag = mask_flag 44 | self.output_attention = output_attention 45 | self.dropout = nn.Dropout(attention_dropout) 46 | 47 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) 48 | # Q [B, H, L, D] 49 | B, H, L_K, E = K.shape 50 | _, _, L_Q, _ = Q.shape 51 | 52 | # calculate the sampled Q_K 53 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) 54 | index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q 55 | K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] 56 | Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() 57 | 58 | # find the Top_k query with sparisty measurement 59 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) 60 | M_top = M.topk(n_top, sorted=False)[1] 61 | 62 | # use the reduced Q to calculate Q_K 63 | Q_reduce = Q[torch.arange(B)[:, None, None], 64 | torch.arange(H)[None, :, None], 65 | M_top, :] # factor*ln(L_q) 66 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k 67 | 68 | return Q_K, M_top 69 | 70 | def _get_initial_context(self, V, L_Q): 71 | B, H, L_V, D = V.shape 72 | if not self.mask_flag: 73 | # V_sum = V.sum(dim=-2) 74 | V_sum = V.mean(dim=-2) 75 | contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() 76 | else: # use mask 77 | assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only 78 | contex = V.cumsum(dim=-2) 79 | return contex 80 | 81 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): 82 | B, H, L_V, D = V.shape 83 | 84 | if self.mask_flag: 85 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) 86 | scores.masked_fill_(attn_mask.mask, -np.inf) 87 | 88 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) 89 | 90 | context_in[torch.arange(B)[:, None, None], 91 | torch.arange(H)[None, :, None], 92 | index, :] = torch.matmul(attn, V).type_as(context_in) 93 | if self.output_attention: 94 | attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device) 95 | attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn 96 | return (context_in, attns) 97 | else: 98 | return (context_in, None) 99 | 100 | def forward(self, queries, keys, values, attn_mask): 101 | B, L_Q, H, D = queries.shape 102 | _, L_K, _, _ = keys.shape 103 | 104 | queries = queries.transpose(2,1) 105 | keys = keys.transpose(2,1) 106 | values = values.transpose(2,1) 107 | 108 | U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) 109 | u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 110 | 111 | U_part = U_part if U_partbhls", queries, keys) 24 | A = self.dropout(torch.softmax(scale * scores, dim=-1)) 25 | V = torch.einsum("bhls,bshd->blhd", A, values) 26 | 27 | return V.contiguous() 28 | 29 | 30 | class AttentionLayer(nn.Module): 31 | ''' 32 | The Multi-head Self-Attention (MSA) Layer 33 | ''' 34 | def __init__(self, d_model, n_heads, d_keys=None, d_values=None, mix=True, dropout = 0.1): 35 | super(AttentionLayer, self).__init__() 36 | 37 | d_keys = d_keys or (d_model//n_heads) 38 | d_values = d_values or (d_model//n_heads) 39 | 40 | self.inner_attention = FullAttention(scale=None, attention_dropout = dropout) 41 | self.query_projection = nn.Linear(d_model, d_keys * n_heads) 42 | self.key_projection = nn.Linear(d_model, d_keys * n_heads) 43 | self.value_projection = nn.Linear(d_model, d_values * n_heads) 44 | self.out_projection = nn.Linear(d_values * n_heads, d_model) 45 | self.n_heads = n_heads 46 | self.mix = mix 47 | 48 | def forward(self, queries, keys, values): 49 | B, L, _ = queries.shape 50 | _, S, _ = keys.shape 51 | H = self.n_heads 52 | 53 | queries = self.query_projection(queries).view(B, L, H, -1) 54 | keys = self.key_projection(keys).view(B, S, H, -1) 55 | values = self.value_projection(values).view(B, S, H, -1) 56 | 57 | out = self.inner_attention( 58 | queries, 59 | keys, 60 | values, 61 | ) 62 | if self.mix: 63 | out = out.transpose(2,1).contiguous() 64 | out = out.view(B, L, -1) 65 | 66 | return self.out_projection(out) 67 | 68 | class TwoStageAttentionLayer(nn.Module): 69 | ''' 70 | The Two Stage Attention (TSA) Layer 71 | input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model] 72 | ''' 73 | def __init__(self, seg_num, factor, d_model, n_heads, d_ff = None, dropout=0.1): 74 | super(TwoStageAttentionLayer, self).__init__() 75 | d_ff = d_ff or 4*d_model 76 | self.time_attention = AttentionLayer(d_model, n_heads, dropout = dropout) 77 | self.dim_sender = AttentionLayer(d_model, n_heads, dropout = dropout) 78 | self.dim_receiver = AttentionLayer(d_model, n_heads, dropout = dropout) 79 | self.router = nn.Parameter(torch.randn(seg_num, factor, d_model)) 80 | 81 | self.dropout = nn.Dropout(dropout) 82 | 83 | self.norm1 = nn.LayerNorm(d_model) 84 | self.norm2 = nn.LayerNorm(d_model) 85 | self.norm3 = nn.LayerNorm(d_model) 86 | self.norm4 = nn.LayerNorm(d_model) 87 | 88 | self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff), 89 | nn.GELU(), 90 | nn.Linear(d_ff, d_model)) 91 | self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff), 92 | nn.GELU(), 93 | nn.Linear(d_ff, d_model)) 94 | 95 | def forward(self, x): 96 | #Cross Time Stage: Directly apply MSA to each dimension 97 | batch = x.shape[0] 98 | time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model') 99 | time_enc = self.time_attention( 100 | time_in, time_in, time_in 101 | ) 102 | dim_in = time_in + self.dropout(time_enc) 103 | dim_in = self.norm1(dim_in) 104 | dim_in = dim_in + self.dropout(self.MLP1(dim_in)) 105 | dim_in = self.norm2(dim_in) 106 | 107 | #Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection 108 | dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b = batch) 109 | batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat = batch) 110 | dim_buffer = self.dim_sender(batch_router, dim_send, dim_send) 111 | dim_receive = self.dim_receiver(dim_send, dim_buffer, dim_buffer) 112 | dim_enc = dim_send + self.dropout(dim_receive) 113 | dim_enc = self.norm3(dim_enc) 114 | dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc)) 115 | dim_enc = self.norm4(dim_enc) 116 | 117 | final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b = batch) 118 | 119 | return final_out 120 | -------------------------------------------------------------------------------- /models/cross_models/cross_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from models.cross_models.attn import FullAttention, AttentionLayer, TwoStageAttentionLayer 6 | 7 | class DecoderLayer(nn.Module): 8 | ''' 9 | The decoder layer of Crossformer, each layer will make a prediction at its scale 10 | ''' 11 | def __init__(self, seg_len, d_model, n_heads, d_ff=None, dropout=0.1, out_seg_num = 10, factor = 10): 12 | super(DecoderLayer, self).__init__() 13 | self.self_attention = TwoStageAttentionLayer(out_seg_num, factor, d_model, n_heads, \ 14 | d_ff, dropout) 15 | self.cross_attention = AttentionLayer(d_model, n_heads, dropout = dropout) 16 | self.norm1 = nn.LayerNorm(d_model) 17 | self.norm2 = nn.LayerNorm(d_model) 18 | self.dropout = nn.Dropout(dropout) 19 | self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model), 20 | nn.GELU(), 21 | nn.Linear(d_model, d_model)) 22 | self.linear_pred = nn.Linear(d_model, seg_len) 23 | 24 | def forward(self, x, cross): 25 | ''' 26 | x: the output of last decoder layer 27 | cross: the output of the corresponding encoder layer 28 | ''' 29 | 30 | batch = x.shape[0] 31 | x = self.self_attention(x) 32 | x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model') 33 | 34 | cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model') 35 | tmp = self.cross_attention( 36 | x, cross, cross, 37 | ) 38 | x = x + self.dropout(tmp) 39 | y = x = self.norm1(x) 40 | y = self.MLP1(y) 41 | dec_output = self.norm2(x+y) 42 | 43 | dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b = batch) 44 | layer_predict = self.linear_pred(dec_output) 45 | layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len') 46 | 47 | return dec_output, layer_predict 48 | 49 | class Decoder(nn.Module): 50 | ''' 51 | The decoder of Crossformer, making the final prediction by adding up predictions at each scale 52 | ''' 53 | def __init__(self, seg_len, d_layers, d_model, n_heads, d_ff, dropout,\ 54 | router=False, out_seg_num = 10, factor=10): 55 | super(Decoder, self).__init__() 56 | 57 | self.router = router 58 | self.decode_layers = nn.ModuleList() 59 | for i in range(d_layers): 60 | self.decode_layers.append(DecoderLayer(seg_len, d_model, n_heads, d_ff, dropout, \ 61 | out_seg_num, factor)) 62 | 63 | def forward(self, x, cross): 64 | final_predict = None 65 | i = 0 66 | 67 | ts_d = x.shape[1] 68 | for layer in self.decode_layers: 69 | cross_enc = cross[i] 70 | x, layer_predict = layer(x, cross_enc) 71 | if final_predict is None: 72 | final_predict = layer_predict 73 | else: 74 | final_predict = final_predict + layer_predict 75 | i += 1 76 | 77 | final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d = ts_d) 78 | 79 | return final_predict 80 | 81 | -------------------------------------------------------------------------------- /models/cross_models/cross_embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | 6 | import math 7 | 8 | class DSW_embedding(nn.Module): 9 | def __init__(self, seg_len, d_model): 10 | super(DSW_embedding, self).__init__() 11 | self.seg_len = seg_len 12 | 13 | self.linear = nn.Linear(seg_len, d_model) 14 | 15 | def forward(self, x): 16 | batch, ts_len, ts_dim = x.shape 17 | 18 | x_segment = rearrange(x, 'b (seg_num seg_len) d -> (b d seg_num) seg_len', seg_len = self.seg_len) 19 | x_embed = self.linear(x_segment) 20 | x_embed = rearrange(x_embed, '(b d seg_num) d_model -> b d seg_num d_model', b = batch, d = ts_dim) 21 | 22 | return x_embed -------------------------------------------------------------------------------- /models/cross_models/cross_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | from models.cross_models.attn import FullAttention, AttentionLayer, TwoStageAttentionLayer 6 | from math import ceil 7 | 8 | class SegMerging(nn.Module): 9 | ''' 10 | Segment Merging Layer. 11 | The adjacent `win_size' segments in each dimension will be merged into one segment to 12 | get representation of a coarser scale 13 | we set win_size = 2 in our paper 14 | ''' 15 | def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm): 16 | super().__init__() 17 | self.d_model = d_model 18 | self.win_size = win_size 19 | self.linear_trans = nn.Linear(win_size * d_model, d_model) 20 | self.norm = norm_layer(win_size * d_model) 21 | 22 | def forward(self, x): 23 | """ 24 | x: B, ts_d, L, d_model 25 | """ 26 | batch_size, ts_d, seg_num, d_model = x.shape 27 | pad_num = seg_num % self.win_size 28 | if pad_num != 0: 29 | pad_num = self.win_size - pad_num 30 | x = torch.cat((x, x[:, :, -pad_num:, :]), dim = -2) 31 | 32 | seg_to_merge = [] 33 | for i in range(self.win_size): 34 | seg_to_merge.append(x[:, :, i::self.win_size, :]) 35 | x = torch.cat(seg_to_merge, -1) # [B, ts_d, seg_num/win_size, win_size*d_model] 36 | 37 | x = self.norm(x) 38 | x = self.linear_trans(x) 39 | 40 | return x 41 | 42 | class scale_block(nn.Module): 43 | ''' 44 | We can use one segment merging layer followed by multiple TSA layers in each scale 45 | the parameter `depth' determines the number of TSA layers used in each scale 46 | We set depth = 1 in the paper 47 | ''' 48 | def __init__(self, win_size, d_model, n_heads, d_ff, depth, dropout, \ 49 | seg_num = 10, factor=10): 50 | super(scale_block, self).__init__() 51 | 52 | if (win_size > 1): 53 | self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm) 54 | else: 55 | self.merge_layer = None 56 | 57 | self.encode_layers = nn.ModuleList() 58 | 59 | for i in range(depth): 60 | self.encode_layers.append(TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, \ 61 | d_ff, dropout)) 62 | 63 | def forward(self, x): 64 | _, ts_dim, _, _ = x.shape 65 | 66 | if self.merge_layer is not None: 67 | x = self.merge_layer(x) 68 | 69 | for layer in self.encode_layers: 70 | x = layer(x) 71 | 72 | return x 73 | 74 | class Encoder(nn.Module): 75 | ''' 76 | The Encoder of Crossformer. 77 | ''' 78 | def __init__(self, e_blocks, win_size, d_model, n_heads, d_ff, block_depth, dropout, 79 | in_seg_num = 10, factor=10): 80 | super(Encoder, self).__init__() 81 | self.encode_blocks = nn.ModuleList() 82 | 83 | self.encode_blocks.append(scale_block(1, d_model, n_heads, d_ff, block_depth, dropout,\ 84 | in_seg_num, factor)) 85 | for i in range(1, e_blocks): 86 | self.encode_blocks.append(scale_block(win_size, d_model, n_heads, d_ff, block_depth, dropout,\ 87 | ceil(in_seg_num/win_size**i), factor)) 88 | 89 | def forward(self, x): 90 | encode_x = [] 91 | encode_x.append(x) 92 | 93 | for block in self.encode_blocks: 94 | x = block(x) 95 | encode_x.append(x) 96 | 97 | return encode_x -------------------------------------------------------------------------------- /models/cross_models/cross_former.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat 5 | 6 | from models.cross_models.cross_encoder import Encoder 7 | from models.cross_models.cross_decoder import Decoder 8 | from models.cross_models.attn import FullAttention, AttentionLayer, TwoStageAttentionLayer 9 | from models.cross_models.cross_embed import DSW_embedding 10 | 11 | from math import ceil 12 | 13 | class Crossformer(nn.Module): 14 | def __init__(self, data_dim, in_len, out_len, seg_len, win_size = 4, 15 | factor=10, d_model=512, d_ff = 1024, n_heads=8, e_layers=3, 16 | dropout=0.0, baseline = False, device=torch.device('cuda:0')): 17 | super(Crossformer, self).__init__() 18 | self.data_dim = data_dim 19 | self.in_len = in_len 20 | self.out_len = out_len 21 | self.seg_len = seg_len 22 | self.merge_win = win_size 23 | 24 | self.baseline = baseline 25 | 26 | self.device = device 27 | 28 | # The padding operation to handle invisible sgemnet length 29 | self.pad_in_len = ceil(1.0 * in_len / seg_len) * seg_len 30 | self.pad_out_len = ceil(1.0 * out_len / seg_len) * seg_len 31 | self.in_len_add = self.pad_in_len - self.in_len 32 | 33 | # Embedding 34 | self.enc_value_embedding = DSW_embedding(seg_len, d_model) 35 | self.enc_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_in_len // seg_len), d_model)) 36 | self.pre_norm = nn.LayerNorm(d_model) 37 | 38 | # Encoder 39 | self.encoder = Encoder(e_layers, win_size, d_model, n_heads, d_ff, block_depth = 1, \ 40 | dropout = dropout,in_seg_num = (self.pad_in_len // seg_len), factor = factor) 41 | 42 | # Decoder 43 | self.dec_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_out_len // seg_len), d_model)) 44 | self.decoder = Decoder(seg_len, e_layers + 1, d_model, n_heads, d_ff, dropout, \ 45 | out_seg_num = (self.pad_out_len // seg_len), factor = factor) 46 | 47 | def forward(self, x_seq): 48 | if (self.baseline): 49 | base = x_seq.mean(dim = 1, keepdim = True) 50 | else: 51 | base = 0 52 | batch_size = x_seq.shape[0] 53 | if (self.in_len_add != 0): 54 | x_seq = torch.cat((x_seq[:, :1, :].expand(-1, self.in_len_add, -1), x_seq), dim = 1) 55 | 56 | x_seq = self.enc_value_embedding(x_seq) 57 | x_seq += self.enc_pos_embedding 58 | x_seq = self.pre_norm(x_seq) 59 | 60 | enc_out = self.encoder(x_seq) 61 | 62 | dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat = batch_size) 63 | predict_y = self.decoder(dec_in, enc_out) 64 | 65 | 66 | return base + predict_y[:, :self.out_len, :] -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class DecoderLayer(nn.Module): 6 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None, 7 | dropout=0.1, activation="relu"): 8 | super(DecoderLayer, self).__init__() 9 | d_ff = d_ff or 4*d_model 10 | self.self_attention = self_attention 11 | self.cross_attention = cross_attention 12 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 13 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 14 | self.norm1 = nn.LayerNorm(d_model) 15 | self.norm2 = nn.LayerNorm(d_model) 16 | self.norm3 = nn.LayerNorm(d_model) 17 | self.dropout = nn.Dropout(dropout) 18 | self.activation = F.relu if activation == "relu" else F.gelu 19 | 20 | def forward(self, x, cross, x_mask=None, cross_mask=None): 21 | x = x + self.dropout(self.self_attention( 22 | x, x, x, 23 | attn_mask=x_mask 24 | )[0]) 25 | x = self.norm1(x) 26 | 27 | x = x + self.dropout(self.cross_attention( 28 | x, cross, cross, 29 | attn_mask=cross_mask 30 | )[0]) 31 | 32 | y = x = self.norm2(x) 33 | y = self.dropout(self.activation(self.conv1(y.transpose(-1,1)))) 34 | y = self.dropout(self.conv2(y).transpose(-1,1)) 35 | 36 | return self.norm3(x+y) 37 | 38 | class Decoder(nn.Module): 39 | def __init__(self, layers, norm_layer=None): 40 | super(Decoder, self).__init__() 41 | self.layers = nn.ModuleList(layers) 42 | self.norm = norm_layer 43 | 44 | def forward(self, x, cross, x_mask=None, cross_mask=None): 45 | for layer in self.layers: 46 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) 47 | 48 | if self.norm is not None: 49 | x = self.norm(x) 50 | 51 | return x -------------------------------------------------------------------------------- /models/embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | class PositionalEmbedding(nn.Module): 8 | def __init__(self, d_model, max_len=5000): 9 | super(PositionalEmbedding, self).__init__() 10 | # Compute the positional encodings once in log space. 11 | pe = torch.zeros(max_len, d_model).float() 12 | pe.require_grad = False 13 | 14 | position = torch.arange(0, max_len).float().unsqueeze(1) 15 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 16 | 17 | pe[:, 0::2] = torch.sin(position * div_term) 18 | pe[:, 1::2] = torch.cos(position * div_term) 19 | 20 | pe = pe.unsqueeze(0) 21 | self.register_buffer('pe', pe) 22 | 23 | def forward(self, x): 24 | return self.pe[:, :x.size(1)] 25 | 26 | class TokenEmbedding(nn.Module): 27 | def __init__(self, c_in, d_model): 28 | super(TokenEmbedding, self).__init__() 29 | padding = 1 if torch.__version__>='1.5.0' else 2 30 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 31 | kernel_size=3, padding=padding, padding_mode='circular') 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv1d): 34 | nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu') 35 | 36 | def forward(self, x): 37 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2) 38 | return x 39 | 40 | class FixedEmbedding(nn.Module): 41 | def __init__(self, c_in, d_model): 42 | super(FixedEmbedding, self).__init__() 43 | 44 | w = torch.zeros(c_in, d_model).float() 45 | w.require_grad = False 46 | 47 | position = torch.arange(0, c_in).float().unsqueeze(1) 48 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 49 | 50 | w[:, 0::2] = torch.sin(position * div_term) 51 | w[:, 1::2] = torch.cos(position * div_term) 52 | 53 | self.emb = nn.Embedding(c_in, d_model) 54 | self.emb.weight = nn.Parameter(w, requires_grad=False) 55 | 56 | def forward(self, x): 57 | return self.emb(x).detach() 58 | 59 | class TemporalEmbedding(nn.Module): 60 | def __init__(self, d_model, embed_type='fixed', freq='h'): 61 | super(TemporalEmbedding, self).__init__() 62 | 63 | minute_size = 4; hour_size = 24 64 | weekday_size = 7; day_size = 32; month_size = 13 65 | 66 | Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding 67 | if freq=='t': 68 | self.minute_embed = Embed(minute_size, d_model) 69 | self.hour_embed = Embed(hour_size, d_model) 70 | self.weekday_embed = Embed(weekday_size, d_model) 71 | self.day_embed = Embed(day_size, d_model) 72 | self.month_embed = Embed(month_size, d_model) 73 | 74 | def forward(self, x): 75 | x = x.long() 76 | 77 | minute_x = self.minute_embed(x[:,:,4]) if hasattr(self, 'minute_embed') else 0. 78 | hour_x = self.hour_embed(x[:,:,3]) 79 | weekday_x = self.weekday_embed(x[:,:,2]) 80 | day_x = self.day_embed(x[:,:,1]) 81 | month_x = self.month_embed(x[:,:,0]) 82 | 83 | return hour_x + weekday_x + day_x + month_x + minute_x 84 | 85 | class TimeFeatureEmbedding(nn.Module): 86 | def __init__(self, d_model, embed_type='timeF', freq='h'): 87 | super(TimeFeatureEmbedding, self).__init__() 88 | 89 | freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3} 90 | d_inp = freq_map[freq] 91 | self.embed = nn.Linear(d_inp, d_model) 92 | 93 | def forward(self, x): 94 | return self.embed(x) 95 | 96 | class DataEmbedding(nn.Module): 97 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 98 | super(DataEmbedding, self).__init__() 99 | 100 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 101 | self.position_embedding = PositionalEmbedding(d_model=d_model) 102 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) 103 | 104 | self.dropout = nn.Dropout(p=dropout) 105 | 106 | def forward(self, x, x_mark): 107 | x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark) 108 | 109 | return self.dropout(x) -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ConvLayer(nn.Module): 6 | def __init__(self, c_in): 7 | super(ConvLayer, self).__init__() 8 | padding = 1 if torch.__version__>='1.5.0' else 2 9 | self.downConv = nn.Conv1d(in_channels=c_in, 10 | out_channels=c_in, 11 | kernel_size=3, 12 | padding=padding, 13 | padding_mode='circular') 14 | self.norm = nn.BatchNorm1d(c_in) 15 | self.activation = nn.ELU() 16 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 17 | 18 | def forward(self, x): 19 | x = self.downConv(x.permute(0, 2, 1)) 20 | x = self.norm(x) 21 | x = self.activation(x) 22 | x = self.maxPool(x) 23 | x = x.transpose(1,2) 24 | return x 25 | 26 | class EncoderLayer(nn.Module): 27 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): 28 | super(EncoderLayer, self).__init__() 29 | d_ff = d_ff or 4*d_model 30 | self.attention = attention 31 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) 32 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) 33 | self.norm1 = nn.LayerNorm(d_model) 34 | self.norm2 = nn.LayerNorm(d_model) 35 | self.dropout = nn.Dropout(dropout) 36 | self.activation = F.relu if activation == "relu" else F.gelu 37 | 38 | def forward(self, x, attn_mask=None): 39 | # x [B, L, D] 40 | # x = x + self.dropout(self.attention( 41 | # x, x, x, 42 | # attn_mask = attn_mask 43 | # )) 44 | new_x, attn = self.attention( 45 | x, x, x, 46 | attn_mask = attn_mask 47 | ) 48 | x = x + self.dropout(new_x) 49 | 50 | y = x = self.norm1(x) 51 | y = self.dropout(self.activation(self.conv1(y.transpose(-1,1)))) 52 | y = self.dropout(self.conv2(y).transpose(-1,1)) 53 | 54 | return self.norm2(x+y), attn 55 | 56 | class Encoder(nn.Module): 57 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None): 58 | super(Encoder, self).__init__() 59 | self.attn_layers = nn.ModuleList(attn_layers) 60 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None 61 | self.norm = norm_layer 62 | 63 | def forward(self, x, attn_mask=None): 64 | # x [B, L, D] 65 | attns = [] 66 | if self.conv_layers is not None: 67 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): 68 | x, attn = attn_layer(x, attn_mask=attn_mask) 69 | x = conv_layer(x) 70 | attns.append(attn) 71 | x, attn = self.attn_layers[-1](x, attn_mask=attn_mask) 72 | attns.append(attn) 73 | else: 74 | for attn_layer in self.attn_layers: 75 | x, attn = attn_layer(x, attn_mask=attn_mask) 76 | attns.append(attn) 77 | 78 | if self.norm is not None: 79 | x = self.norm(x) 80 | 81 | return x, attns 82 | 83 | class EncoderStack(nn.Module): 84 | def __init__(self, encoders, inp_lens): 85 | super(EncoderStack, self).__init__() 86 | self.encoders = nn.ModuleList(encoders) 87 | self.inp_lens = inp_lens 88 | 89 | def forward(self, x, attn_mask=None): 90 | # x [B, L, D] 91 | x_stack = []; attns = [] 92 | for i_len, encoder in zip(self.inp_lens, self.encoders): 93 | inp_len = x.shape[1]//(2**i_len) 94 | x_s, attn = encoder(x[:, -inp_len:, :]) 95 | x_stack.append(x_s); attns.append(attn) 96 | x_stack = torch.cat(x_stack, -2) 97 | 98 | return x_stack, attns 99 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pdb 5 | from utils.masking import TriangularCausalMask, ProbMask 6 | from models.encoder import Encoder, EncoderLayer, ConvLayer, EncoderStack 7 | from models.decoder import Decoder, DecoderLayer 8 | from models.attn import FullAttention, ProbAttention, AttentionLayer 9 | from models.embed import DataEmbedding 10 | 11 | class Informer(nn.Module): 12 | def __init__(self, enc_in, dec_in, c_out, seq_len, label_len, out_len, 13 | factor=5, d_model=512, n_heads=8, e_layers=3, d_layers=2, d_ff=512, 14 | dropout=0.0, attn='prob', embed='fixed', freq='h', activation='gelu', 15 | output_attention = False, distil=True, mix=True, 16 | device=torch.device('cuda:0')): 17 | super(Informer, self).__init__() 18 | self.pred_len = out_len 19 | self.attn = attn 20 | self.output_attention = output_attention 21 | 22 | # Encoding 23 | self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout) 24 | self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout) 25 | # Attention 26 | Attn = ProbAttention if attn=='prob' else FullAttention 27 | # Encoder 28 | self.encoder = Encoder( 29 | [ 30 | EncoderLayer( 31 | AttentionLayer(Attn(False, factor, attention_dropout=dropout, output_attention=output_attention), 32 | d_model, n_heads, mix=False), 33 | d_model, 34 | d_ff, 35 | dropout=dropout, 36 | activation=activation 37 | ) for l in range(e_layers) 38 | ], 39 | [ 40 | ConvLayer( 41 | d_model 42 | ) for l in range(e_layers-1) 43 | ] if distil else None, 44 | norm_layer=torch.nn.LayerNorm(d_model) 45 | ) 46 | # Decoder 47 | self.decoder = Decoder( 48 | [ 49 | DecoderLayer( 50 | AttentionLayer(Attn(True, factor, attention_dropout=dropout, output_attention=False), 51 | d_model, n_heads, mix=mix), 52 | AttentionLayer(FullAttention(False, factor, attention_dropout=dropout, output_attention=False), 53 | d_model, n_heads, mix=False), 54 | d_model, 55 | d_ff, 56 | dropout=dropout, 57 | activation=activation, 58 | ) 59 | for l in range(d_layers) 60 | ], 61 | norm_layer=torch.nn.LayerNorm(d_model) 62 | ) 63 | # self.end_conv1 = nn.Conv1d(in_channels=label_len+out_len, out_channels=out_len, kernel_size=1, bias=True) 64 | # self.end_conv2 = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=1, bias=True) 65 | self.projection = nn.Linear(d_model, c_out, bias=True) 66 | 67 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 68 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 69 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 70 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 71 | 72 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 73 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) 74 | 75 | dec_out = self.projection(dec_out) 76 | 77 | # dec_out = self.end_conv1(dec_out) 78 | # dec_out = self.end_conv2(dec_out.transpose(2,1)).transpose(1,2) 79 | if self.output_attention: 80 | return dec_out[:,-self.pred_len:,:], attns 81 | else: 82 | return dec_out[:,-self.pred_len:,:] # [B, L, D] 83 | 84 | 85 | class InformerStack(nn.Module): 86 | def __init__(self, enc_in, dec_in, c_out, seq_len, label_len, out_len, 87 | factor=5, d_model=512, n_heads=8, e_layers=[3,2,1], d_layers=2, d_ff=512, 88 | dropout=0.0, attn='prob', embed='fixed', freq='h', activation='gelu', 89 | output_attention = False, distil=True, mix=True, 90 | device=torch.device('cuda:0')): 91 | super(InformerStack, self).__init__() 92 | self.pred_len = out_len 93 | self.attn = attn 94 | self.output_attention = output_attention 95 | 96 | # Encoding 97 | self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout) 98 | self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout) 99 | # Attention 100 | Attn = ProbAttention if attn=='prob' else FullAttention 101 | # Encoder 102 | 103 | inp_lens = list(range(len(e_layers))) # [0,1,2,...] you can customize here 104 | encoders = [ 105 | Encoder( 106 | [ 107 | EncoderLayer( 108 | AttentionLayer(Attn(False, factor, attention_dropout=dropout, output_attention=output_attention), 109 | d_model, n_heads, mix=False), 110 | d_model, 111 | d_ff, 112 | dropout=dropout, 113 | activation=activation 114 | ) for l in range(el) 115 | ], 116 | [ 117 | ConvLayer( 118 | d_model 119 | ) for l in range(el-1) 120 | ] if distil else None, 121 | norm_layer=torch.nn.LayerNorm(d_model) 122 | ) for el in e_layers] 123 | self.encoder = EncoderStack(encoders, inp_lens) 124 | # Decoder 125 | self.decoder = Decoder( 126 | [ 127 | DecoderLayer( 128 | AttentionLayer(Attn(True, factor, attention_dropout=dropout, output_attention=False), 129 | d_model, n_heads, mix=mix), 130 | AttentionLayer(FullAttention(False, factor, attention_dropout=dropout, output_attention=False), 131 | d_model, n_heads, mix=False), 132 | d_model, 133 | d_ff, 134 | dropout=dropout, 135 | activation=activation, 136 | ) 137 | for l in range(d_layers) 138 | ], 139 | norm_layer=torch.nn.LayerNorm(d_model) 140 | ) 141 | # self.end_conv1 = nn.Conv1d(in_channels=label_len+out_len, out_channels=out_len, kernel_size=1, bias=True) 142 | # self.end_conv2 = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=1, bias=True) 143 | self.projection = nn.Linear(d_model, c_out, bias=True) 144 | 145 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, 146 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): 147 | enc_out = self.enc_embedding(x_enc, x_mark_enc) 148 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) 149 | 150 | dec_out = self.dec_embedding(x_dec, x_mark_dec) 151 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) 152 | dec_out = self.projection(dec_out) 153 | 154 | # dec_out = self.end_conv1(dec_out) 155 | # dec_out = self.end_conv2(dec_out.transpose(2,1)).transpose(1,2) 156 | if self.output_attention: 157 | return dec_out[:,-self.pred_len:,:], attns 158 | else: 159 | return dec_out[:,-self.pred_len:,:] # [B, L, D] 160 | -------------------------------------------------------------------------------- /models/ts2vec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzhang114/OneNet/076188ad4c1b28cbda8cb4b1e231499b233e1437/models/ts2vec/__init__.py -------------------------------------------------------------------------------- /models/ts2vec/dev.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import pdb 6 | 7 | from itertools import chain 8 | 9 | class SamePadConv(nn.Module): 10 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1, gamma=0.9): 11 | super().__init__() 12 | self.receptive_field = (kernel_size - 1) * dilation + 1 13 | padding = self.receptive_field // 2 14 | self.conv = nn.Conv1d( 15 | in_channels, out_channels, kernel_size, 16 | padding=padding, 17 | dilation=dilation, 18 | groups=groups, bias=False 19 | ) 20 | self.bias = torch.nn.Parameter(torch.zeros([out_channels]),requires_grad=True) 21 | self.padding=padding 22 | self.dilation = dilation 23 | self.kernel_size= kernel_size 24 | 25 | self.grad_dim, self.shape = [], [] 26 | for p in self.conv.parameters(): 27 | self.grad_dim.append(p.numel()) 28 | self.shape.append(p.size()) 29 | self.dim = sum(self.grad_dim) 30 | 31 | self.n_chunks = in_channels 32 | self.chunk_in_d = self.dim // self.n_chunks 33 | self.chunk_out_d = int(in_channels*kernel_size// self.n_chunks) 34 | self.grads = torch.Tensor(sum(self.grad_dim)).fill_(0).cuda() 35 | nh=64 36 | self.controller = nn.Sequential(nn.Linear(self.chunk_in_d, nh), nn.SiLU()) 37 | self.calib_w = nn.Linear(nh, self.chunk_out_d) 38 | self.calib_b = nn.Linear(nh, out_channels//in_channels) 39 | self.calib_f = nn.Linear(nh, out_channels//in_channels) 40 | 41 | #self.calib_w = torch.nn.Parameter(torch.ones(out_channels, in_channels,1), requires_grad = True) 42 | #self.calib_b = torch.nn.Parameter(torch.zeros([out_channels]), requires_grad = True) 43 | #self.calib_f = torch.nn.Parameter(torch.ones(1,out_channels,1), requires_grad = True) 44 | 45 | self.remove = 1 if self.receptive_field % 2 == 0 else 0 46 | self.gamma = gamma 47 | 48 | def ctrl_params(self): 49 | c_iter = chain(self.controller.parameters(), self.calib_w.parameters(), 50 | self.calib_b.parameters(), self.calib_f.parameters()) 51 | for p in c_iter: 52 | yield p 53 | 54 | def store_grad(self): 55 | #print('storing grad') 56 | grad = self.conv.weight.grad.data.clone() 57 | grad = nn.functional.normalize(grad) 58 | grad = grad.view(-1) 59 | 60 | self.grads = self.gamma * self.grads + (1-self.gamma) * grad 61 | 62 | def fw_chunks(self): 63 | x = self.grads.view(self.n_chunks, -1) 64 | rep = self.controller(x) 65 | w = self.calib_w(rep) 66 | b = self.calib_b(rep) 67 | f = self.calib_f(rep) 68 | f = f.view(-1).unsqueeze(0).unsqueeze(2) 69 | 70 | return w.unsqueeze(0) ,b.view(-1),f 71 | 72 | def forward(self, x): 73 | w,b,f = self.fw_chunks() 74 | d0, d1 = self.conv.weight.shape[1:] 75 | 76 | cw = self.conv.weight * w 77 | #cw = self.conv.weight 78 | try: 79 | conv_out = F.conv1d(x, cw, padding=self.padding, dilation=self.dilation, bias = self.bias * b) 80 | out = f * conv_out 81 | except: pdb.set_trace() 82 | return out 83 | 84 | def representation(self, x): 85 | out = self.conv(x) 86 | if self.remove > 0: 87 | out = out[:, :, : -self.remove] 88 | return out 89 | 90 | def _forward(self, x): 91 | out = self.conv(x) 92 | if self.remove > 0: 93 | out = out[:, :, : -self.remove] 94 | return out 95 | 96 | class ConvBlock(nn.Module): 97 | def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False, gamma=0.9): 98 | super().__init__() 99 | self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation, gamma=gamma) 100 | self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation, gamma=gamma) 101 | self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None 102 | 103 | def ctrl_params(self): 104 | c_iter = chain(self.conv1.controller.parameters(), self.conv1.calib_w.parameters(), 105 | self.conv1.calib_b.parameters(), self.conv1.calib_f.parameters(), 106 | self.conv2.controller.parameters(), self.conv2.calib_w.parameters(), 107 | self.conv2.calib_b.parameters(), self.conv2.calib_f.parameters()) 108 | 109 | return c_iter 110 | 111 | 112 | 113 | def forward(self, x): 114 | residual = x if self.projector is None else self.projector(x) 115 | x = F.gelu(x) 116 | x = self.conv1(x) 117 | x = F.gelu(x) 118 | x = self.conv2(x) 119 | return x + residual 120 | 121 | class DilatedConvEncoder(nn.Module): 122 | def __init__(self, in_channels, channels, kernel_size, gamma=0.9): 123 | super().__init__() 124 | self.net = nn.Sequential(*[ 125 | ConvBlock( 126 | channels[i-1] if i > 0 else in_channels, 127 | channels[i], 128 | kernel_size=kernel_size, 129 | dilation=2**i, 130 | final=(i == len(channels)-1), gamma=gamma 131 | ) 132 | for i in range(len(channels)) 133 | ]) 134 | def ctrl_params(self): 135 | ctrl = [] 136 | for l in self.net: 137 | ctrl.append(l.ctrl_params()) 138 | c = chain(*ctrl) 139 | for p in c: 140 | yield p 141 | def forward(self, x): 142 | return self.net(x) 143 | -------------------------------------------------------------------------------- /models/ts2vec/dilated_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class SamePadConv(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1): 8 | super().__init__() 9 | self.receptive_field = (kernel_size - 1) * dilation + 1 10 | padding = self.receptive_field // 2 11 | self.conv = nn.Conv1d( 12 | in_channels, out_channels, kernel_size, 13 | padding=padding, 14 | dilation=dilation, 15 | groups=groups 16 | ) 17 | self.remove = 1 if self.receptive_field % 2 == 0 else 0 18 | 19 | def forward(self, x): 20 | out = self.conv(x) 21 | if self.remove > 0: 22 | out = out[:, :, : -self.remove] 23 | return out 24 | 25 | class ConvBlock(nn.Module): 26 | def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False): 27 | super().__init__() 28 | self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation) 29 | self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation) 30 | self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None 31 | 32 | def forward(self, x): 33 | residual = x if self.projector is None else self.projector(x) 34 | x = F.gelu(x) 35 | x = self.conv1(x) 36 | x = F.gelu(x) 37 | x = self.conv2(x) 38 | return x + residual 39 | 40 | class DilatedConvEncoder(nn.Module): 41 | def __init__(self, in_channels, channels, kernel_size): 42 | super().__init__() 43 | self.net = nn.Sequential(*[ 44 | ConvBlock( 45 | channels[i-1] if i > 0 else in_channels, 46 | channels[i], 47 | kernel_size=kernel_size, 48 | dilation=2**i, 49 | final=(i == len(channels)-1) 50 | ) 51 | for i in range(len(channels)) 52 | ]) 53 | 54 | def forward(self, x): 55 | return self.net(x) 56 | -------------------------------------------------------------------------------- /models/ts2vec/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.fft as fft 8 | 9 | import numpy as np 10 | from einops import rearrange, reduce, repeat 11 | 12 | from .dilated_conv import DilatedConvEncoder 13 | 14 | 15 | def generate_continuous_mask(B, T, n=5, l=0.1): 16 | res = torch.full((B, T), True, dtype=torch.bool) 17 | if isinstance(n, float): 18 | n = int(n * T) 19 | n = max(min(n, T // 2), 1) 20 | 21 | if isinstance(l, float): 22 | l = int(l * T) 23 | l = max(l, 1) 24 | 25 | for i in range(B): 26 | for _ in range(n): 27 | t = np.random.randint(T-l+1) 28 | res[i, t:t+l] = False 29 | return res 30 | 31 | def generate_binomial_mask(B, T, p=0.5): 32 | return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool) 33 | 34 | class TSEncoder(nn.Module): 35 | def __init__(self, input_dims, output_dims, hidden_dims=64, depth=10, mask_mode='binomial'): 36 | super().__init__() 37 | self.input_dims = input_dims 38 | self.output_dims = output_dims 39 | self.hidden_dims = hidden_dims 40 | self.mask_mode = mask_mode 41 | self.input_fc = nn.Linear(input_dims, hidden_dims) 42 | self.feature_extractor = DilatedConvEncoder( 43 | hidden_dims, 44 | [hidden_dims] * depth + [output_dims], 45 | kernel_size=3 46 | ) 47 | self.repr_dropout = nn.Dropout(p=0.1) 48 | 49 | # [64] * 10 + [320] = [64, 64, 64, 64, 64, 64, 64, 64, 64 ,64, 320] = 11 items 50 | # for i in range(len(...)) -> 0, 1, ..., 10 51 | 52 | def forward(self, x, mask=None): # x: B x T x input_dims 53 | nan_mask = ~x.isnan().any(axis=-1) 54 | x[~nan_mask] = 0 55 | x = self.input_fc(x) # B x T x Ch 56 | 57 | # generate & apply mask 58 | if mask is None: 59 | if self.training: 60 | mask = self.mask_mode 61 | else: 62 | mask = 'all_true' 63 | 64 | if mask == 'binomial': 65 | mask = generate_binomial_mask(x.size(0), x.size(1)).to(x.device) 66 | elif mask == 'continuous': 67 | mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device) 68 | elif mask == 'all_true': 69 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) 70 | elif mask == 'all_false': 71 | mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool) 72 | elif mask == 'mask_last': 73 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) 74 | mask[:, -1] = False 75 | 76 | mask &= nan_mask 77 | x[~mask] = 0 78 | 79 | # conv encoder 80 | x = x.transpose(1, 2) # B x Ch x T 81 | x = self.repr_dropout(self.feature_extractor(x)) # B x Co x T 82 | x = x.transpose(1, 2) # B x T x Co 83 | 84 | return x 85 | 86 | 87 | class BandedFourierLayer(nn.Module): 88 | 89 | def __init__(self, in_channels, out_channels, band, num_bands, freq_mixing=False, bias=True, length=201): 90 | super().__init__() 91 | 92 | self.length = length 93 | self.total_freqs = (self.length // 2) + 1 94 | 95 | self.in_channels = in_channels 96 | self.out_channels = out_channels 97 | 98 | self.freq_mixing = freq_mixing 99 | 100 | self.band = band # zero indexed 101 | self.num_bands = num_bands 102 | 103 | self.num_freqs = self.total_freqs // self.num_bands + (self.total_freqs % self.num_bands if self.band == self.num_bands - 1 else 0) 104 | 105 | self.start = self.band * (self.total_freqs // self.num_bands) 106 | self.end = self.start + self.num_freqs 107 | 108 | 109 | # case: from other frequencies 110 | if self.freq_mixing: 111 | self.weight = nn.Parameter(torch.empty((self.num_freqs, in_channels, self.total_freqs, out_channels), dtype=torch.cfloat)) 112 | else: 113 | self.weight = nn.Parameter(torch.empty((self.num_freqs, in_channels, out_channels), dtype=torch.cfloat)) 114 | if bias: 115 | self.bias = nn.Parameter(torch.empty((self.num_freqs, out_channels), dtype=torch.cfloat)) 116 | else: 117 | self.bias = None 118 | self.reset_parameters() 119 | 120 | def forward(self, input): 121 | # input - b t d 122 | b, t, _ = input.shape 123 | input_fft = fft.rfft(input, dim=1) 124 | output_fft = torch.zeros(b, t // 2 + 1, self.out_channels, device=input.device, dtype=torch.cfloat) 125 | output_fft[:, self.start:self.end] = self._forward(input_fft) 126 | return fft.irfft(output_fft, n=input.size(1), dim=1) 127 | 128 | def _forward(self, input): 129 | if self.freq_mixing: 130 | output = torch.einsum('bai,tiao->bto', input, self.weight) 131 | else: 132 | output = torch.einsum('bti,tio->bto', input[:, self.start:self.end], self.weight) 133 | if self.bias is None: 134 | return output 135 | return output + self.bias 136 | 137 | def reset_parameters(self) -> None: 138 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 139 | if self.bias is not None: 140 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 141 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 142 | nn.init.uniform_(self.bias, -bound, bound) 143 | 144 | 145 | class GlobalLocalMultiscaleTSEncoder(nn.Module): 146 | 147 | def __init__(self, input_dims, output_dims, 148 | kernels: List[int], 149 | num_bands: int, 150 | freq_mixing: bool, 151 | length: int, 152 | mode = 0, 153 | hidden_dims=64, depth=10, mask_mode='binomial'): 154 | super().__init__() 155 | 156 | self.mode = mode 157 | 158 | self.input_dims = input_dims 159 | self.output_dims = output_dims 160 | self.hidden_dims = hidden_dims 161 | self.mask_mode = mask_mode 162 | self.input_fc = nn.Linear(input_dims, hidden_dims) 163 | self.feature_extractor = DilatedConvEncoder( 164 | hidden_dims, 165 | [hidden_dims] * depth + [output_dims], 166 | kernel_size=3 167 | ) 168 | 169 | self.kernels = kernels 170 | self.num_bands = num_bands 171 | 172 | self.convs = nn.ModuleList( 173 | [nn.Conv1d(output_dims, output_dims//2, k, padding=k-1) for k in kernels] 174 | ) 175 | self.fouriers = nn.ModuleList( 176 | [BandedFourierLayer(output_dims, output_dims//2, b, num_bands, 177 | freq_mixing=freq_mixing, length=length) for b in range(num_bands)] 178 | ) 179 | 180 | def forward(self, x, tcn_output=False, mask='all_true'): # x: B x T x input_dims 181 | nan_mask = ~x.isnan().any(axis=-1) 182 | x[~nan_mask] = 0 183 | x = self.input_fc(x) # B x T x Ch 184 | 185 | # generate & apply mask 186 | if mask is None: 187 | if self.training: 188 | mask = self.mask_mode 189 | else: 190 | mask = 'all_true' 191 | 192 | if mask == 'binomial': 193 | mask = generate_binomial_mask(x.size(0), x.size(1)).to(x.device) 194 | elif mask == 'continuous': 195 | mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device) 196 | elif mask == 'all_true': 197 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) 198 | elif mask == 'all_false': 199 | mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool) 200 | elif mask == 'mask_last': 201 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) 202 | mask[:, -1] = False 203 | 204 | mask &= nan_mask 205 | x[~mask] = 0 206 | 207 | # conv encoder 208 | x = x.transpose(1, 2) # B x Ch x T 209 | x = self.feature_extractor(x) # B x Co x T 210 | 211 | if tcn_output: 212 | return x.transpose(1, 2) 213 | 214 | if len(self.kernels) == 0: 215 | local_multiscale = None 216 | else: 217 | local_multiscale = [] 218 | for idx, mod in enumerate(self.convs): 219 | out = mod(x) # b d t 220 | if self.kernels[idx] != 1: 221 | out = out[..., :-(self.kernels[idx] - 1)] 222 | local_multiscale.append(out.transpose(1, 2)) # b t d 223 | local_multiscale = reduce( 224 | rearrange(local_multiscale, 'list b t d -> list b t d'), 225 | 'list b t d -> b t d', 'mean' 226 | ) 227 | 228 | x = x.transpose(1, 2) # B x T x Co 229 | 230 | if self.num_bands == 0: 231 | global_multiscale = None 232 | else: 233 | global_multiscale = [] 234 | for mod in self.fouriers: 235 | out = mod(x) # b t d 236 | global_multiscale.append(out) 237 | 238 | global_multiscale = global_multiscale[0] 239 | 240 | return torch.cat([local_multiscale, global_multiscale], dim=-1) 241 | -------------------------------------------------------------------------------- /models/ts2vec/fsnet_.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | from itertools import chain 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | def normalize(W): 10 | W_norm = torch.norm(W) 11 | W_norm = torch.relu(W_norm - 1) + 1 12 | W = W / W_norm 13 | return W 14 | 15 | 16 | class SamePadConv(nn.Module): 17 | 18 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1, gamma=0.9, device=None): 19 | super().__init__() 20 | 21 | self.device = device if device else torch.device('cuda:0') 22 | self.receptive_field = (kernel_size - 1) * dilation + 1 23 | padding = self.receptive_field // 2 24 | self.conv = nn.Conv1d( 25 | in_channels, out_channels, kernel_size, 26 | padding=padding, 27 | dilation=dilation, 28 | groups=groups, bias=False 29 | ) 30 | self.bias = torch.nn.Parameter(torch.zeros([out_channels]), requires_grad=True) 31 | self.padding = padding 32 | self.dilation = dilation 33 | self.kernel_size = kernel_size 34 | 35 | self.grad_dim, self.shape = [], [] 36 | for p in self.conv.parameters(): 37 | self.grad_dim.append(p.numel()) 38 | self.shape.append(p.size()) 39 | self.dim = sum(self.grad_dim) 40 | self.in_channels = in_channels 41 | self.out_features = out_channels 42 | 43 | self.n_chunks = in_channels 44 | self.chunk_in_d = self.dim // self.n_chunks 45 | self.chunk_out_d = int(in_channels * kernel_size // self.n_chunks) 46 | 47 | self.grads = torch.Tensor(sum(self.grad_dim)).fill_(0).to(self.device) 48 | self.f_grads = torch.Tensor(sum(self.grad_dim)).fill_(0).to(self.device) 49 | nh = 64 50 | self.controller = nn.Sequential(nn.Linear(self.chunk_in_d, nh), nn.SiLU()) 51 | self.calib_w = nn.Linear(nh, self.chunk_out_d) 52 | self.calib_b = nn.Linear(nh, out_channels // in_channels) 53 | self.calib_f = nn.Linear(nh, out_channels // in_channels) 54 | dim = self.n_chunks * (self.chunk_out_d + 2 * out_channels // in_channels) 55 | self.W = nn.Parameter(torch.empty(dim, 32), requires_grad=False) 56 | nn.init.xavier_uniform_(self.W.data) 57 | self.W.data = normalize(self.W.data) 58 | 59 | # self.calib_w = torch.nn.Parameter(torch.ones(out_channels, in_channels,1), requires_grad = True) 60 | # self.calib_b = torch.nn.Parameter(torch.zeros([out_channels]), requires_grad = True) 61 | # self.calib_f = torch.nn.Parameter(torch.ones(1,out_channels,1), requires_grad = True) 62 | 63 | # 确定是否需要移除最后一个像素 64 | self.remove = 1 if self.receptive_field % 2 == 0 else 0 65 | self.gamma = gamma 66 | self.f_gamma = 0.3 67 | self.cos = nn.CosineSimilarity(dim=0, eps=1e-6) 68 | self.trigger = 0 69 | self.tau = 0.75 70 | 71 | def ctrl_params(self): 72 | 73 | c_iter = chain(self.controller.parameters(), self.calib_w.parameters(), 74 | self.calib_b.parameters(), self.calib_f.parameters()) 75 | for p in c_iter: 76 | yield p 77 | 78 | def store_grad(self): 79 | 80 | # 实现梯度存储逻辑 81 | grad = self.conv.weight.grad.data.clone() 82 | grad = nn.functional.normalize(grad) 83 | grad = grad.view(-1) 84 | self.f_grads = self.f_gamma * self.f_grads + (1 - self.f_gamma) * grad 85 | self.grads = self.gamma * self.grads + (1 - self.gamma) * grad 86 | if not self.training: 87 | e = self.cos(self.f_grads, self.grads) 88 | if e < -self.tau: 89 | self.trigger = 1 90 | 91 | def fw_chunks(self): 92 | 93 | # 实现前向传播的块处理逻辑 94 | x = self.grads.view(self.n_chunks, -1) 95 | rep = self.controller(x) 96 | w = self.calib_w(rep) 97 | b = self.calib_b(rep) 98 | f = self.calib_f(rep) 99 | q = torch.cat([w.view(-1), b.view(-1), f.view(-1)]) 100 | if not hasattr(self, 'q_ema'): 101 | setattr(self, 'q_ema', torch.zeros(*q.size()).float().to(self.device)) 102 | 103 | self.q_ema = self.f_gamma * self.q_ema + (1 - self.f_gamma) * q 104 | q = self.q_ema 105 | if self.trigger == 1: 106 | dim = w.size(0) 107 | self.trigger = 0 108 | # read 109 | 110 | att = q @ self.W 111 | att = F.softmax(att / 0.5) 112 | 113 | v, idx = torch.topk(att, 2) 114 | ww = torch.index_select(self.W, 1, idx) 115 | old_w = ww @ v.clone().detach().unsqueeze(1).float() 116 | # write memory 117 | s_att = torch.zeros(att.size(0)).to(self.device) 118 | s_att[idx.squeeze().long()] = v.squeeze() 119 | W = old_w @ s_att.unsqueeze(0) 120 | mask = torch.ones(W.size()).to(self.device) 121 | mask[:, idx.squeeze().long()] = self.tau 122 | self.W.data = mask * self.W.data + (1 - mask) * W 123 | self.W.data = normalize(self.W.data) 124 | # retrieve 125 | ll = torch.split(old_w, dim) 126 | nw, nb, nf = w.size(1), b.size(1), f.size(1) 127 | o_w, o_b, o_f = torch.cat(*[ll[:nw]]), torch.cat(*[ll[nw:nw + nb]]), torch.cat(*[ll[-nf:]]) 128 | 129 | try: 130 | w = self.tau * w + (1 - self.tau) * o_w.view(w.size()) 131 | b = self.tau * b + (1 - self.tau) * o_b.view(b.size()) 132 | f = self.tau * f + (1 - self.tau) * o_f.view(f.size()) 133 | except: 134 | pdb.set_trace() 135 | f = f.view(-1).unsqueeze(0).unsqueeze(2) 136 | 137 | return w.unsqueeze(0), b.view(-1), f 138 | 139 | def forward(self, x): 140 | 141 | w, b, f = self.fw_chunks() 142 | 143 | cw = self.conv.weight * w 144 | try: 145 | conv_out = F.conv1d(x, cw, padding=self.padding, dilation=self.dilation, bias=self.bias * b) 146 | out = f * conv_out 147 | except: 148 | pdb.set_trace() 149 | return out 150 | 151 | def representation(self, x): 152 | 153 | out = self.conv(x) 154 | if self.remove > 0: 155 | out = out[:, :, : -self.remove] 156 | return out 157 | 158 | def _forward(self, x): 159 | 160 | out = self.conv(x) 161 | if self.remove > 0: 162 | out = out[:, :, : -self.remove] 163 | return out 164 | 165 | 166 | class ConvBlock(nn.Module): 167 | def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False, gamma=0.9, device=None): 168 | super().__init__() 169 | self.device = device if device else torch.device('cuda:0') 170 | self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation, gamma=gamma, 171 | device=self.device) 172 | self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation, gamma=gamma, 173 | device=self.device) 174 | self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None 175 | 176 | def ctrl_params(self): 177 | c_iter = chain(self.conv1.controller.parameters(), self.conv1.calib_w.parameters(), 178 | self.conv1.calib_b.parameters(), self.conv1.calib_f.parameters(), 179 | self.conv2.controller.parameters(), self.conv2.calib_w.parameters(), 180 | self.conv2.calib_b.parameters(), self.conv2.calib_f.parameters()) 181 | 182 | return c_iter 183 | 184 | def forward(self, x): 185 | residual = x if self.projector is None else self.projector(x) 186 | x = F.gelu(x) 187 | x = self.conv1(x) 188 | x = F.gelu(x) 189 | x = self.conv2(x) 190 | return x + residual 191 | 192 | 193 | class DilatedConvEncoder(nn.Module): 194 | def __init__(self, in_channels, channels, kernel_size, gamma=0.9, device=None): 195 | super().__init__() 196 | self.device = device if device else torch.device('cuda:0') 197 | self.net = nn.Sequential(*[ 198 | ConvBlock( 199 | channels[i - 1] if i > 0 else in_channels, 200 | channels[i], 201 | kernel_size=kernel_size, 202 | dilation=2 ** i, 203 | final=(i == len(channels) - 1), gamma=gamma, 204 | device=self.device 205 | ) 206 | for i in range(len(channels)) 207 | ]) 208 | 209 | def ctrl_params(self): 210 | ctrl = [] 211 | for l in self.net: 212 | ctrl.append(l.ctrl_params()) 213 | c = chain(*ctrl) 214 | for p in c: 215 | yield p 216 | 217 | def forward(self, x): 218 | return self.net(x) 219 | -------------------------------------------------------------------------------- /models/ts2vec/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | def hierarchical_contrastive_loss(z1, z2, alpha=0.5, temporal_unit=0): 6 | loss = torch.tensor(0., device=z1.device) 7 | d = 0 8 | while z1.size(1) > 1: 9 | if alpha != 0: 10 | loss += alpha * instance_contrastive_loss(z1, z2) 11 | if d >= temporal_unit: 12 | if 1 - alpha != 0: 13 | loss += (1 - alpha) * temporal_contrastive_loss(z1, z2) 14 | d += 1 15 | z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2) 16 | z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2) 17 | if z1.size(1) == 1: 18 | if alpha != 0: 19 | loss += alpha * instance_contrastive_loss(z1, z2) 20 | d += 1 21 | return loss / d 22 | 23 | def instance_contrastive_loss(z1, z2): 24 | B, T = z1.size(0), z1.size(1) 25 | if B == 1: 26 | return z1.new_tensor(0.) 27 | z = torch.cat([z1, z2], dim=0) # 2B x T x C 28 | z = z.transpose(0, 1) # T x 2B x C 29 | sim = torch.matmul(z, z.transpose(1, 2)) # T x 2B x 2B 30 | logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1) 31 | logits += torch.triu(sim, diagonal=1)[:, :, 1:] 32 | logits = -F.log_softmax(logits, dim=-1) 33 | 34 | i = torch.arange(B, device=z1.device) 35 | loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2 36 | return loss 37 | 38 | def temporal_contrastive_loss(z1, z2): 39 | B, T = z1.size(0), z1.size(1) 40 | if T == 1: 41 | return z1.new_tensor(0.) 42 | z = torch.cat([z1, z2], dim=1) # B x 2T x C 43 | sim = torch.matmul(z, z.transpose(1, 2)) # B x 2T x 2T 44 | logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # B x 2T x (2T-1) 45 | logits += torch.triu(sim, diagonal=1)[:, :, 1:] 46 | logits = -F.log_softmax(logits, dim=-1) 47 | 48 | t = torch.arange(T, device=z1.device) 49 | loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2 50 | return loss 51 | -------------------------------------------------------------------------------- /models/ts2vec/ncca_.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import pdb 6 | 7 | from itertools import chain 8 | 9 | class SamePadConv(nn.Module): 10 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1, gamma=0.9): 11 | super().__init__() 12 | self.receptive_field = (kernel_size - 1) * dilation + 1 13 | padding = self.receptive_field // 2 14 | self.conv = nn.Conv1d( 15 | in_channels, out_channels, kernel_size, 16 | padding=padding, 17 | dilation=dilation, 18 | groups=groups, bias=False 19 | ) 20 | self.bias = torch.nn.Parameter(torch.zeros([out_channels]),requires_grad=True) 21 | self.padding=padding 22 | self.dilation = dilation 23 | self.kernel_size= kernel_size 24 | 25 | self.grad_dim, self.shape = [], [] 26 | for p in self.conv.parameters(): 27 | self.grad_dim.append(p.numel()) 28 | self.shape.append(p.size()) 29 | self.dim = sum(self.grad_dim) 30 | 31 | self.calib_w = torch.nn.Parameter(torch.ones(out_channels, in_channels,1), requires_grad = True) 32 | self.calib_b = torch.nn.Parameter(torch.zeros([out_channels]), requires_grad = True) 33 | self.calib_f = torch.nn.Parameter(torch.ones(1,out_channels,1), requires_grad = True) 34 | #nn.init.xavier_uniform_(self.calib_w.data) 35 | #nn.init.xavier_uniform_(self.calib_b.data) 36 | #nn.init.xavier_uniform_(self.calib_f.data) 37 | 38 | self.remove = 1 if self.receptive_field % 2 == 0 else 0 39 | self.gamma = gamma 40 | 41 | def ctrl_params(self): 42 | c_iter = chain(self.controller.parameters(), self.calib_w.parameters(), 43 | self.calib_b.parameters(), self.calib_f.parameters()) 44 | for p in c_iter: 45 | yield p 46 | 47 | def forward(self, x): 48 | w,b,f = self.calib_w, self.calib_b, self.calib_f 49 | d0, d1 = self.conv.weight.shape[1:] 50 | 51 | cw = self.conv.weight * w 52 | #cw = self.conv.weight 53 | try: 54 | conv_out = F.conv1d(x, cw, padding=self.padding, dilation=self.dilation, bias = self.bias * b) 55 | out = conv_out + f * conv_out 56 | except: pdb.set_trace() 57 | return out 58 | 59 | def representation(self, x): 60 | out = self.conv(x) 61 | if self.remove > 0: 62 | out = out[:, :, : -self.remove] 63 | return out 64 | 65 | def _forward(self, x): 66 | out = self.conv(x) 67 | if self.remove > 0: 68 | out = out[:, :, : -self.remove] 69 | return out 70 | 71 | class ConvBlock(nn.Module): 72 | def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False, gamma=0.9): 73 | super().__init__() 74 | self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation, gamma=gamma) 75 | self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation, gamma=gamma) 76 | self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None 77 | 78 | def ctrl_params(self): 79 | c_iter = chain(self.conv1.controller.parameters(), self.conv1.calib_w.parameters(), 80 | self.conv1.calib_b.parameters(), self.conv1.calib_f.parameters(), 81 | self.conv2.controller.parameters(), self.conv2.calib_w.parameters(), 82 | self.conv2.calib_b.parameters(), self.conv2.calib_f.parameters()) 83 | 84 | return c_iter 85 | 86 | 87 | 88 | def forward(self, x): 89 | residual = x if self.projector is None else self.projector(x) 90 | x = F.gelu(x) 91 | x = self.conv1(x) 92 | x = F.gelu(x) 93 | x = self.conv2(x) 94 | return x + residual 95 | 96 | class DilatedConvEncoder(nn.Module): 97 | def __init__(self, in_channels, channels, kernel_size, gamma=0.9): 98 | super().__init__() 99 | self.net = nn.Sequential(*[ 100 | ConvBlock( 101 | channels[i-1] if i > 0 else in_channels, 102 | channels[i], 103 | kernel_size=kernel_size, 104 | dilation=2**i, 105 | final=(i == len(channels)-1), gamma=gamma 106 | ) 107 | for i in range(len(channels)) 108 | ]) 109 | 110 | def ctrl_params(self): 111 | ctrl = [] 112 | for l in self.net: 113 | ctrl.append(l.ctrl_params()) 114 | c = chain(*ctrl) 115 | for p in c: 116 | yield p 117 | def forward(self, x): 118 | return self.net(x) 119 | -------------------------------------------------------------------------------- /models/ts2vec/nomem.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.fft as fft 8 | 9 | import numpy as np 10 | from einops import rearrange, reduce, repeat 11 | 12 | from .dev import DilatedConvEncoder 13 | 14 | 15 | def generate_continuous_mask(B, T, n=5, l=0.1): 16 | res = torch.full((B, T), True, dtype=torch.bool) 17 | if isinstance(n, float): 18 | n = int(n * T) 19 | n = max(min(n, T // 2), 1) 20 | 21 | if isinstance(l, float): 22 | l = int(l * T) 23 | l = max(l, 1) 24 | 25 | for i in range(B): 26 | for _ in range(n): 27 | t = np.random.randint(T-l+1) 28 | res[i, t:t+l] = False 29 | return res 30 | 31 | def generate_binomial_mask(B, T, p=0.5): 32 | return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool) 33 | 34 | class TSEncoder(nn.Module): 35 | def __init__(self, input_dims, output_dims, hidden_dims=64, depth=10, mask_mode='binomial', gamma=0.9): 36 | super().__init__() 37 | self.input_dims = input_dims 38 | self.output_dims = output_dims 39 | self.hidden_dims = hidden_dims 40 | self.mask_mode = mask_mode 41 | self.input_fc = nn.Linear(input_dims, hidden_dims) 42 | self.feature_extractor = DilatedConvEncoder( 43 | hidden_dims, 44 | [hidden_dims] * depth + [output_dims], 45 | kernel_size=3, gamma=gamma 46 | ) 47 | self.repr_dropout = nn.Dropout(p=0.1) 48 | 49 | # [64] * 10 + [320] = [64, 64, 64, 64, 64, 64, 64, 64, 64 ,64, 320] = 11 items 50 | # for i in range(len(...)) -> 0, 1, ..., 10 51 | 52 | def ctrl_params(self): 53 | return self.feature_extractor.ctrl_params() 54 | 55 | def forward(self, x, mask=None): # x: B x T x input_dims 56 | nan_mask = ~x.isnan().any(axis=-1) 57 | x[~nan_mask] = 0 58 | x = self.input_fc(x) # B x T x Ch 59 | 60 | # generate & apply mask 61 | if mask is None: 62 | if self.training: 63 | mask = self.mask_mode 64 | else: 65 | mask = 'all_true' 66 | 67 | if mask == 'binomial': 68 | mask = generate_binomial_mask(x.size(0), x.size(1)).to(x.device) 69 | elif mask == 'continuous': 70 | mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device) 71 | elif mask == 'all_true': 72 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) 73 | elif mask == 'all_false': 74 | mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool) 75 | elif mask == 'mask_last': 76 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) 77 | mask[:, -1] = False 78 | 79 | mask &= nan_mask 80 | x[~mask] = 0 81 | 82 | # conv encoder 83 | x = x.transpose(1, 2) # B x Ch x T 84 | x = self.repr_dropout(self.feature_extractor(x)) # B x Co x T 85 | x = x.transpose(1, 2) # B x T x Co 86 | 87 | return x 88 | 89 | 90 | class BandedFourierLayer(nn.Module): 91 | 92 | def __init__(self, in_channels, out_channels, band, num_bands, freq_mixing=False, bias=True, length=201): 93 | super().__init__() 94 | 95 | self.length = length 96 | self.total_freqs = (self.length // 2) + 1 97 | 98 | self.in_channels = in_channels 99 | self.out_channels = out_channels 100 | 101 | self.freq_mixing = freq_mixing 102 | 103 | self.band = band # zero indexed 104 | self.num_bands = num_bands 105 | 106 | self.num_freqs = self.total_freqs // self.num_bands + (self.total_freqs % self.num_bands if self.band == self.num_bands - 1 else 0) 107 | 108 | self.start = self.band * (self.total_freqs // self.num_bands) 109 | self.end = self.start + self.num_freqs 110 | 111 | 112 | # case: from other frequencies 113 | if self.freq_mixing: 114 | self.weight = nn.Parameter(torch.empty((self.num_freqs, in_channels, self.total_freqs, out_channels), dtype=torch.cfloat)) 115 | else: 116 | self.weight = nn.Parameter(torch.empty((self.num_freqs, in_channels, out_channels), dtype=torch.cfloat)) 117 | if bias: 118 | self.bias = nn.Parameter(torch.empty((self.num_freqs, out_channels), dtype=torch.cfloat)) 119 | else: 120 | self.bias = None 121 | self.reset_parameters() 122 | 123 | def forward(self, input): 124 | # input - b t d 125 | b, t, _ = input.shape 126 | input_fft = fft.rfft(input, dim=1) 127 | output_fft = torch.zeros(b, t // 2 + 1, self.out_channels, device=input.device, dtype=torch.cfloat) 128 | output_fft[:, self.start:self.end] = self._forward(input_fft) 129 | return fft.irfft(output_fft, n=input.size(1), dim=1) 130 | 131 | def _forward(self, input): 132 | if self.freq_mixing: 133 | output = torch.einsum('bai,tiao->bto', input, self.weight) 134 | else: 135 | output = torch.einsum('bti,tio->bto', input[:, self.start:self.end], self.weight) 136 | if self.bias is None: 137 | return output 138 | return output + self.bias 139 | 140 | def reset_parameters(self) -> None: 141 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 142 | if self.bias is not None: 143 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 144 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 145 | nn.init.uniform_(self.bias, -bound, bound) 146 | 147 | 148 | class GlobalLocalMultiscaleTSEncoder(nn.Module): 149 | 150 | def __init__(self, input_dims, output_dims, 151 | kernels: List[int], 152 | num_bands: int, 153 | freq_mixing: bool, 154 | length: int, 155 | mode = 0, 156 | hidden_dims=64, depth=10, mask_mode='binomial', gamma=0.9): 157 | super().__init__() 158 | 159 | self.mode = mode 160 | 161 | self.input_dims = input_dims 162 | self.output_dims = output_dims 163 | self.hidden_dims = hidden_dims 164 | self.mask_mode = mask_mode 165 | self.input_fc = nn.Linear(input_dims, hidden_dims) 166 | self.feature_extractor = DilatedConvEncoder( 167 | hidden_dims, 168 | [hidden_dims] * depth + [output_dims], 169 | kernel_size=3, gamma = gamma 170 | ) 171 | 172 | self.kernels = kernels 173 | self.num_bands = num_bands 174 | 175 | self.convs = nn.ModuleList( 176 | [nn.Conv1d(output_dims, output_dims//2, k, padding=k-1) for k in kernels] 177 | ) 178 | self.fouriers = nn.ModuleList( 179 | [BandedFourierLayer(output_dims, output_dims//2, b, num_bands, 180 | freq_mixing=freq_mixing, length=length) for b in range(num_bands)] 181 | ) 182 | 183 | def forward(self, x, tcn_output=False, mask='all_true'): # x: B x T x input_dims 184 | nan_mask = ~x.isnan().any(axis=-1) 185 | x[~nan_mask] = 0 186 | x = self.input_fc(x) # B x T x Ch 187 | 188 | # generate & apply mask 189 | if mask is None: 190 | if self.training: 191 | mask = self.mask_mode 192 | else: 193 | mask = 'all_true' 194 | 195 | if mask == 'binomial': 196 | mask = generate_binomial_mask(x.size(0), x.size(1)).to(x.device) 197 | elif mask == 'continuous': 198 | mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device) 199 | elif mask == 'all_true': 200 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) 201 | elif mask == 'all_false': 202 | mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool) 203 | elif mask == 'mask_last': 204 | mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) 205 | mask[:, -1] = False 206 | 207 | mask &= nan_mask 208 | x[~mask] = 0 209 | 210 | # conv encoder 211 | x = x.transpose(1, 2) # B x Ch x T 212 | x = self.feature_extractor(x) # B x Co x T 213 | 214 | if tcn_output: 215 | return x.transpose(1, 2) 216 | 217 | if len(self.kernels) == 0: 218 | local_multiscale = None 219 | else: 220 | local_multiscale = [] 221 | for idx, mod in enumerate(self.convs): 222 | out = mod(x) # b d t 223 | if self.kernels[idx] != 1: 224 | out = out[..., :-(self.kernels[idx] - 1)] 225 | local_multiscale.append(out.transpose(1, 2)) # b t d 226 | local_multiscale = reduce( 227 | rearrange(local_multiscale, 'list b t d -> list b t d'), 228 | 'list b t d -> b t d', 'mean' 229 | ) 230 | 231 | x = x.transpose(1, 2) # B x T x Co 232 | 233 | if self.num_bands == 0: 234 | global_multiscale = None 235 | else: 236 | global_multiscale = [] 237 | for mod in self.fouriers: 238 | out = mod(x) # b t d 239 | global_multiscale.append(out) 240 | 241 | global_multiscale = global_multiscale[0] 242 | 243 | return torch.cat([local_multiscale, global_multiscale], dim=-1) 244 | -------------------------------------------------------------------------------- /onenet_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzhang114/OneNet/076188ad4c1b28cbda8cb4b1e231499b233e1437/onenet_result.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.1.1 2 | numpy==1.19.4 3 | pandas==0.25.1 4 | scikit_learn==0.21.3 5 | tqdm==4.62.3 6 | einops==0.4.0 -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | ## M 2 | online_learning='full' 3 | i=1 4 | ns=(1 ) 5 | bszs=(1 ) 6 | lens=(1 24 48) 7 | methods=('onenet_fsnet') 8 | for n in ${ns[*]}; do 9 | for bsz in ${bszs[*]}; do 10 | for len in ${lens[*]}; do 11 | for m in ${methods[*]}; do 12 | CUDA_VISIBLE_DEVICES=2 nohup python -u main.py --method $m --root_path ./data/ --n_inner $n --test_bsz $bsz --data ETTh2 --features M --seq_len 60 --label_len 0 --pred_len $len --des 'Exp' --itr $i --train_epochs 15 --learning_rate 1e-3 --online_learning $online_learning --use_adbfgs > ETTh2$len$online_learning.out 2>&1 & 13 | CUDA_VISIBLE_DEVICES=2 nohup python -u main.py --method $m --root_path ./data/ --n_inner $n --test_bsz $bsz --data ETTm1 --features M --seq_len 60 --label_len 0 --pred_len $len --des 'Exp' --itr $i --train_epochs 15 --learning_rate 1e-3 --online_learning $online_learning > ETTm1$len$online_learning.out 2>&1 & 14 | CUDA_VISIBLE_DEVICES=3 nohup python -u main.py --method $m --root_path ./data/ --n_inner $n --test_bsz $bsz --data WTH --features M --seq_len 60 --label_len 0 --pred_len $len --des 'Exp' --itr $i --train_epochs 15 --learning_rate 1e-3 --online_learning $online_learning --use_adbfgs > WTH$len$online_learning.out 2>&1 & 15 | CUDA_VISIBLE_DEVICES=1 nohup python -u main.py --method $m --root_path ./data/ --n_inner $n --test_bsz $bsz --data ECL --features M --seq_len 60 --label_len 0 --pred_len $len --des 'Exp' --itr $i --train_epochs 15 --learning_rate 3e-3 --online_learning $online_learning --use_adbfgs > ECL$len$online_learning.out 2>&1 & 16 | done 17 | done 18 | done 19 | done 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /run_d3a.sh: -------------------------------------------------------------------------------- 1 | ## M 2 | online_learning='full' 3 | i=1 4 | ns=(1 ) 5 | bszs=(1 ) 6 | lens=(48) 7 | methods=('fsnet_d3a') 8 | sleep_interval=16 9 | sleep_epochs=20 10 | for n in ${ns[*]}; do 11 | for bsz in ${bszs[*]}; do 12 | for len in ${lens[*]}; do 13 | for m in ${methods[*]}; do 14 | online_adjust=0.5 15 | offline_adjust=0.5 16 | CUDA_VISIBLE_DEVICES=3 nohup python -u main.py --method $m --root_path ./data/ --n_inner $n --test_bsz $bsz --data ETTh2 --features M --seq_len 60 --label_len 0 --pred_len $len --des 'Exp' --itr $i --train_epochs 6 --learning_rate 1e-3 --online_learning $online_learning --sleep_interval $sleep_interval --sleep_epochs $sleep_epochs --online_adjust $online_adjust --offline_adjust $offline_adjust > $m-ETTh2$len$online_learning-sleep$sleep_interval-epoch$sleep_epochs'_'online_adjust$online_adjust'_'offline_adjust$offline_adjust.out 2>&1 & 17 | CUDA_VISIBLE_DEVICES=3 nohup python -u main.py --method $m --root_path ./data/ --n_inner $n --test_bsz $bsz --data ETTm1 --features M --seq_len 60 --label_len 0 --pred_len $len --des 'Exp' --itr $i --train_epochs 6 --learning_rate 1e-3 --online_learning $online_learning --sleep_interval $sleep_interval --sleep_epochs $sleep_epochs --online_adjust $online_adjust --offline_adjust $offline_adjust > $m-ETTm1$len$online_learning-sleep$sleep_interval-epoch$sleep_epochs'_'online_adjust$online_adjust'_'offline_adjust$offline_adjust.out 2>&1 & 18 | CUDA_VISIBLE_DEVICES=0 nohup python -u main.py --method $m --root_path ./data/ --n_inner $n --test_bsz $bsz --data WTH --features M --seq_len 60 --label_len 0 --pred_len $len --des 'Exp' --itr $i --train_epochs 6 --learning_rate 1e-3 --online_learning $online_learning --sleep_interval $sleep_interval --sleep_epochs $sleep_epochs --online_adjust $online_adjust --offline_adjust $offline_adjust > $m-WTH$len$online_learning-sleep$sleep_interval-epoch$sleep_epochs'_'online_adjust$online_adjust'_'offline_adjust$offline_adjust.out 2>&1 & 19 | online_adjust=2.0 20 | offline_adjust=2.0 21 | CUDA_VISIBLE_DEVICES=1 nohup python -u main.py --method $m --root_path ./data/ --n_inner $n --test_bsz $bsz --data ECL --features M --seq_len 60 --label_len 0 --pred_len $len --des 'Exp' --itr $i --train_epochs 6 --learning_rate 3e-3 --online_learning $online_learning --sleep_interval $sleep_interval --sleep_epochs $sleep_epochs --online_adjust $online_adjust --offline_adjust $offline_adjust > $m-ECL$len$online_learning-sleep$sleep_interval-epoch$sleep_epochs'_'online_adjust$online_adjust'_'offline_adjust$offline_adjust.out 2>&1 & 22 | done 23 | done 24 | done 25 | done 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /teaser_d3a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzhang114/OneNet/076188ad4c1b28cbda8cb4b1e231499b233e1437/teaser_d3a.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzhang114/OneNet/076188ad4c1b28cbda8cb4b1e231499b233e1437/utils/__init__.py -------------------------------------------------------------------------------- /utils/augmentations.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | class Augmenter(object): 7 | """ 8 | It applies a series of semantically preserving augmentations to batch of sequences, and updates their mask accordingly. 9 | Available augmentations are: 10 | - History cutout 11 | - History crop 12 | - Gaussian noise 13 | - Spatial dropout 14 | """ 15 | def __init__(self, cutout_length=4, cutout_prob=0.5, crop_min_history=0.5, crop_prob=0.5, gaussian_std=0.1, dropout_prob=0.1, is_cuda=True): 16 | self.cutout_length = cutout_length 17 | self.cutout_prob = cutout_prob 18 | self.crop_min_history = crop_min_history 19 | self.crop_prob = crop_prob 20 | self.gaussian_std = gaussian_std 21 | self.dropout_prob = dropout_prob 22 | self.is_cuda = is_cuda 23 | 24 | self.augmentations = [self.history_cutout, self.history_crop, self.gaussian_noise, self.spatial_dropout] 25 | 26 | def __call__(self, sequence, sequence_mask): 27 | for f in self.augmentations: 28 | sequence, sequence_mask = f(sequence, sequence_mask) 29 | 30 | return sequence, sequence_mask 31 | 32 | def history_cutout(self, sequence, sequence_mask): 33 | 34 | """ 35 | Mask out some time-window in history (i.e. excluding last time step) 36 | """ 37 | n_seq, n_len, n_channel = sequence.shape 38 | 39 | #Randomly draw the beginning of cutout 40 | cutout_start_index = torch.randint(low=0, high=n_len-self.cutout_length, size=(n_seq,1)).expand(-1,n_len) 41 | cutout_end_index = cutout_start_index + self.cutout_length 42 | 43 | #Based on start and end index of cutout, defined the cutout mask 44 | indices_tensor = torch.arange(n_len).repeat(n_seq,1) 45 | mask_pre = indices_tensor < cutout_start_index 46 | mask_post = indices_tensor >= cutout_end_index 47 | 48 | mask_cutout = mask_pre + mask_post 49 | 50 | #Expand it through the dimension of channels 51 | mask_cutout = mask_cutout.unsqueeze(dim=-1).expand(-1,-1,n_channel).long() 52 | 53 | #Probabilistically apply the cutoff to each sequence 54 | cutout_selection = (torch.rand(n_seq) < self.cutout_prob).long().reshape(-1,1,1) 55 | 56 | #If cuda is enabled, we will transfer the generated tensors to cuda 57 | if self.is_cuda: 58 | cutout_selection = cutout_selection.cuda() 59 | mask_cutout = mask_cutout.cuda() 60 | 61 | #Based on mask_cutout and cutout_selection, apply mask to the sequence 62 | sequence_cutout = sequence * (1-cutout_selection) + sequence * cutout_selection * mask_cutout 63 | 64 | #Update the mask as well 65 | sequence_mask_cutout = sequence_mask * (1-cutout_selection) + sequence_mask * cutout_selection * mask_cutout 66 | 67 | return sequence_cutout, sequence_mask_cutout 68 | 69 | def history_crop(self, sequence, sequence_mask): 70 | """ 71 | Crop the certain window of history from the beginning. 72 | """ 73 | 74 | n_seq, n_len, n_channel = sequence.shape 75 | 76 | #Get number of measurements non-padded for each sequence and time step 77 | nonpadded = sequence_mask.sum(dim=-1).cpu() 78 | first_nonpadded = self.get_first_nonzero(nonpadded).reshape(-1,1)/n_len #normalized by length 79 | 80 | #Randomly draw the beginning of crop 81 | crop_start_index = torch.rand(size=(n_seq,1)) 82 | 83 | #Adjust the start_index based on first N-padded time steps 84 | # For instance: if you remove first half of history, then this code removes 85 | # the first half of the NON-PADDED history. 86 | crop_start_index = (crop_start_index * (1 - first_nonpadded) * self.crop_min_history + first_nonpadded) 87 | crop_start_index = (crop_start_index * n_len).long().expand(-1,n_len) 88 | 89 | #Based on start index of crop, defined the crop mask 90 | indices_tensor = torch.arange(n_len).repeat(n_seq,1) 91 | mask_crop = indices_tensor >= crop_start_index 92 | 93 | #Expand it through the dimension of channels 94 | mask_crop = mask_crop.unsqueeze(dim=-1).expand(-1,-1,n_channel).long() 95 | 96 | #Probabilistically apply the crop to each sequence 97 | crop_selection = (torch.rand(n_seq) < self.crop_prob).long().reshape(-1,1,1) 98 | 99 | #If cuda is enabled, we will transfer the generated tensors to cuda 100 | if self.is_cuda: 101 | crop_selection = crop_selection.cuda() 102 | mask_crop = mask_crop.cuda() 103 | 104 | #Based on mask_crop and crop_selection, apply mask to the sequence 105 | sequence_crop = sequence * (1-crop_selection) + sequence * crop_selection * mask_crop 106 | 107 | #Update the mask as well 108 | sequence_mask_crop = sequence_mask * (1-crop_selection) + sequence_mask * crop_selection * mask_crop 109 | 110 | return sequence_crop, sequence_mask_crop 111 | 112 | def gaussian_noise(self, sequence, sequence_mask): 113 | """ 114 | Add Gaussian noise to non-padded measurments 115 | """ 116 | 117 | #Add gaussian noise to the measurements 118 | 119 | #For padded entries, we won't add noise 120 | padding_mask = (sequence_mask != 0).long() 121 | #Calculate the noise for all entries 122 | noise = nn.init.trunc_normal_(torch.empty_like(sequence),std=self.gaussian_std, a=-2*self.gaussian_std, b=2*self.gaussian_std) 123 | 124 | #Add noise only to nonpadded entries 125 | sequence_noisy = sequence + padding_mask * noise 126 | 127 | return sequence_noisy, sequence_mask 128 | 129 | def spatial_dropout(self, sequence, sequence_mask): 130 | """ 131 | Drop some channels/measurements completely at random. 132 | """ 133 | n_seq, n_len, n_channel = sequence.shape 134 | 135 | dropout_selection = (torch.rand((n_seq,1,n_channel)) > self.dropout_prob).long().expand(-1,n_len,-1) 136 | 137 | #If cuda is enabled, we will transfer the generated tensors to cuda 138 | if self.is_cuda: 139 | dropout_selection = dropout_selection.cuda() 140 | 141 | sequence_dropout = sequence * dropout_selection 142 | 143 | sequence_mask_dropout = sequence_mask * dropout_selection 144 | 145 | return sequence_dropout, sequence_mask_dropout 146 | 147 | def get_first_nonzero(self, tensor2d): 148 | """ 149 | Helper function to get the first nonzero index for the 2nd dimension 150 | """ 151 | 152 | nonzero = tensor2d != 0 153 | cumsum = nonzero.cumsum(dim=-1) 154 | 155 | nonzero_idx = ((cumsum == 1) & nonzero).max(dim=-1).indices 156 | 157 | return nonzero_idx 158 | 159 | 160 | 161 | def concat_mask(seq, seq_mask, use_mask=False): 162 | if use_mask: 163 | seq = torch.cat([seq, seq_mask], dim=2) 164 | return seq 165 | -------------------------------------------------------------------------------- /utils/buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara. 2 | # All rights reserved. 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import numpy as np 8 | from typing import Tuple 9 | from torchvision import transforms 10 | 11 | 12 | def reservoir(num_seen_examples: int, buffer_size: int) -> int: 13 | """ 14 | Reservoir sampling algorithm. 15 | :param num_seen_examples: the number of seen examples 16 | :param buffer_size: the maximum buffer size 17 | :return: the target index if the current image is sampled, else -1 18 | """ 19 | if num_seen_examples < buffer_size: 20 | return num_seen_examples 21 | 22 | rand = np.random.randint(0, num_seen_examples + 1) 23 | if rand < buffer_size: 24 | return rand 25 | else: 26 | return -1 27 | 28 | 29 | def ring(num_seen_examples: int, buffer_portion_size: int, task: int) -> int: 30 | return num_seen_examples % buffer_portion_size + task * buffer_portion_size 31 | 32 | def fifo(num_seen_examples: int, buffer_size: int) -> int: 33 | return num_seen_examples % buffer_size 34 | 35 | class Buffer: 36 | """ 37 | The memory buffer of rehearsal method. 38 | """ 39 | def __init__(self, buffer_size, device, n_tasks=1, mode='fifo'): 40 | assert mode in ['ring', 'reservoir', 'fifo'] 41 | self.buffer_size = buffer_size 42 | self.device = device 43 | self.num_seen_examples = 0 44 | self.functional_index = eval(mode) 45 | if mode == 'ring': 46 | assert n_tasks is not None 47 | self.task_number = n_tasks 48 | self.buffer_portion_size = buffer_size // n_tasks 49 | self.attributes = ['examples', 'labels', 'logits', 'task_labels'] 50 | 51 | def init_tensors(self, examples: torch.Tensor, labels: torch.Tensor, 52 | logits: torch.Tensor, task_labels: torch.Tensor) -> None: 53 | """ 54 | Initializes just the required tensors. 55 | :param examples: tensor containing the images 56 | :param labels: tensor containing the labels 57 | :param logits: tensor containing the outputs of the network 58 | :param task_labels: tensor containing the task labels 59 | """ 60 | for attr_str in self.attributes: 61 | attr = eval(attr_str) 62 | if attr is not None and not hasattr(self, attr_str): 63 | #typ = torch.int64 if attr_str.endswith('els') else torch.float32 64 | typ = torch.float32 65 | setattr(self, attr_str, torch.zeros((self.buffer_size, 66 | *attr.shape[1:]), dtype=typ, device=self.device)) 67 | 68 | def add_data(self, examples, labels=None, logits=None, task_labels=None): 69 | """ 70 | Adds the data to the memory buffer according to the reservoir strategy. 71 | :param examples: tensor containing the images 72 | :param labels: tensor containing the labels 73 | :param logits: tensor containing the outputs of the network 74 | :param task_labels: tensor containing the task labels 75 | :return: 76 | """ 77 | if not hasattr(self, 'examples'): 78 | self.init_tensors(examples, labels, logits, task_labels) 79 | 80 | for i in range(examples.shape[0]): 81 | index = fifo(self.num_seen_examples, self.buffer_size) 82 | self.num_seen_examples += 1 83 | if index >= 0: 84 | self.examples[index] = examples[i].to(self.device) 85 | if labels is not None: 86 | self.labels[index] = labels[i].to(self.device) 87 | if logits is not None: 88 | self.logits[index] = logits[i].to(self.device) 89 | if task_labels is not None: 90 | self.task_labels[index] = task_labels[i].to(self.device) 91 | 92 | def get_data(self, size: int, transform: transforms=None) -> Tuple: 93 | """ 94 | Random samples a batch of size items. 95 | :param size: the number of requested items 96 | :param transform: the transformation to be applied (data augmentation) 97 | :return: 98 | """ 99 | if size > min(self.num_seen_examples, self.examples.shape[0]): 100 | size = min(self.num_seen_examples, self.examples.shape[0]) 101 | 102 | choice = np.random.choice(min(self.num_seen_examples, self.examples.shape[0]), 103 | size=size, replace=False) 104 | if transform is None: transform = lambda x: x 105 | ret_tuple = (torch.stack([transform(ee.cpu()) 106 | for ee in self.examples[choice]]).to(self.device),) 107 | for attr_str in self.attributes[1:]: 108 | if hasattr(self, attr_str): 109 | attr = getattr(self, attr_str) 110 | ret_tuple += (attr[choice],) 111 | 112 | return ret_tuple 113 | 114 | def is_empty(self) -> bool: 115 | """ 116 | Returns true if the buffer is empty, false otherwise. 117 | """ 118 | if self.num_seen_examples == 0: 119 | return True 120 | else: 121 | return False 122 | 123 | def get_all_data(self, transform: transforms=None) -> Tuple: 124 | """ 125 | Return all the items in the memory buffer. 126 | :param transform: the transformation to be applied (data augmentation) 127 | :return: a tuple with all the items in the memory buffer 128 | """ 129 | if transform is None: transform = lambda x: x 130 | ret_tuple = (torch.stack([transform(ee.cpu()) 131 | for ee in self.examples]).to(self.device),) 132 | for attr_str in self.attributes[1:]: 133 | if hasattr(self, attr_str): 134 | attr = getattr(self, attr_str) 135 | ret_tuple += (attr,) 136 | return ret_tuple 137 | 138 | def empty(self) -> None: 139 | """ 140 | Set all the tensors to None. 141 | """ 142 | for attr_str in self.attributes: 143 | if hasattr(self, attr_str): 144 | delattr(self, attr_str) 145 | self.num_seen_examples = 0 146 | 147 | 148 | 149 | class BufferFIFO: 150 | """ 151 | The memory buffer of rehearsal method. 152 | """ 153 | def __init__(self, buffer_size, device, n_tasks=1, mode='reservoir'): 154 | assert mode in ['ring', 'reservoir'] 155 | self.buffer_size = buffer_size 156 | self.device = device 157 | self.num_seen_examples = 0 158 | self.functional_index = eval(mode) 159 | if mode == 'ring': 160 | assert n_tasks is not None 161 | self.task_number = n_tasks 162 | self.buffer_portion_size = buffer_size // n_tasks 163 | self.attributes = ['losses',] 164 | 165 | def init_tensors(self, losses: torch.Tensor, labels: torch.Tensor, 166 | logits: torch.Tensor, task_labels: torch.Tensor) -> None: 167 | """ 168 | Initializes just the required tensors. 169 | :param examples: tensor containing the images 170 | """ 171 | for attr_str in self.attributes: 172 | attr = eval(attr_str) 173 | if attr is not None and not hasattr(self, attr_str): 174 | #typ = torch.int64 if attr_str.endswith('els') else torch.float32 175 | typ = torch.float32 176 | setattr(self, attr_str, torch.zeros((self.buffer_size, 177 | *attr.shape[1:]), dtype=typ, device=self.device)) 178 | 179 | def add_data(self, losses, labels=None, logits=None, task_labels=None): 180 | """ 181 | Adds the data to the memory buffer according to the reservoir strategy. 182 | :param losses: tensor containing the images 183 | :return: 184 | """ 185 | if not hasattr(self, 'losses'): 186 | self.init_tensors(losses, labels, logits, task_labels) 187 | 188 | 189 | self.losses[self.num_seen_examples] = losses 190 | self.num_seen_examples += 1 191 | 192 | def get_data(self, size: int, transform: transforms=None) -> Tuple: 193 | """ 194 | Random samples a batch of size items. 195 | :param size: the number of requested items 196 | :param transform: the transformation to be applied (data augmentation) 197 | :return: 198 | """ 199 | if size > min(self.num_seen_examples, self.losses.shape[0]): 200 | size = min(self.num_seen_examples, self.losses.shape[0]) 201 | return torch.mean(self.losses[self.num_seen_examples-size:self.num_seen_examples]) 202 | 203 | def is_empty(self) -> bool: 204 | """ 205 | Returns true if the buffer is empty, false otherwise. 206 | """ 207 | if self.num_seen_examples == 0: 208 | return True 209 | else: 210 | return False 211 | 212 | def get_all_data(self, transform: transforms=None) -> Tuple: 213 | """ 214 | Return all the items in the memory buffer. 215 | :param transform: the transformation to be applied (data augmentation) 216 | :return: a tuple with all the items in the memory buffer 217 | """ 218 | if transform is None: transform = lambda x: x 219 | ret_tuple = (torch.stack([transform(ee.cpu()) 220 | for ee in self.examples]).to(self.device),) 221 | for attr_str in self.attributes[1:]: 222 | if hasattr(self, attr_str): 223 | attr = getattr(self, attr_str) 224 | ret_tuple += (attr,) 225 | return ret_tuple 226 | 227 | def empty(self) -> None: 228 | """ 229 | Set all the tensors to None. 230 | """ 231 | for attr_str in self.attributes: 232 | if hasattr(self, attr_str): 233 | delattr(self, attr_str) 234 | self.num_seen_examples = 0 -------------------------------------------------------------------------------- /utils/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import norm 3 | import torch 4 | import os 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | palette = 'colorblind' 8 | colors = ['gold', 'grey', '#d7191c', '#2b83ba' ] 9 | colors = sns.color_palette("muted", n_colors=10) 10 | 11 | class STEPD: 12 | def __init__(self, new_window_size, alpha_w=0.05, alpha_d=0.003): 13 | self.new_window_size = new_window_size 14 | self.alpha_w = alpha_w 15 | self.alpha_d = alpha_d 16 | self.cnt, self.shift_cnt = 1, 0 17 | self.data = [] 18 | self.data_visualize = None 19 | 20 | def add_data(self, error_rate , x): 21 | self.cnt += 1 22 | if self.data_visualize is None: 23 | self.data_visualize = x.cpu().detach()[0] 24 | else: 25 | self.data_visualize = torch.cat([self.data_visualize, x.cpu().detach()[0,-1,:].unsqueeze(0)], dim=0) 26 | # if len(self.data) > self.new_window_size and self.is_outlier(error_rate) : 27 | # # 如果是异常值,不将其添加到数据中 28 | # return 29 | self.data.append(error_rate) 30 | 31 | def reset(self): 32 | self.shift_cnt += 1 33 | self.data = [] 34 | self.cnt = 0 35 | self.data_visualize = None 36 | 37 | def is_outlier(self, value, threshold=3.0): 38 | # 使用标准差的方法检测异常值,也许这里不需要normalize,绝对大小就能说明问题 39 | mean_value = np.mean(self.data[-self.new_window_size:]) 40 | std_dev_value = np.std(self.data[-self.new_window_size:]) 41 | z_score = (value - mean_value) / (std_dev_value + 1e-4) 42 | warning_threshold = norm.ppf(1 - self.alpha_w / 2) 43 | return z_score > warning_threshold 44 | 45 | def run_test(self,): 46 | if len(self.data) < self.new_window_size: 47 | # Not enough data for comparison 48 | return 0, None 49 | 50 | # Extract the most recent time window and the overall time window 51 | recent_window = self.data[-self.new_window_size:] 52 | overall_window = self.data 53 | 54 | # Calculate the test statistic 55 | mean_recent = np.mean(recent_window) 56 | mean_overall = np.mean(overall_window) 57 | std_dev_overall = np.std(overall_window) 58 | n = len(self.data) 59 | theta_stepd = (mean_recent - mean_overall) / (std_dev_overall / np.sqrt(n)) 60 | 61 | # Calculate the warning and drift thresholds 62 | warning_threshold = norm.ppf(1 - self.alpha_w / 2) 63 | drift_threshold = norm.ppf(1 - self.alpha_d / 2) 64 | 65 | # Check for warning or drift 66 | if theta_stepd > drift_threshold: 67 | # self.plt_distribution(self.data, c1 = colors[0], c2 = colors[1]) 68 | # self.plt_distribution(self.data_visualize[:,0].numpy(), name='value', c1 = colors[2], c2 = colors[3]) 69 | return 1, 3e-3 70 | else: 71 | lr = 1e-4 + (3e-3 - 1e-4) * (theta_stepd / drift_threshold) 72 | return 0, max(lr, 1e-4) 73 | 74 | def plt_distribution(self, data, name='error', c1 = colors[0], c2 = colors[1]): 75 | sns.set(style="whitegrid", font_scale=1.8) 76 | plt.rcParams["font.family"] = "Times New Roman" 77 | plt.figure(figsize=(12, 8)) 78 | plt.plot(data, label='Entire Dataset', color=c1) 79 | 80 | # Highlight the last 100 time steps 81 | plt.plot(range(len(data) - self.new_window_size, len(data)), data[-self.new_window_size:], label=f'Last Window', color=c2, linewidth=2) 82 | 83 | # Plot the mean of the entire dataset 84 | mean_entire = np.mean(data) 85 | plt.axhline(mean_entire, color=c1, linestyle='dashed', label=f'Mean (Entire): {mean_entire:.2f}') 86 | 87 | # Plot the mean of the last 100 steps 88 | mean_last_100 = np.mean(data[-self.new_window_size:]) 89 | plt.axhline(mean_last_100, color=c2, linestyle='dashed', label=f'Mean (Last Window): {mean_last_100:.2f}') 90 | 91 | if name == 'value': 92 | plt.legend([]) 93 | else: 94 | plt.legend() 95 | plt.xlabel('Time Steps') 96 | plt.ylabel(name) 97 | # Add legend 98 | 99 | # Save the plot as a PDF file 100 | if not os.path.exists('imgs/drift/'): 101 | os.mkdir('imgs/drift/') 102 | sns.despine(offset=10, trim=True) 103 | plt.tight_layout() 104 | plt.savefig(f'imgs/drift/{name}_plot_{self.shift_cnt}.pdf') 105 | plt.close() -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class TriangularCausalMask(): 4 | def __init__(self, B, L, device="cpu"): 5 | mask_shape = [B, 1, L, L] 6 | with torch.no_grad(): 7 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 8 | 9 | @property 10 | def mask(self): 11 | return self._mask 12 | 13 | class ProbMask(): 14 | def __init__(self, B, H, L, index, scores, device="cpu"): 15 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 16 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 17 | indicator = _mask_ex[torch.arange(B)[:, None, None], 18 | torch.arange(H)[None, :, None], 19 | index, :].to(device) 20 | self._mask = indicator.view(scores.shape).to(device) 21 | 22 | @property 23 | def mask(self): 24 | return self._mask -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numexpr as ne 3 | import pdb 4 | 5 | def cumavg(m): 6 | cumsum= np.cumsum(m) 7 | return cumsum / np.arange(1, cumsum.size + 1) 8 | 9 | def RSE(pred, true): 10 | return np.sqrt(np.sum((true-pred)**2)) / np.sqrt(np.sum((true-true.mean())**2)) 11 | 12 | def CORR(pred, true): 13 | u = ((true-true.mean(0))*(pred-pred.mean(0))).sum(0) 14 | d = np.sqrt(((true-true.mean(0))**2*(pred-pred.mean(0))**2).sum(0)) 15 | return (u/d).mean(-1) 16 | 17 | def MAE(pred, true): 18 | 19 | return np.mean(np.abs(pred-true)) 20 | #return ne.evaluate('sum(abs(pred-true))')/true.size 21 | def MSE(pred, true): 22 | return np.mean((pred-true)**2) 23 | 24 | def RMSE(pred, true): 25 | return np.sqrt(MSE(pred, true)) 26 | 27 | def MAPE(pred, true): 28 | return np.mean(np.abs((pred - true) / true)) 29 | #return ne.evaluate('sum(abs(pred-true)/true)')/true.shape[0] 30 | #return 0 31 | #return ne.evaluate('sum(abs(pred-true))')/(true.size) 32 | 33 | def MSPE(pred, true): 34 | return np.mean(np.square((pred - true) / true)) 35 | #return 0 36 | 37 | def metric(pred, true): 38 | 39 | mae = MAE(pred, true) 40 | mse = MSE(pred, true) 41 | rmse = RMSE(pred, true) 42 | mape = MAPE(pred, true) 43 | mspe = MSPE(pred, true) 44 | 45 | return mae,mse,rmse,mape,mspe 46 | -------------------------------------------------------------------------------- /utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from pandas.tseries import offsets 6 | from pandas.tseries.frequencies import to_offset 7 | 8 | class TimeFeature: 9 | def __init__(self): 10 | pass 11 | 12 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 13 | pass 14 | 15 | def __repr__(self): 16 | return self.__class__.__name__ + "()" 17 | 18 | class SecondOfMinute(TimeFeature): 19 | """Minute of hour encoded as value between [-0.5, 0.5]""" 20 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 21 | return index.second / 59.0 - 0.5 22 | 23 | class MinuteOfHour(TimeFeature): 24 | """Minute of hour encoded as value between [-0.5, 0.5]""" 25 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 26 | return index.minute / 59.0 - 0.5 27 | 28 | class HourOfDay(TimeFeature): 29 | """Hour of day encoded as value between [-0.5, 0.5]""" 30 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 31 | return index.hour / 23.0 - 0.5 32 | 33 | class DayOfWeek(TimeFeature): 34 | """Hour of day encoded as value between [-0.5, 0.5]""" 35 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 36 | return index.dayofweek / 6.0 - 0.5 37 | 38 | class DayOfMonth(TimeFeature): 39 | """Day of month encoded as value between [-0.5, 0.5]""" 40 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 41 | return (index.day - 1) / 30.0 - 0.5 42 | 43 | class DayOfYear(TimeFeature): 44 | """Day of year encoded as value between [-0.5, 0.5]""" 45 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 46 | return (index.dayofyear - 1) / 365.0 - 0.5 47 | 48 | class MonthOfYear(TimeFeature): 49 | """Month of year encoded as value between [-0.5, 0.5]""" 50 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 51 | return (index.month - 1) / 11.0 - 0.5 52 | 53 | class WeekOfYear(TimeFeature): 54 | """Week of year encoded as value between [-0.5, 0.5]""" 55 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 56 | return (index.isocalendar().week - 1) / 52.0 - 0.5 57 | 58 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 59 | """ 60 | Returns a list of time features that will be appropriate for the given frequency string. 61 | Parameters 62 | ---------- 63 | freq_str 64 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 65 | """ 66 | 67 | features_by_offsets = { 68 | offsets.YearEnd: [], 69 | offsets.QuarterEnd: [MonthOfYear], 70 | offsets.MonthEnd: [MonthOfYear], 71 | offsets.Week: [DayOfMonth, WeekOfYear], 72 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 73 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 74 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 75 | offsets.Minute: [ 76 | MinuteOfHour, 77 | HourOfDay, 78 | DayOfWeek, 79 | DayOfMonth, 80 | DayOfYear, 81 | ], 82 | offsets.Second: [ 83 | SecondOfMinute, 84 | MinuteOfHour, 85 | HourOfDay, 86 | DayOfWeek, 87 | DayOfMonth, 88 | DayOfYear, 89 | ], 90 | } 91 | 92 | offset = to_offset(freq_str) 93 | 94 | for offset_type, feature_classes in features_by_offsets.items(): 95 | if isinstance(offset, offset_type): 96 | return [cls() for cls in feature_classes] 97 | 98 | supported_freq_msg = f""" 99 | Unsupported frequency {freq_str} 100 | The following frequencies are supported: 101 | Y - yearly 102 | alias: A 103 | M - monthly 104 | W - weekly 105 | D - daily 106 | B - business days 107 | H - hourly 108 | T - minutely 109 | alias: min 110 | S - secondly 111 | """ 112 | raise RuntimeError(supported_freq_msg) 113 | 114 | def time_features(dates, timeenc=1, freq='h'): 115 | """ 116 | > `time_features` takes in a `dates` dataframe with a 'dates' column and extracts the date down to `freq` where freq can be any of the following if `timeenc` is 0: 117 | > * m - [month] 118 | > * w - [month] 119 | > * d - [month, day, weekday] 120 | > * b - [month, day, weekday] 121 | > * h - [month, day, weekday, hour] 122 | > * t - [month, day, weekday, hour, *minute] 123 | > 124 | > If `timeenc` is 1, a similar, but different list of `freq` values are supported (all encoded between [-0.5 and 0.5]): 125 | > * Q - [month] 126 | > * M - [month] 127 | > * W - [Day of month, week of year] 128 | > * D - [Day of week, day of month, day of year] 129 | > * B - [Day of week, day of month, day of year] 130 | > * H - [Hour of day, day of week, day of month, day of year] 131 | > * T - [Minute of hour*, hour of day, day of week, day of month, day of year] 132 | > * S - [Second of minute, minute of hour, hour of day, day of week, day of month, day of year] 133 | 134 | *minute returns a number from 0-3 corresponding to the 15 minute period it falls into. 135 | """ 136 | if timeenc==0: 137 | dates['month'] = dates.date.apply(lambda row:row.month,1) 138 | dates['day'] = dates.date.apply(lambda row:row.day,1) 139 | dates['weekday'] = dates.date.apply(lambda row:row.weekday(),1) 140 | dates['hour'] = dates.date.apply(lambda row:row.hour,1) 141 | dates['minute'] = dates.date.apply(lambda row:row.minute,1) 142 | dates['minute'] = dates.minute.map(lambda x:x//15) 143 | freq_map = { 144 | 'y':[],'m':['month'],'w':['month'],'d':['month','day','weekday'], 145 | 'b':['month','day','weekday'],'h':['month','day','weekday','hour'], 146 | 't':['month','day','weekday','hour','minute'], 147 | } 148 | return dates[freq_map[freq.lower()]].values 149 | if timeenc==1: 150 | dates = pd.to_datetime(dates.date.values) 151 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]).transpose(1,0) 152 | 153 | if timeenc == 2: 154 | dt = pd.to_datetime(dates.date.values) 155 | return np.stack([ 156 | dt.minute.to_numpy(), 157 | dt.hour.to_numpy(), 158 | dt.dayofweek.to_numpy(), 159 | dt.day.to_numpy(), 160 | dt.dayofyear.to_numpy(), 161 | dt.month.to_numpy(), 162 | dt.weekofyear.to_numpy(), 163 | ], axis=1).astype(np.float) -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def adjust_learning_rate(optimizer, epoch, args, decision=False): 5 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 6 | if decision: 7 | for param_group in optimizer.param_groups: 8 | param_group['lr'] = param_group['lr'] / 3 9 | print('Updating learning rate to {}'.format(param_group['lr'])) 10 | return 11 | if args.lradj=='type1': 12 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch-1) // 1))} 13 | elif args.lradj=='type2': 14 | lr_adjust = { 15 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 16 | 10: 5e-7, 15: 1e-7, 20: 5e-8 17 | } 18 | if epoch in lr_adjust.keys(): 19 | lr = lr_adjust[epoch] 20 | for param_group in optimizer.param_groups: 21 | param_group['lr'] = lr 22 | print('Updating learning rate to {}'.format(lr)) 23 | 24 | class EarlyStopping: 25 | def __init__(self, patience=7, verbose=False, delta=0): 26 | self.patience = patience 27 | self.verbose = verbose 28 | self.counter = 0 29 | self.best_score = None 30 | self.early_stop = False 31 | self.val_loss_min = np.Inf 32 | self.delta = delta 33 | 34 | def __call__(self, val_loss, model, path, name='checkpoint.pth'): 35 | score = -val_loss 36 | if self.best_score is None: 37 | self.best_score = score 38 | self.save_checkpoint(val_loss, model, path, name) 39 | elif score < self.best_score + self.delta: 40 | self.counter += 1 41 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 42 | if self.counter >= self.patience: 43 | self.early_stop = True 44 | else: 45 | self.best_score = score 46 | self.save_checkpoint(val_loss, model, path, name) 47 | self.counter = 0 48 | 49 | def save_checkpoint(self, val_loss, model, path, name='checkpoint.pth'): 50 | if self.verbose: 51 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 52 | torch.save(model.state_dict(), path+'/'+name) 53 | self.val_loss_min = val_loss 54 | 55 | class dotdict(dict): 56 | """dot.notation access to dictionary attributes""" 57 | __getattr__ = dict.get 58 | __setattr__ = dict.__setitem__ 59 | __delattr__ = dict.__delitem__ 60 | 61 | class StandardScaler(): 62 | def __init__(self): 63 | self.mean = 0. 64 | self.std = 1. 65 | 66 | def fit(self, data): 67 | self.mean = data.mean(0) 68 | self.std = data.std(0) 69 | 70 | def transform(self, data): 71 | mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean 72 | std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std 73 | return (data - mean) / std 74 | 75 | def inverse_transform(self, data): 76 | mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean 77 | std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std 78 | return (data * std) + mean 79 | --------------------------------------------------------------------------------