├── .gitignore ├── README.md ├── data_provider ├── __init__.py ├── data_factory.py └── data_loader.py ├── exp ├── exp_basic.py └── exp_main.py ├── figs ├── framework.png └── multi-scale transformer.png ├── layers ├── AMS.py ├── Embedding.py ├── Layer.py └── RevIN.py ├── models └── PathFormer.py ├── requirements.txt ├── run.py ├── scripts └── multivariate │ ├── ETTh1.sh │ ├── ETTh2.sh │ ├── ETTm1.sh │ ├── ETTm2.sh │ ├── electricity.sh │ ├── traffic.sh │ └── weather.sh └── utils ├── Other.py ├── decomposition.py ├── masking.py ├── metrics.py ├── timefeatures.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | dataset/ 4 | 5 | result/ 6 | 7 | .idea/ 8 | 9 | .vscode/ 10 | 11 | venv/ 12 | 13 | .DS_Store 14 | 15 | 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## (ICLR 2024) Pathformer: Multi-scale Transformers with Adaptive Pathways for Time Series Forecasting 2 | 3 | This code is a PyTorch implementation of our ICLR'24 paper "Pathformer: Multi-scale Transformers with Adaptive Pathways for Time Series Forecasting". [[arXiv]](https://arxiv.org/abs/2402.05956) 4 | 5 | 🌟 Pathformer代码在阿里云仓库也进行同步更新:[阿里云Pathformer代码链接](https://github.com/alibaba/sreworks-ext/tree/main/aiops/Pathformer_ICLR2024) 6 | 7 | ## Citing Pathformer 8 | If you find this resource helpful, please consider to cite our research: 9 | 10 | ``` 11 | @inproceedings{chen2024pathformer, 12 | author = {Peng Chen and Yingying Zhang and Yunyao Cheng and Yang Shu and Yihang Wang and Qingsong Wen and Bin Yang and Chenjuan Guo}, 13 | title = {Pathformer: Multi-scale Transformers with Adaptive Pathways for Time Series Forecasting}, 14 | booktitle = {International Conference on Learning Representations (ICLR)}, 15 | year = {2024} 16 | } 17 | ``` 18 | 19 | 20 | 21 | 22 | ## Introduction 23 | Pathformer, a Multi-Scale Transformer with Adaptive Pathways for time series forecasting. It integrates multi-scale temporal resolutions and temporal distances by introducing patch division with multiple patch sizes and dual attention on the divided patches, enabling the comprehensive modeling of multi-scale characteristics. Furthermore, adaptive pathways dynamically select and aggregate scale-specific characteristics based on the different temporal dynamics. 24 | 25 | 26 | ![The architecture of Pathformer](./figs/framework.png#pic_center) 27 | 28 | The important components of Pathformer: Multi-Scale Transformer Block and Multi-Scale Router. 29 | 30 | ![The structure of the Multi-Scale Transformer Block and Multi-Scale Router](./figs/multi-scale%20transformer.png) 31 | ## Requirements 32 | To install all dependencies: 33 | ``` 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | ## Datasets 38 | You can access the well pre-processed datasets from [Google Drive](https://drive.google.com/file/d/1NF7VEefXCmXuWNbnNe858WvQAkJ_7wuP/view), then place the downloaded contents under ./dataset 39 | ## Quick Demos 40 | 1. Download datasets and place them under ./dataset 41 | 2. Run each script in scripts/, for example 42 | ``` 43 | bash scripts/multivariate/ETTm2.sh 44 | ``` 45 | 46 | 47 | ## Further Reading 48 | 1, [**Transformers in Time Series: A Survey**](https://arxiv.org/abs/2202.07125), in IJCAI 2023. 49 | [\[GitHub Repo\]](https://github.com/qingsongedu/time-series-transformers-review) 50 | 51 | ```bibtex 52 | @inproceedings{wen2023transformers, 53 | title={Transformers in time series: A survey}, 54 | author={Wen, Qingsong and Zhou, Tian and Zhang, Chaoli and Chen, Weiqi and Ma, Ziqing and Yan, Junchi and Sun, Liang}, 55 | booktitle={International Joint Conference on Artificial Intelligence(IJCAI)}, 56 | year={2023} 57 | } 58 | ``` 59 | 60 | 61 | -------------------------------------------------------------------------------- /data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data_provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred,Dataset_Pretrain 2 | from torch.utils.data import DataLoader 3 | 4 | data_dict = { 5 | 'ETTh1': Dataset_ETT_hour, 6 | 'ETTh2': Dataset_ETT_hour, 7 | 'ETTm1': Dataset_ETT_minute, 8 | 'ETTm2': Dataset_ETT_minute, 9 | 'custom': Dataset_Custom, 10 | 'pretrain': Dataset_Pretrain, 11 | } 12 | 13 | def data_provider(args, flag): 14 | Data = data_dict[args.data] 15 | timeenc = 0 if args.embed != 'timeF' else 1 16 | 17 | if flag == 'test': 18 | shuffle_flag = False 19 | drop_last = True 20 | batch_size = args.batch_size 21 | freq = args.freq 22 | 23 | elif flag == 'pred': 24 | shuffle_flag = False 25 | drop_last = False 26 | batch_size = 1 27 | freq = args.freq 28 | Data = Dataset_Pred 29 | 30 | else: 31 | shuffle_flag = True 32 | drop_last = True 33 | batch_size = args.batch_size 34 | freq = args.freq 35 | 36 | data_set = Data( 37 | root_path=args.root_path, 38 | data_path=args.data_path, 39 | flag=flag, 40 | size=[args.seq_len, args.pred_len], 41 | features=args.features, 42 | target=args.target, 43 | timeenc=timeenc, 44 | freq=freq 45 | ) 46 | print(flag, len(data_set)) 47 | data_loader = DataLoader( 48 | data_set, 49 | batch_size=batch_size, 50 | shuffle=shuffle_flag, 51 | num_workers=args.num_workers, 52 | drop_last=drop_last) 53 | 54 | return data_set, data_loader 55 | -------------------------------------------------------------------------------- /data_provider/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from torch.utils.data import Dataset 4 | from sklearn.preprocessing import StandardScaler 5 | from utils.timefeatures import time_features 6 | import warnings 7 | 8 | warnings.filterwarnings('ignore') 9 | 10 | class Dataset_ETT_hour(Dataset): 11 | def __init__(self, root_path, flag='train', size=None, 12 | features='S', data_path='ETTh1.csv', 13 | target='OT', scale=True, timeenc=0, freq='h'): 14 | # size [seq_len, pred_len] 15 | if size == None: 16 | self.seq_len = 24 * 4 * 4 17 | self.pred_len = 24 * 4 18 | else: 19 | self.seq_len = size[0] 20 | self.pred_len = size[1] 21 | # init 22 | assert flag in ['train', 'test', 'val'] 23 | type_map = {'train': 0, 'val': 1, 'test': 2} 24 | self.set_type = type_map[flag] 25 | 26 | self.features = features 27 | self.target = target 28 | self.scale = scale 29 | self.timeenc = timeenc 30 | self.freq = freq 31 | 32 | self.root_path = root_path 33 | self.data_path = data_path 34 | self.__read_data__() 35 | 36 | def __read_data__(self): 37 | self.scaler = StandardScaler() 38 | df_raw = pd.read_csv(os.path.join(self.root_path, 39 | self.data_path)) 40 | 41 | border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len] 42 | border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24] 43 | border1 = border1s[self.set_type] 44 | border2 = border2s[self.set_type] 45 | 46 | if self.features == 'M' or self.features == 'MS': 47 | cols_data = df_raw.columns[1:] 48 | df_data = df_raw[cols_data] 49 | elif self.features == 'S': 50 | df_data = df_raw[[self.target]] 51 | 52 | if self.scale: 53 | train_data = df_data[border1s[0]:border2s[0]] 54 | self.scaler.fit(train_data.values) 55 | data = self.scaler.transform(df_data.values) 56 | else: 57 | data = df_data.values 58 | 59 | df_stamp = df_raw[['date']][border1:border2] 60 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 61 | if self.timeenc == 0: 62 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 63 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 64 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 65 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 66 | data_stamp = df_stamp.drop(['date'], 1).values 67 | elif self.timeenc == 1: 68 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 69 | data_stamp = data_stamp.transpose(1, 0) 70 | 71 | self.data_x = data[border1:border2] 72 | self.data_y = data[border1:border2] 73 | self.data_stamp = data_stamp 74 | 75 | def __getitem__(self, index): 76 | s_begin = index 77 | s_end = s_begin + self.seq_len 78 | r_begin = s_end 79 | r_end = r_begin + self.pred_len 80 | 81 | seq_x = self.data_x[s_begin:s_end] 82 | seq_y = self.data_y[r_begin:r_end] 83 | seq_x_mark = self.data_stamp[s_begin:s_end] 84 | seq_y_mark = self.data_stamp[r_begin:r_end] 85 | 86 | return seq_x, seq_y, seq_x_mark, seq_y_mark 87 | 88 | def __len__(self): 89 | return len(self.data_x) - self.seq_len - self.pred_len + 1 90 | 91 | def inverse_transform(self, data): 92 | return self.scaler.inverse_transform(data) 93 | 94 | 95 | class Dataset_ETT_minute(Dataset): 96 | def __init__(self, root_path, flag='train', size=None, 97 | features='S', data_path='ETTm1.csv', 98 | target='OT', scale=True, timeenc=0, freq='t'): 99 | # size [seq_len, pred_len] 100 | # info 101 | if size == None: 102 | self.seq_len = 24 * 4 * 4 103 | self.pred_len = 24 * 4 104 | else: 105 | self.seq_len = size[0] 106 | self.pred_len = size[1] 107 | # init 108 | assert flag in ['train', 'test', 'val'] 109 | type_map = {'train': 0, 'val': 1, 'test': 2} 110 | self.set_type = type_map[flag] 111 | 112 | self.features = features 113 | self.target = target 114 | self.scale = scale 115 | self.timeenc = timeenc 116 | self.freq = freq 117 | 118 | self.root_path = root_path 119 | self.data_path = data_path 120 | self.__read_data__() 121 | 122 | def __read_data__(self): 123 | self.scaler = StandardScaler() 124 | df_raw = pd.read_csv(os.path.join(self.root_path, 125 | self.data_path)) 126 | 127 | border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len] 128 | border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4] 129 | border1 = border1s[self.set_type] 130 | border2 = border2s[self.set_type] 131 | 132 | if self.features == 'M' or self.features == 'MS': 133 | cols_data = df_raw.columns[1:] 134 | df_data = df_raw[cols_data] 135 | elif self.features == 'S': 136 | df_data = df_raw[[self.target]] 137 | 138 | if self.scale: 139 | train_data = df_data[border1s[0]:border2s[0]] 140 | self.scaler.fit(train_data.values) 141 | data = self.scaler.transform(df_data.values) 142 | else: 143 | data = df_data.values 144 | 145 | df_stamp = df_raw[['date']][border1:border2] 146 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 147 | if self.timeenc == 0: 148 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 149 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 150 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 151 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 152 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) 153 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) 154 | data_stamp = df_stamp.drop(['date'], 1).values 155 | elif self.timeenc == 1: 156 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 157 | data_stamp = data_stamp.transpose(1, 0) 158 | 159 | self.data_x = data[border1:border2] 160 | self.data_y = data[border1:border2] 161 | self.data_stamp = data_stamp 162 | 163 | def __getitem__(self, index): 164 | s_begin = index 165 | s_end = s_begin + self.seq_len 166 | r_begin = s_end 167 | r_end = r_begin + self.pred_len 168 | 169 | seq_x = self.data_x[s_begin:s_end] 170 | seq_y = self.data_y[r_begin:r_end] 171 | seq_x_mark = self.data_stamp[s_begin:s_end] 172 | seq_y_mark = self.data_stamp[r_begin:r_end] 173 | 174 | return seq_x, seq_y, seq_x_mark, seq_y_mark 175 | 176 | def __len__(self): 177 | return len(self.data_x) - self.seq_len - self.pred_len + 1 178 | 179 | def inverse_transform(self, data): 180 | return self.scaler.inverse_transform(data) 181 | 182 | 183 | class Dataset_Custom(Dataset): 184 | def __init__(self, root_path, flag='train', size=None, 185 | features='S', data_path='ETTh1.csv', 186 | target='OT', scale=True, timeenc=0, freq='h'): 187 | # size [seq_len, pred_len] 188 | # info 189 | if size == None: 190 | self.seq_len = 24 * 4 * 4 191 | self.pred_len = 24 * 4 192 | else: 193 | self.seq_len = size[0] 194 | self.pred_len = size[1] 195 | # init 196 | assert flag in ['train', 'test', 'val'] 197 | type_map = {'train': 0, 'val': 1, 'test': 2} 198 | self.set_type = type_map[flag] 199 | 200 | self.features = features 201 | self.target = target 202 | self.scale = scale 203 | self.timeenc = timeenc 204 | self.freq = freq 205 | 206 | self.root_path = root_path 207 | self.data_path = data_path 208 | self.__read_data__() 209 | 210 | def __read_data__(self): 211 | self.scaler = StandardScaler() 212 | df_raw = pd.read_csv(os.path.join(self.root_path, 213 | self.data_path)) 214 | 215 | ''' 216 | df_raw.columns: ['date', ...(other features), target feature] 217 | ''' 218 | cols = list(df_raw.columns) 219 | cols.remove(self.target) 220 | cols.remove('date') 221 | df_raw = df_raw[['date'] + cols + [self.target]] 222 | # print(cols) 223 | num_train = int(len(df_raw) * 0.7) 224 | num_test = int(len(df_raw) * 0.2) 225 | num_vali = len(df_raw) - num_train - num_test 226 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 227 | border2s = [num_train, num_train + num_vali, len(df_raw)] 228 | border1 = border1s[self.set_type] 229 | border2 = border2s[self.set_type] 230 | 231 | if self.features == 'M' or self.features == 'MS': 232 | cols_data = df_raw.columns[1:] 233 | df_data = df_raw[cols_data] 234 | elif self.features == 'S': 235 | df_data = df_raw[[self.target]] 236 | 237 | if self.scale: 238 | train_data = df_data[border1s[0]:border2s[0]] 239 | self.scaler.fit(train_data.values) 240 | # print(self.scaler.mean_) 241 | # exit() 242 | data = self.scaler.transform(df_data.values) 243 | else: 244 | data = df_data.values 245 | 246 | df_stamp = df_raw[['date']][border1:border2] 247 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 248 | if self.timeenc == 0: 249 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 250 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 251 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 252 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 253 | data_stamp = df_stamp.drop(['date'], 1).values 254 | elif self.timeenc == 1: 255 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 256 | data_stamp = data_stamp.transpose(1, 0) 257 | 258 | self.data_x = data[border1:border2] 259 | self.data_y = data[border1:border2] 260 | self.data_stamp = data_stamp 261 | 262 | def __getitem__(self, index): 263 | s_begin = index 264 | s_end = s_begin + self.seq_len 265 | r_begin = s_end 266 | r_end = r_begin + self.pred_len 267 | 268 | seq_x = self.data_x[s_begin:s_end] 269 | seq_y = self.data_y[r_begin:r_end] 270 | seq_x_mark = self.data_stamp[s_begin:s_end] 271 | seq_y_mark = self.data_stamp[r_begin:r_end] 272 | 273 | return seq_x, seq_y, seq_x_mark, seq_y_mark 274 | 275 | def __len__(self): 276 | return len(self.data_x) - self.seq_len - self.pred_len + 1 277 | 278 | def inverse_transform(self, data): 279 | return self.scaler.inverse_transform(data) 280 | 281 | 282 | class Dataset_Pred(Dataset): 283 | def __init__(self, root_path, flag='pred', size=None, 284 | features='S', data_path='ETTh1.csv', 285 | target='OT', scale=True, inverse=False, timeenc=0, freq='15min', cols=None): 286 | # size [seq_len, pred_len] 287 | # info 288 | if size == None: 289 | self.seq_len = 24 * 4 * 4 290 | self.pred_len = 24 * 4 291 | else: 292 | self.seq_len = size[0] 293 | self.pred_len = size[1] 294 | # init 295 | assert flag in ['pred'] 296 | 297 | self.features = features 298 | self.target = target 299 | self.scale = scale 300 | self.inverse = inverse 301 | self.timeenc = timeenc 302 | self.freq = freq 303 | self.cols = cols 304 | self.root_path = root_path 305 | self.data_path = data_path 306 | self.__read_data__() 307 | 308 | def __read_data__(self): 309 | self.scaler = StandardScaler() 310 | df_raw = pd.read_csv(os.path.join(self.root_path, 311 | self.data_path)) 312 | ''' 313 | df_raw.columns: ['date', ...(other features), target feature] 314 | ''' 315 | if self.cols: 316 | cols = self.cols.copy() 317 | cols.remove(self.target) 318 | else: 319 | cols = list(df_raw.columns) 320 | cols.remove(self.target) 321 | cols.remove('date') 322 | df_raw = df_raw[['date'] + cols + [self.target]] 323 | border1 = len(df_raw) - self.seq_len 324 | border2 = len(df_raw) 325 | 326 | if self.features == 'M' or self.features == 'MS': 327 | cols_data = df_raw.columns[1:] 328 | df_data = df_raw[cols_data] 329 | elif self.features == 'S': 330 | df_data = df_raw[[self.target]] 331 | 332 | if self.scale: 333 | self.scaler.fit(df_data.values) 334 | data = self.scaler.transform(df_data.values) 335 | else: 336 | data = df_data.values 337 | 338 | tmp_stamp = df_raw[['date']][border1:border2] 339 | tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date) 340 | pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq) 341 | 342 | df_stamp = pd.DataFrame(columns=['date']) 343 | df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:]) 344 | if self.timeenc == 0: 345 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 346 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 347 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 348 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 349 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) 350 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) 351 | data_stamp = df_stamp.drop(['date'], 1).values 352 | elif self.timeenc == 1: 353 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 354 | data_stamp = data_stamp.transpose(1, 0) 355 | 356 | self.data_x = data[border1:border2] 357 | if self.inverse: 358 | self.data_y = df_data.values[border1:border2] 359 | else: 360 | self.data_y = data[border1:border2] 361 | self.data_stamp = data_stamp 362 | 363 | def __getitem__(self, index): 364 | s_begin = index 365 | s_end = s_begin + self.seq_len 366 | r_begin = s_end 367 | r_end = r_begin + self.pred_len 368 | 369 | seq_x = self.data_x[s_begin:s_end] 370 | if self.inverse: 371 | seq_y = self.data_x[r_begin:r_begin] 372 | else: 373 | seq_y = self.data_y[r_begin:r_begin] 374 | seq_x_mark = self.data_stamp[s_begin:s_end] 375 | seq_y_mark = self.data_stamp[r_begin:r_end] 376 | 377 | return seq_x, seq_y, seq_x_mark, seq_y_mark 378 | 379 | def __len__(self): 380 | return len(self.data_x) - self.seq_len + 1 381 | 382 | def inverse_transform(self, data): 383 | return self.scaler.inverse_transform(data) 384 | 385 | 386 | class Dataset_Pretrain(Dataset): 387 | def __init__(self, root_path, flag='train', size=None, 388 | features='S', data_path='ETTh1.csv', 389 | target='OT', scale=True, timeenc=0, freq='h'): 390 | # size [seq_len, pred_len] 391 | # info 392 | if size == None: 393 | self.seq_len = 24 * 4 * 4 394 | self.pred_len = 24 * 4 395 | else: 396 | self.seq_len = size[0] 397 | self.pred_len = size[1] 398 | # init 399 | assert flag in ['train', 'test', 'val'] 400 | type_map = {'train': 0, 'val': 1, 'test': 2} 401 | self.set_type = type_map[flag] 402 | 403 | self.features = features 404 | self.target = target 405 | self.scale = scale 406 | self.timeenc = timeenc 407 | self.freq = freq 408 | 409 | self.root_path = root_path 410 | self.data_path = data_path 411 | self.__read_data__() 412 | 413 | def __read_data__(self): 414 | self.scaler = StandardScaler() 415 | df_raw = pd.read_csv(os.path.join(self.root_path, 416 | self.data_path)) 417 | 418 | ''' 419 | df_raw.columns: ['date', ...(other features), target feature] 420 | ''' 421 | cols = list(df_raw.columns) 422 | cols.remove(self.target) 423 | cols.remove('date') 424 | df_raw = df_raw[['date'] + cols + [self.target]] 425 | # print(cols) 426 | num_train = int(len(df_raw) * 1) 427 | num_test = int(len(df_raw) * 0) 428 | num_vali = len(df_raw) - num_train - num_test 429 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] 430 | border2s = [num_train, num_train + num_vali, len(df_raw)] 431 | border1 = border1s[self.set_type] 432 | border2 = border2s[self.set_type] 433 | 434 | if self.features == 'M' or self.features == 'MS': 435 | cols_data = df_raw.columns[1:] 436 | df_data = df_raw[cols_data] 437 | elif self.features == 'S': 438 | df_data = df_raw[[self.target]] 439 | 440 | if self.scale: 441 | train_data = df_data[border1s[0]:border2s[0]] 442 | self.scaler.fit(train_data.values) 443 | # print(self.scaler.mean_) 444 | # exit() 445 | data = self.scaler.transform(df_data.values) 446 | else: 447 | data = df_data.values 448 | 449 | df_stamp = df_raw[['date']][border1:border2] 450 | df_stamp['date'] = pd.to_datetime(df_stamp.date) 451 | if self.timeenc == 0: 452 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) 453 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) 454 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) 455 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) 456 | data_stamp = df_stamp.drop(['date'], 1).values 457 | elif self.timeenc == 1: 458 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) 459 | data_stamp = data_stamp.transpose(1, 0) 460 | 461 | self.data_x = data[border1:border2] 462 | self.data_y = data[border1:border2] 463 | self.data_stamp = data_stamp 464 | 465 | def __getitem__(self, index): 466 | s_begin = index 467 | s_end = s_begin + self.seq_len 468 | r_begin = s_end 469 | r_end = r_begin + self.pred_len 470 | 471 | seq_x = self.data_x[s_begin:s_end] 472 | seq_y = self.data_y[r_begin:r_end] 473 | seq_x_mark = self.data_stamp[s_begin:s_end] 474 | seq_y_mark = self.data_stamp[r_begin:r_end] 475 | 476 | return seq_x, seq_y, seq_x_mark, seq_y_mark 477 | 478 | def __len__(self): 479 | return len(self.data_x) - self.seq_len - self.pred_len + 1 480 | 481 | def inverse_transform(self, data): 482 | return self.scaler.inverse_transform(data) 483 | -------------------------------------------------------------------------------- /exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | class Exp_Basic(object): 5 | def __init__(self, args): 6 | self.args = args 7 | self.device = self._acquire_device() 8 | self.model = self._build_model().to(self.device) 9 | 10 | def _build_model(self): 11 | raise NotImplementedError 12 | return None 13 | 14 | def _acquire_device(self): 15 | if self.args.use_gpu: 16 | os.environ["CUDA_VISIBLE_DEVICES"] = str( 17 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices 18 | device = torch.device('cuda:{}'.format(self.args.gpu)) 19 | print('Use GPU: cuda:{}'.format(self.args.gpu)) 20 | else: 21 | device = torch.device('cpu') 22 | print('Use CPU') 23 | 24 | return device 25 | 26 | def _get_data(self): 27 | pass 28 | 29 | def vali(self): 30 | pass 31 | 32 | def train(self): 33 | pass 34 | 35 | def test(self): 36 | pass 37 | -------------------------------------------------------------------------------- /exp/exp_main.py: -------------------------------------------------------------------------------- 1 | from data_provider.data_factory import data_provider 2 | from exp.exp_basic import Exp_Basic 3 | from models import PathFormer 4 | from utils.tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop 5 | from utils.metrics import metric 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torch import optim 11 | from torch.optim import lr_scheduler 12 | 13 | import os 14 | import time 15 | 16 | import warnings 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | 20 | warnings.filterwarnings('ignore') 21 | 22 | 23 | class Exp_Main(Exp_Basic): 24 | def __init__(self, args): 25 | super(Exp_Main, self).__init__(args) 26 | 27 | def _build_model(self): 28 | model_dict = { 29 | 'PathFormer': PathFormer, 30 | } 31 | model = model_dict[self.args.model].Model(self.args).float() 32 | 33 | if self.args.use_multi_gpu and self.args.use_gpu: 34 | model = nn.DataParallel(model, device_ids=self.args.device_ids) 35 | return model 36 | 37 | def _get_data(self, flag): 38 | data_set, data_loader = data_provider(self.args, flag) 39 | return data_set, data_loader 40 | 41 | def _select_optimizer(self): 42 | model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate) 43 | return model_optim 44 | 45 | def _select_criterion(self): 46 | criterion = nn.L1Loss() 47 | return criterion 48 | 49 | def vali(self, vali_data, vali_loader, criterion): 50 | total_loss = [] 51 | self.model.eval() 52 | with torch.no_grad(): 53 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader): 54 | batch_x = batch_x.float().to(self.device) 55 | batch_y = batch_y.float() 56 | 57 | batch_x_mark = batch_x_mark.float().to(self.device) 58 | batch_y_mark = batch_y_mark.float().to(self.device) 59 | 60 | 61 | # encoder - decoder 62 | if self.args.use_amp: 63 | with torch.cuda.amp.autocast(): 64 | if self.args.model=='PathFormer': 65 | outputs, balance_loss = self.model(batch_x) 66 | else: 67 | outputs = self.model(batch_x) 68 | 69 | else: 70 | if self.args.model=='PathFormer': 71 | outputs, balance_loss = self.model(batch_x) 72 | else: 73 | outputs = self.model(batch_x) 74 | f_dim = -1 if self.args.features == 'MS' else 0 75 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 76 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 77 | 78 | pred = outputs.detach().cpu() 79 | true = batch_y.detach().cpu() 80 | 81 | loss = criterion(pred, true) 82 | 83 | total_loss.append(loss) 84 | total_loss = np.average(total_loss) 85 | self.model.train() 86 | return total_loss 87 | 88 | def train(self, setting): 89 | train_data, train_loader = self._get_data(flag='train') 90 | vali_data, vali_loader = self._get_data(flag='val') 91 | test_data, test_loader = self._get_data(flag='test') 92 | 93 | path = os.path.join(self.args.checkpoints, setting) 94 | if not os.path.exists(path): 95 | os.makedirs(path) 96 | 97 | total_num = sum(p.numel() for p in self.model.parameters()) 98 | time_now = time.time() 99 | 100 | train_steps = len(train_loader) 101 | early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) 102 | 103 | model_optim = self._select_optimizer() 104 | criterion = self._select_criterion() 105 | 106 | if self.args.use_amp: 107 | scaler = torch.cuda.amp.GradScaler() 108 | 109 | scheduler = lr_scheduler.OneCycleLR(optimizer=model_optim, 110 | steps_per_epoch=train_steps, 111 | pct_start=self.args.pct_start, 112 | epochs=self.args.train_epochs, 113 | max_lr=self.args.learning_rate) 114 | 115 | for epoch in range(self.args.train_epochs): 116 | iter_count = 0 117 | train_loss = [] 118 | self.model.train() 119 | epoch_time = time.time() 120 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): 121 | iter_count += 1 122 | model_optim.zero_grad() 123 | batch_x = batch_x.float().to(self.device) 124 | 125 | batch_y = batch_y.float().to(self.device) 126 | batch_x_mark = batch_x_mark.float().to(self.device) 127 | batch_y_mark = batch_y_mark.float().to(self.device) 128 | 129 | 130 | 131 | # encoder - decoder 132 | if self.args.use_amp: 133 | with torch.cuda.amp.autocast(): 134 | if self.args.model=='PathFormer': 135 | outputs, balance_loss = self.model(batch_x) 136 | else: 137 | outputs = self.model(batch_x) 138 | 139 | f_dim = -1 if self.args.features == 'MS' else 0 140 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 141 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 142 | loss = criterion(outputs, batch_y) 143 | train_loss.append(loss.item()) 144 | else: 145 | if self.args.model == 'PathFormer': 146 | outputs, balance_loss = self.model(batch_x) 147 | else: 148 | outputs = self.model(batch_x) 149 | f_dim = -1 if self.args.features == 'MS' else 0 150 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 151 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 152 | loss = criterion(outputs, batch_y) 153 | if self.args.model=="PathFormer": 154 | loss = loss + balance_loss 155 | train_loss.append(loss.item()) 156 | 157 | if (i + 1) % 100 == 0: 158 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) 159 | speed = (time.time() - time_now) / iter_count 160 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i) 161 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) 162 | iter_count = 0 163 | time_now = time.time() 164 | 165 | if self.args.use_amp: 166 | scaler.scale(loss).backward() 167 | scaler.step(model_optim) 168 | scaler.update() 169 | else: 170 | loss.backward() 171 | model_optim.step() 172 | 173 | if self.args.lradj == 'TST': 174 | adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args, printout=False) 175 | scheduler.step() 176 | 177 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)) 178 | train_loss = np.average(train_loss) 179 | vali_loss = self.vali(vali_data, vali_loader, criterion) 180 | test_loss = self.vali(test_data, test_loader, criterion) 181 | 182 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format( 183 | epoch + 1, train_steps, train_loss, vali_loss, test_loss)) 184 | early_stopping(vali_loss, self.model, path) 185 | if early_stopping.early_stop: 186 | print("Early stopping") 187 | break 188 | 189 | if self.args.lradj != 'TST': 190 | adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args) 191 | else: 192 | print('Updating learning rate to {}'.format(scheduler.get_last_lr()[0])) 193 | 194 | best_model_path = path + '/' + 'checkpoint.pth' 195 | self.model.load_state_dict(torch.load(best_model_path)) 196 | return self.model 197 | 198 | 199 | def test(self, setting, test=0): 200 | test_data, test_loader = self._get_data(flag='test') 201 | 202 | if test: 203 | print('loading model') 204 | self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth'))) 205 | 206 | preds = [] 207 | trues = [] 208 | inputx = [] 209 | folder_path = './test_results/' + setting + '/' 210 | if not os.path.exists(folder_path): 211 | os.makedirs(folder_path) 212 | 213 | self.model.eval() 214 | with torch.no_grad(): 215 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader): 216 | batch_x = batch_x.float().to(self.device) 217 | batch_y = batch_y.float().to(self.device) 218 | 219 | batch_x_mark = batch_x_mark.float().to(self.device) 220 | batch_y_mark = batch_y_mark.float().to(self.device) 221 | 222 | 223 | if self.args.use_amp: 224 | with torch.cuda.amp.autocast(): 225 | if self.args.model=='PathFormer': 226 | outputs, balance_loss = self.model(batch_x) 227 | else: 228 | outputs = self.model(batch_x) 229 | else: 230 | if self.args.model == 'PathFormer': 231 | outputs, balance_loss = self.model(batch_x) 232 | else: 233 | outputs = self.model(batch_x) 234 | f_dim = -1 if self.args.features == 'MS' else 0 235 | 236 | outputs = outputs[:, -self.args.pred_len:, f_dim:] 237 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) 238 | outputs = outputs.detach().cpu().numpy() 239 | batch_y = batch_y.detach().cpu().numpy() 240 | 241 | pred = outputs # outputs.detach().cpu().numpy() # .squeeze() 242 | true = batch_y # batch_y.detach().cpu().numpy() # .squeeze() 243 | 244 | preds.append(pred) 245 | trues.append(true) 246 | inputx.append(batch_x.detach().cpu().numpy()) 247 | 248 | if i % 20 == 0: 249 | input = batch_x.detach().cpu().numpy() 250 | gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0) 251 | pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0) 252 | visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf')) 253 | 254 | if self.args.test_flop: 255 | test_params_flop((batch_x.shape[1], batch_x.shape[2])) 256 | exit() 257 | preds = np.array(preds) 258 | trues = np.array(trues) 259 | inputx = np.array(inputx) 260 | 261 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1]) 262 | trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1]) 263 | inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1]) 264 | 265 | mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues) 266 | print('mse:{}, mae:{}, rse:{}'.format(mse, mae, rse)) 267 | f = open("result.txt", 'a') 268 | f.write(setting + " \n") 269 | f.write('mse:{}, mae:{}, rse:{}'.format(mse, mae, rse)) 270 | f.write('\n') 271 | f.write('\n') 272 | f.close() 273 | return 274 | 275 | def predict(self, setting, load=False): 276 | pred_data, pred_loader = self._get_data(flag='pred') 277 | 278 | if load: 279 | path = os.path.join(self.args.checkpoints, setting) 280 | best_model_path = path + '/' + 'checkpoint.pth' 281 | self.model.load_state_dict(torch.load(best_model_path)) 282 | 283 | 284 | preds = [] 285 | 286 | self.model.eval() 287 | with torch.no_grad(): 288 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader): 289 | batch_x = batch_x.float().to(self.device) 290 | batch_y = batch_y.float() 291 | batch_x_mark = batch_x_mark.float().to(self.device) 292 | batch_y_mark = batch_y_mark.float().to(self.device) 293 | 294 | 295 | # encoder - decoder 296 | if self.args.use_amp: 297 | with torch.cuda.amp.autocast(): 298 | if self.args.model=='PathFormer': 299 | outputs, a_loss = self.model(batch_x) 300 | else: 301 | outputs = self.model(batch_x) 302 | 303 | else: 304 | if self.args.model == 'PathFormer': 305 | outputs, a_loss = self.model(batch_x) 306 | else: 307 | outputs = self.model(batch_x) 308 | pred = outputs.detach().cpu().numpy() # .squeeze() 309 | preds.append(pred) 310 | 311 | preds = np.array(preds) 312 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1]) 313 | 314 | # result save 315 | # folder_path = './results/' + setting + '/' 316 | # if not os.path.exists(folder_path): 317 | # os.makedirs(folder_path) 318 | # 319 | # np.save(folder_path + 'real_prediction.npy', preds) 320 | 321 | return -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionintelligence/pathformer/ea85d82932215e171357da47b3bc82d502344758/figs/framework.png -------------------------------------------------------------------------------- /figs/multi-scale transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionintelligence/pathformer/ea85d82932215e171357da47b3bc82d502344758/figs/multi-scale transformer.png -------------------------------------------------------------------------------- /layers/AMS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions.normal import Normal 4 | from layers.Layer import Transformer_Layer 5 | from utils.Other import SparseDispatcher, FourierLayer, series_decomp_multi, MLP 6 | 7 | 8 | class AMS(nn.Module): 9 | def __init__(self, input_size, output_size, num_experts, device, num_nodes=1, d_model=32, d_ff=64, dynamic=False, 10 | patch_size=[8, 6, 4, 2], noisy_gating=True, k=4, layer_number=1, residual_connection=1, batch_norm=False): 11 | super(AMS, self).__init__() 12 | self.num_experts = num_experts 13 | self.output_size = output_size 14 | self.input_size = input_size 15 | self.k = k 16 | 17 | self.start_linear = nn.Linear(in_features=num_nodes, out_features=1) 18 | self.seasonality_model = FourierLayer(pred_len=0, k=3) 19 | self.trend_model = series_decomp_multi(kernel_size=[4, 8, 12]) 20 | 21 | self.experts = nn.ModuleList() 22 | self.MLPs = nn.ModuleList() 23 | for patch in patch_size: 24 | patch_nums = int(input_size / patch) 25 | self.experts.append(Transformer_Layer(device=device, d_model=d_model, d_ff=d_ff, 26 | dynamic=dynamic, num_nodes=num_nodes, patch_nums=patch_nums, 27 | patch_size=patch, factorized=True, layer_number=layer_number, batch_norm=batch_norm)) 28 | 29 | # self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) 30 | # self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) 31 | self.w_noise = nn.Linear(input_size, num_experts) 32 | self.w_gate = nn.Linear(input_size, num_experts) 33 | 34 | self.residual_connection = residual_connection 35 | self.end_MLP = MLP(input_size=input_size, output_size=output_size) 36 | 37 | self.noisy_gating = noisy_gating 38 | self.softplus = nn.Softplus() 39 | self.softmax = nn.Softmax(1) 40 | self.register_buffer("mean", torch.tensor([0.0])) 41 | self.register_buffer("std", torch.tensor([1.0])) 42 | assert (self.k <= self.num_experts) 43 | 44 | def cv_squared(self, x): 45 | eps = 1e-10 46 | if x.shape[0] == 1: 47 | return torch.tensor([0], device=x.device, dtype=x.dtype) 48 | return x.float().var() / (x.float().mean() ** 2 + eps) 49 | 50 | def _gates_to_load(self, gates): 51 | return (gates > 0).sum(0) 52 | 53 | def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): 54 | batch = clean_values.size(0) 55 | m = noisy_top_values.size(1) 56 | top_values_flat = noisy_top_values.flatten() 57 | 58 | threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k 59 | threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) 60 | is_in = torch.gt(noisy_values, threshold_if_in) 61 | threshold_positions_if_out = threshold_positions_if_in - 1 62 | threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) 63 | normal = Normal(self.mean, self.std) 64 | prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev) 65 | prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev) 66 | prob = torch.where(is_in, prob_if_in, prob_if_out) 67 | return prob 68 | 69 | def seasonality_and_trend_decompose(self, x): 70 | x = x[:, :, :, 0] 71 | _, trend = self.trend_model(x) 72 | seasonality, _ = self.seasonality_model(x) 73 | return x + seasonality + trend 74 | 75 | def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): 76 | x = self.start_linear(x).squeeze(-1) 77 | 78 | # clean_logits = x @ self.w_gate 79 | clean_logits = self.w_gate(x) 80 | if self.noisy_gating and train: 81 | # raw_noise_stddev = x @ self.w_noise 82 | raw_noise_stddev = self.w_noise(x) 83 | noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) 84 | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) 85 | logits = noisy_logits 86 | else: 87 | logits = clean_logits 88 | # calculate topk + 1 that will be needed for the noisy gates 89 | top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1) 90 | 91 | top_k_logits = top_logits[:, :self.k] 92 | top_k_indices = top_indices[:, :self.k] 93 | top_k_gates = self.softmax(top_k_logits) 94 | 95 | zeros = torch.zeros_like(logits, requires_grad=True) 96 | gates = zeros.scatter(1, top_k_indices, top_k_gates) 97 | 98 | if self.noisy_gating and self.k < self.num_experts and train: 99 | load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) 100 | else: 101 | load = self._gates_to_load(gates) 102 | return gates, load 103 | 104 | def forward(self, x, loss_coef=1e-2): 105 | new_x = self.seasonality_and_trend_decompose(x) 106 | 107 | #multi-scale router 108 | gates, load = self.noisy_top_k_gating(new_x, self.training) 109 | # calculate balance loss 110 | importance = gates.sum(0) 111 | balance_loss = self.cv_squared(importance) + self.cv_squared(load) 112 | balance_loss *= loss_coef 113 | dispatcher = SparseDispatcher(self.num_experts, gates) 114 | expert_inputs = dispatcher.dispatch(x) 115 | expert_outputs = [self.experts[i](expert_inputs[i])[0] for i in range(self.num_experts)] 116 | output = dispatcher.combine(expert_outputs) 117 | if self.residual_connection: 118 | output = output + x 119 | return output, balance_loss 120 | 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /layers/Embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | 6 | class PositionalEmbedding(nn.Module): 7 | def __init__(self, d_model, n_position=1024): 8 | super(PositionalEmbedding, self).__init__() 9 | 10 | # Not a parameter 11 | self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_model)) 12 | 13 | def _get_sinusoid_encoding_table(self, n_position, d_model): 14 | ''' Sinusoid position encoding table ''' 15 | def get_position_angle_vec(position): 16 | return [position / np.power(10000, 2 * (hid_j // 2) / d_model) for hid_j in range(d_model)] 17 | 18 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 19 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) 20 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) 21 | 22 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 23 | 24 | def forward(self, x): 25 | 26 | return self.pos_table[:, :x.size(1)].clone().detach() 27 | 28 | 29 | class TokenEmbedding(nn.Module): 30 | def __init__(self, c_in, d_model): 31 | super(TokenEmbedding, self).__init__() 32 | padding = 1 if torch.__version__ >= '1.5.0' else 2 33 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 34 | kernel_size=3, padding=padding, padding_mode='circular', bias=False) 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv1d): 37 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') 38 | 39 | def forward(self, x): 40 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 41 | return x 42 | 43 | 44 | class FixedEmbedding(nn.Module): 45 | def __init__(self, c_in, d_model): 46 | super(FixedEmbedding, self).__init__() 47 | 48 | w = torch.zeros(c_in, d_model).float() 49 | w.require_grad = False 50 | 51 | position = torch.arange(0, c_in).float().unsqueeze(1) 52 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 53 | 54 | w[:, 0::2] = torch.sin(position * div_term) 55 | w[:, 1::2] = torch.cos(position * div_term) 56 | 57 | self.emb = nn.Embedding(c_in, d_model) 58 | self.emb.weight = nn.Parameter(w, requires_grad=False) 59 | 60 | def forward(self, x): 61 | return self.emb(x).detach() 62 | 63 | 64 | class TemporalEmbedding(nn.Module): 65 | def __init__(self, d_model, embed_type='fixed', freq='h'): 66 | super(TemporalEmbedding, self).__init__() 67 | 68 | minute_size = 4 69 | hour_size = 24 70 | weekday_size = 7 71 | day_size = 32 72 | month_size = 13 73 | 74 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding 75 | if freq == 't': 76 | self.minute_embed = Embed(minute_size, d_model) 77 | self.hour_embed = Embed(hour_size, d_model) 78 | self.weekday_embed = Embed(weekday_size, d_model) 79 | self.day_embed = Embed(day_size, d_model) 80 | self.month_embed = Embed(month_size, d_model) 81 | 82 | def forward(self, x): 83 | x = x.long() 84 | 85 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0. 86 | hour_x = self.hour_embed(x[:, :, 3]) 87 | weekday_x = self.weekday_embed(x[:, :, 2]) 88 | day_x = self.day_embed(x[:, :, 1]) 89 | month_x = self.month_embed(x[:, :, 0]) 90 | 91 | return hour_x + weekday_x + day_x + month_x + minute_x 92 | 93 | 94 | class TimeFeatureEmbedding(nn.Module): 95 | def __init__(self, d_model, embed_type='timeF', freq='h'): 96 | super(TimeFeatureEmbedding, self).__init__() 97 | 98 | freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 99 | d_inp = freq_map[freq] 100 | self.embed = nn.Linear(d_inp, d_model, bias=False) 101 | 102 | def forward(self, x): 103 | return self.embed(x) 104 | 105 | 106 | class DataEmbedding(nn.Module): 107 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 108 | super(DataEmbedding, self).__init__() 109 | 110 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 111 | self.position_embedding = PositionalEmbedding(d_model=d_model) 112 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 113 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 114 | d_model=d_model, embed_type=embed_type, freq=freq) 115 | self.dropout = nn.Dropout(p=dropout) 116 | 117 | def forward(self, x, x_mark): 118 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 119 | return self.dropout(x) 120 | 121 | 122 | class DataEmbedding_wo_temp(nn.Module): 123 | def __init__(self, c_in, d_model, dropout=0.1): 124 | super(DataEmbedding_wo_temp, self).__init__() 125 | 126 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 127 | self.position_embedding = PositionalEmbedding(d_model=d_model) 128 | self.dropout = nn.Dropout(p=dropout) 129 | 130 | def forward(self, x, x_mark=None): 131 | x = self.value_embedding(x) + self.position_embedding(x) 132 | 133 | return self.dropout(x) 134 | 135 | 136 | 137 | 138 | def PositionalEncoding(q_len, d_model, normalize=True): 139 | pe = torch.zeros(q_len, d_model) 140 | position = torch.arange(0, q_len).unsqueeze(1) 141 | div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) 142 | pe[:, 0::2] = torch.sin(position * div_term) 143 | pe[:, 1::2] = torch.cos(position * div_term) 144 | if normalize: 145 | pe = pe - pe.mean() 146 | pe = pe / (pe.std() * 10) 147 | return pe 148 | 149 | SinCosPosEncoding = PositionalEncoding 150 | 151 | def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False): 152 | x = .5 if exponential else 1 153 | i = 0 154 | for i in range(100): 155 | cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1 156 | pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose) 157 | if abs(cpe.mean()) <= eps: break 158 | elif cpe.mean() > eps: x += .001 159 | else: x -= .001 160 | i += 1 161 | if normalize: 162 | cpe = cpe - cpe.mean() 163 | cpe = cpe / (cpe.std() * 10) 164 | return cpe 165 | 166 | def Coord1dPosEncoding(q_len, exponential=False, normalize=True): 167 | cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1) 168 | if normalize: 169 | cpe = cpe - cpe.mean() 170 | cpe = cpe / (cpe.std() * 10) 171 | return cpe 172 | 173 | def positional_encoding(pe, learn_pe, q_len, d_model): 174 | # Positional encoding 175 | if pe == None: 176 | W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe 177 | nn.init.uniform_(W_pos, -0.02, 0.02) 178 | learn_pe = False 179 | elif pe == 'zero': 180 | W_pos = torch.empty((q_len, 1)) 181 | nn.init.uniform_(W_pos, -0.02, 0.02) 182 | elif pe == 'zeros': 183 | W_pos = torch.empty((q_len, d_model)) 184 | nn.init.uniform_(W_pos, -0.02, 0.02) 185 | elif pe == 'normal' or pe == 'gauss': 186 | W_pos = torch.zeros((q_len, 1)) 187 | torch.nn.init.normal_(W_pos, mean=0.0, std=0.1) 188 | elif pe == 'uniform': 189 | W_pos = torch.zeros((q_len, 1)) 190 | nn.init.uniform_(W_pos, a=0.0, b=0.1) 191 | elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True) 192 | elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True) 193 | elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True) 194 | elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True) 195 | elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True) 196 | else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \ 197 | 'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)") 198 | return nn.Parameter(W_pos, requires_grad=learn_pe) 199 | -------------------------------------------------------------------------------- /layers/Layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import time 6 | import torch.nn.functional as F 7 | from layers.Embedding import * 8 | 9 | 10 | 11 | class Transformer_Layer(nn.Module): 12 | def __init__(self, device, d_model, d_ff, num_nodes, patch_nums, patch_size, dynamic, factorized, layer_number, batch_norm): 13 | super(Transformer_Layer, self).__init__() 14 | self.device = device 15 | self.d_model = d_model 16 | self.num_nodes = num_nodes 17 | self.dynamic = dynamic 18 | self.patch_nums = patch_nums 19 | self.patch_size = patch_size 20 | self.layer_number = layer_number 21 | self.batch_norm = batch_norm 22 | 23 | 24 | ##intra_patch_attention 25 | self.intra_embeddings = nn.Parameter(torch.rand(self.patch_nums, 1, 1, self.num_nodes, 16), 26 | requires_grad=True) 27 | self.embeddings_generator = nn.ModuleList([nn.Sequential(*[ 28 | nn.Linear(16, self.d_model)]) for _ in range(self.patch_nums)]) 29 | self.intra_d_model = self.d_model 30 | self.intra_patch_attention = Intra_Patch_Attention(self.intra_d_model, factorized=factorized) 31 | self.weights_generator_distinct = WeightGenerator(self.intra_d_model, self.intra_d_model, mem_dim=16, num_nodes=num_nodes, 32 | factorized=factorized, number_of_weights=2) 33 | self.weights_generator_shared = WeightGenerator(self.intra_d_model, self.intra_d_model, mem_dim=None, num_nodes=num_nodes, 34 | factorized=False, number_of_weights=2) 35 | self.intra_Linear = nn.Linear(self.patch_nums, self.patch_nums*self.patch_size) 36 | 37 | 38 | 39 | ##inter_patch_attention 40 | self.stride = patch_size 41 | # patch_num = int((context_window - cut_size) / self.stride + 1) 42 | 43 | self.inter_d_model = self.d_model * self.patch_size 44 | ##inter_embedding 45 | self.emb_linear = nn.Linear(self.inter_d_model, self.inter_d_model) 46 | # Positional encoding 47 | self.W_pos = positional_encoding(pe='zeros', learn_pe=True, q_len=self.patch_nums, d_model=self.inter_d_model) 48 | n_heads = self.d_model 49 | d_k = self.inter_d_model // n_heads 50 | d_v = self.inter_d_model // n_heads 51 | self.inter_patch_attention = Inter_Patch_Attention(self.inter_d_model, self.inter_d_model, n_heads, d_k, d_v, attn_dropout=0, 52 | proj_dropout=0.1, res_attention=False) 53 | 54 | 55 | ##Normalization 56 | self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(self.d_model), Transpose(1,2)) 57 | self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(self.d_model), Transpose(1,2)) 58 | 59 | ##FFN 60 | self.d_ff = d_ff 61 | self.dropout = nn.Dropout(0.1) 62 | self.ff = nn.Sequential(nn.Linear(self.d_model, self.d_ff, bias=True), 63 | nn.GELU(), 64 | nn.Dropout(0.2), 65 | nn.Linear(self.d_ff, self.d_model, bias=True)) 66 | 67 | def forward(self, x): 68 | 69 | new_x = x 70 | batch_size = x.size(0) 71 | intra_out_concat = None 72 | 73 | weights_shared, biases_shared = self.weights_generator_shared() 74 | weights_distinct, biases_distinct = self.weights_generator_distinct() 75 | 76 | ####intra Attention##### 77 | for i in range(self.patch_nums): 78 | t = x[:, i * self.patch_size:(i + 1) * self.patch_size, :, :] 79 | 80 | intra_emb = self.embeddings_generator[i](self.intra_embeddings[i]).expand(batch_size, -1, -1, -1) 81 | t = torch.cat([intra_emb, t], dim=1) 82 | out, attention = self.intra_patch_attention(intra_emb, t, t, weights_distinct, biases_distinct, weights_shared, 83 | biases_shared) 84 | 85 | if intra_out_concat == None: 86 | intra_out_concat = out 87 | 88 | else: 89 | intra_out_concat = torch.cat([intra_out_concat, out], dim=1) 90 | 91 | intra_out_concat = intra_out_concat.permute(0,3,2,1) 92 | intra_out_concat = self.intra_Linear(intra_out_concat) 93 | intra_out_concat = intra_out_concat.permute(0,3,2,1) 94 | 95 | 96 | 97 | ####inter Attention###### 98 | x = x.unfold(dimension=1, size=self.patch_size, step=self.stride) # [b x patch_num x nvar x dim x patch_len] 99 | x = x.permute(0, 2, 1, 3, 4) # [b x nvar x patch_num x dim x patch_len ] 100 | b, nvar, patch_num, dim, patch_len = x.shape 101 | 102 | x = torch.reshape(x, ( 103 | x.shape[0] * x.shape[1], x.shape[2], x.shape[3] * x.shape[-1])) # [b*nvar, patch_num, dim*patch_len] 104 | 105 | x = self.emb_linear(x) 106 | x = self.dropout(x + self.W_pos) 107 | 108 | inter_out, attention = self.inter_patch_attention(Q=x, K=x, V=x) # [b*nvar, patch_num, dim] 109 | inter_out = torch.reshape(inter_out, (b, nvar, inter_out.shape[-2], inter_out.shape[-1])) 110 | inter_out = torch.reshape(inter_out, (b, nvar, inter_out.shape[-2], self.patch_size, self.d_model)) 111 | inter_out = torch.reshape(inter_out, (b, self.patch_size*self.patch_nums, nvar, self.d_model)) #[b, temporal, nvar, dim] 112 | 113 | out = new_x + intra_out_concat + inter_out 114 | if self.batch_norm: 115 | out = self.norm_attn(out.reshape(b*nvar, self.patch_size*self.patch_nums, self.d_model)) 116 | ##FFN 117 | out = self.dropout(out) 118 | out = self.ff(out) + out 119 | if self.batch_norm: 120 | out = self.norm_ffn(out).reshape(b, self.patch_size*self.patch_nums, nvar, self.d_model) 121 | return out, attention 122 | 123 | 124 | 125 | class CustomLinear(nn.Module): 126 | def __init__(self, factorized): 127 | super(CustomLinear, self).__init__() 128 | self.factorized = factorized 129 | 130 | def forward(self, input, weights, biases): 131 | if self.factorized: 132 | return torch.matmul(input.unsqueeze(3), weights).squeeze(3) + biases 133 | else: 134 | return torch.matmul(input, weights) + biases 135 | 136 | 137 | class Intra_Patch_Attention(nn.Module): 138 | def __init__(self, d_model, factorized): 139 | super(Intra_Patch_Attention, self).__init__() 140 | self.head = 2 141 | 142 | if d_model % self.head != 0: 143 | raise Exception('Hidden size is not divisible by the number of attention heads') 144 | 145 | self.head_size = int(d_model // self.head) 146 | self.custom_linear = CustomLinear(factorized) 147 | 148 | def forward(self, query, key, value, weights_distinct, biases_distinct, weights_shared, biases_shared): 149 | batch_size = query.shape[0] 150 | 151 | key = self.custom_linear(key, weights_distinct[0], biases_distinct[0]) 152 | value = self.custom_linear(value, weights_distinct[1], biases_distinct[1]) 153 | query = torch.cat(torch.split(query, self.head_size, dim=-1), dim=0) 154 | key = torch.cat(torch.split(key, self.head_size, dim=-1), dim=0) 155 | value = torch.cat(torch.split(value, self.head_size, dim=-1), dim=0) 156 | 157 | query = query.permute((0, 2, 1, 3)) 158 | key = key.permute((0, 2, 3, 1)) 159 | value = value.permute((0, 2, 1, 3)) 160 | 161 | 162 | 163 | attention = torch.matmul(query, key) 164 | attention /= (self.head_size ** 0.5) 165 | 166 | attention = torch.softmax(attention, dim=-1) 167 | 168 | x = torch.matmul(attention, value) 169 | x = x.permute((0, 2, 1, 3)) 170 | x = torch.cat(torch.split(x, batch_size, dim=0), dim=-1) 171 | 172 | if x.shape[0] == 0: 173 | x = x.repeat(1, 1, 1, int(weights_shared[0].shape[-1] / x.shape[-1])) 174 | 175 | x = self.custom_linear(x, weights_shared[0], biases_shared[0]) 176 | x = torch.relu(x) 177 | x = self.custom_linear(x, weights_shared[1], biases_shared[1]) 178 | return x, attention 179 | 180 | 181 | class Inter_Patch_Attention(nn.Module): 182 | def __init__(self, d_model, out_dim, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., 183 | proj_dropout=0., qkv_bias=True, lsa=False): 184 | super().__init__() 185 | d_k = d_model // n_heads if d_k is None else d_k 186 | d_v = d_model // n_heads if d_v is None else d_v 187 | 188 | self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v 189 | 190 | self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) 191 | self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) 192 | self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias) 193 | 194 | # Scaled Dot-Product Attention (multiple heads) 195 | self.res_attention = res_attention 196 | self.sdp_attn = ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, 197 | res_attention=self.res_attention, lsa=lsa) 198 | 199 | # Poject output 200 | self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, out_dim), nn.Dropout(proj_dropout)) 201 | 202 | 203 | def forward(self, Q, K=None, V=None, prev=None, key_padding_mask=None, attn_mask=None): 204 | 205 | bs = Q.size(0) 206 | if K is None: K = Q 207 | if V is None: V = Q 208 | 209 | # Linear (+ split in multiple heads) 210 | q_s = self.W_Q(Q).view(bs, Q.shape[1], self.n_heads, self.d_k).transpose(1, 211 | 2) # q_s : [bs x n_heads x q_len x d_k] 此处的q_len为patch_num 212 | k_s = self.W_K(K).view(bs, K.shape[1], self.n_heads, self.d_k).permute(0, 2, 3, 213 | 1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3) 214 | v_s = self.W_V(V).view(bs, V.shape[1], self.n_heads, self.d_v).transpose(1, 215 | 2) # v_s : [bs x n_heads x q_len x d_v] 216 | 217 | # Apply Scaled Dot-Product Attention (multiple heads) 218 | if self.res_attention: 219 | output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, 220 | key_padding_mask=key_padding_mask, attn_mask=attn_mask) 221 | else: 222 | output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask) 223 | output = output.transpose(1, 2).contiguous().view(bs, Q.shape[1], 224 | self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v] 225 | output = self.to_out(output) 226 | 227 | return output, attn_weights 228 | 229 | 230 | class ScaledDotProductAttention(nn.Module): 231 | r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer 232 | (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets 233 | by Lee et al, 2021)""" 234 | 235 | def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False): 236 | super().__init__() 237 | self.attn_dropout = nn.Dropout(attn_dropout) 238 | self.res_attention = res_attention 239 | head_dim = d_model // n_heads 240 | self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa) 241 | self.lsa = lsa 242 | 243 | def forward(self, q, k, v, prev=None, key_padding_mask=None, attn_mask=None): 244 | 245 | # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence 246 | attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len] 247 | 248 | # Add pre-softmax attention scores from the previous layer (optional) 249 | if prev is not None: attn_scores = attn_scores + prev 250 | 251 | # Attention mask (optional) 252 | if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len 253 | if attn_mask.dtype == torch.bool: 254 | attn_scores.masked_fill_(attn_mask, -np.inf) 255 | else: 256 | attn_scores += attn_mask 257 | 258 | # Key padding mask (optional) 259 | if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len) 260 | attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf) 261 | 262 | # normalize the attention weights 263 | attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len] 264 | attn_weights = self.attn_dropout(attn_weights) 265 | 266 | # compute the new values given the attention weights 267 | output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v] 268 | 269 | return output, attn_weights 270 | 271 | 272 | class WeightGenerator(nn.Module): 273 | def __init__(self, in_dim, out_dim, mem_dim, num_nodes, factorized, number_of_weights=4): 274 | super(WeightGenerator, self).__init__() 275 | #print('FACTORIZED {}'.format(factorized)) 276 | self.number_of_weights = number_of_weights 277 | self.mem_dim = mem_dim 278 | self.num_nodes = num_nodes 279 | self.factorized = factorized 280 | self.out_dim = out_dim 281 | if self.factorized: 282 | self.memory = nn.Parameter(torch.randn(num_nodes, mem_dim), requires_grad=True).to('cpu') 283 | # self.memory = nn.Parameter(torch.randn(num_nodes, mem_dim), requires_grad=True).to('cuda:0') 284 | self.generator = self.generator = nn.Sequential(*[ 285 | nn.Linear(mem_dim, 64), 286 | nn.Tanh(), 287 | nn.Linear(64, 64), 288 | nn.Tanh(), 289 | nn.Linear(64, 100) 290 | ]) 291 | 292 | self.mem_dim = 10 293 | self.P = nn.ParameterList( 294 | [nn.Parameter(torch.Tensor(in_dim, self.mem_dim), requires_grad=True) for _ in 295 | range(number_of_weights)]) 296 | self.Q = nn.ParameterList( 297 | [nn.Parameter(torch.Tensor(self.mem_dim, out_dim), requires_grad=True) for _ in 298 | range(number_of_weights)]) 299 | self.B = nn.ParameterList( 300 | [nn.Parameter(torch.Tensor(self.mem_dim ** 2, out_dim), requires_grad=True) for _ in 301 | range(number_of_weights)]) 302 | else: 303 | self.P = nn.ParameterList( 304 | [nn.Parameter(torch.Tensor(in_dim, out_dim), requires_grad=True) for _ in range(number_of_weights)]) 305 | self.B = nn.ParameterList( 306 | [nn.Parameter(torch.Tensor(1, out_dim), requires_grad=True) for _ in range(number_of_weights)]) 307 | self.reset_parameters() 308 | 309 | def reset_parameters(self): 310 | list_params = [self.P, self.Q, self.B] if self.factorized else [self.P] 311 | for weight_list in list_params: 312 | for weight in weight_list: 313 | init.kaiming_uniform_(weight, a=math.sqrt(5)) 314 | 315 | if not self.factorized: 316 | for i in range(self.number_of_weights): 317 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.P[i]) 318 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 319 | init.uniform_(self.B[i], -bound, bound) 320 | 321 | def forward(self): 322 | if self.factorized: 323 | memory = self.generator(self.memory.unsqueeze(1)) 324 | bias = [torch.matmul(memory, self.B[i]).squeeze(1) for i in range(self.number_of_weights)] 325 | memory = memory.view(self.num_nodes, self.mem_dim, self.mem_dim) 326 | weights = [torch.matmul(torch.matmul(self.P[i], memory), self.Q[i]) for i in range(self.number_of_weights)] 327 | return weights, bias 328 | else: 329 | return self.P, self.B 330 | 331 | 332 | 333 | class Transpose(nn.Module): 334 | def __init__(self, *dims, contiguous=False): 335 | super().__init__() 336 | self.dims, self.contiguous = dims, contiguous 337 | def forward(self, x): 338 | if self.contiguous: return x.transpose(*self.dims).contiguous() 339 | else: return x.transpose(*self.dims) 340 | -------------------------------------------------------------------------------- /layers/RevIN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RevIN(nn.Module): 5 | def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False): 6 | """ 7 | :param num_features: the number of features or channels 8 | :param eps: a value added for numerical stability 9 | :param affine: if True, RevIN has learnable affine parameters 10 | """ 11 | super(RevIN, self).__init__() 12 | self.num_features = num_features 13 | self.eps = eps 14 | self.affine = affine 15 | self.subtract_last = subtract_last 16 | if self.affine: 17 | self._init_params() 18 | 19 | def forward(self, x, mode:str): 20 | if mode == 'norm': 21 | self._get_statistics(x) 22 | x = self._normalize(x) 23 | elif mode == 'denorm': 24 | x = self._denormalize(x) 25 | else: raise NotImplementedError 26 | return x 27 | 28 | def _init_params(self): 29 | # initialize RevIN params: (C,) 30 | self.affine_weight = nn.Parameter(torch.ones(self.num_features)) 31 | self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) 32 | 33 | def _get_statistics(self, x): 34 | dim2reduce = tuple(range(1, x.ndim-1)) 35 | if self.subtract_last: 36 | self.last = x[:,-1,:].unsqueeze(1) 37 | else: 38 | self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() 39 | self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() 40 | 41 | def _normalize(self, x): 42 | if self.subtract_last: 43 | x = x - self.last 44 | else: 45 | x = x - self.mean 46 | x = x / self.stdev 47 | if self.affine: 48 | x = x * self.affine_weight 49 | x = x + self.affine_bias 50 | return x 51 | 52 | def _denormalize(self, x): 53 | if self.affine: 54 | x = x - self.affine_bias 55 | x = x / (self.affine_weight + self.eps*self.eps) 56 | x = x * self.stdev 57 | if self.subtract_last: 58 | x = x + self.last 59 | else: 60 | x = x + self.mean 61 | return x -------------------------------------------------------------------------------- /models/PathFormer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch 4 | import torch.nn as nn 5 | from torch.distributions.normal import Normal 6 | import numpy as np 7 | from layers.AMS import AMS 8 | from layers.Layer import WeightGenerator, CustomLinear 9 | from layers.RevIN import RevIN 10 | from functools import reduce 11 | from operator import mul 12 | 13 | 14 | class Model(nn.Module): 15 | def __init__(self, configs): 16 | super(Model, self).__init__() 17 | self.layer_nums = configs.layer_nums # 设置pathway的层数 18 | self.num_nodes = configs.num_nodes 19 | self.pre_len = configs.pred_len 20 | self.seq_len = configs.seq_len 21 | self.k = configs.k 22 | self.num_experts_list = configs.num_experts_list 23 | self.patch_size_list = configs.patch_size_list 24 | self.d_model = configs.d_model 25 | self.d_ff = configs.d_ff 26 | self.residual_connection = configs.residual_connection 27 | self.revin = configs.revin 28 | if self.revin: 29 | self.revin_layer = RevIN(num_features=configs.num_nodes, affine=False, subtract_last=False) 30 | 31 | self.start_fc = nn.Linear(in_features=1, out_features=self.d_model) 32 | self.AMS_lists = nn.ModuleList() 33 | self.device = torch.device('cuda:{}'.format(configs.gpu)) 34 | self.batch_norm = configs.batch_norm 35 | 36 | for num in range(self.layer_nums): 37 | self.AMS_lists.append( 38 | AMS(self.seq_len, self.seq_len, self.num_experts_list[num], self.device, k=self.k, 39 | num_nodes=self.num_nodes, patch_size=self.patch_size_list[num], noisy_gating=True, 40 | d_model=self.d_model, d_ff=self.d_ff, layer_number=num + 1, residual_connection=self.residual_connection, batch_norm=self.batch_norm)) 41 | self.projections = nn.Sequential( 42 | nn.Linear(self.seq_len * self.d_model, self.pre_len) 43 | ) 44 | 45 | def forward(self, x): 46 | 47 | balance_loss = 0 48 | # norm 49 | if self.revin: 50 | x = self.revin_layer(x, 'norm') 51 | out = self.start_fc(x.unsqueeze(-1)) 52 | 53 | 54 | batch_size = x.shape[0] 55 | 56 | for layer in self.AMS_lists: 57 | out, aux_loss = layer(out) 58 | balance_loss += aux_loss 59 | 60 | out = out.permute(0,2,1,3).reshape(batch_size, self.num_nodes, -1) 61 | out = self.projections(out).transpose(2, 1) 62 | 63 | # denorm 64 | if self.revin: 65 | out = self.revin_layer(out, 'denorm') 66 | 67 | return out, balance_loss 68 | 69 | 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.7.0 2 | numpy==1.24.4 3 | pandas==2.0.3 4 | Pillow==10.1.0 5 | scikit-learn==1.3.2 6 | scipy==1.10.1 7 | torch==1.10.1+cu111 8 | torchaudio==0.10.1+cu111 9 | torchvision==0.11.2+cu111 10 | tornado==6.4 11 | tqdm==4.66.1 12 | 13 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | from exp.exp_main import Exp_Main 5 | import argparse 6 | import time 7 | 8 | fix_seed = 1024 9 | random.seed(fix_seed) 10 | torch.manual_seed(fix_seed) 11 | np.random.seed(fix_seed) 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser(description='Multivariate Time Series Forecasting') 15 | 16 | 17 | # basic config 18 | parser.add_argument('--is_training', type=int, default=1, help='status') 19 | parser.add_argument('--model', type=str, default='PathFormer', 20 | help='model name, options: [PathFormer]') 21 | parser.add_argument('--model_id', type=str, default="ETT.sh") 22 | 23 | # data loader 24 | parser.add_argument('--data', type=str, default='custom', help='dataset type') 25 | parser.add_argument('--root_path', type=str, default='./dataset/weather', help='root path of the data file') 26 | parser.add_argument('--data_path', type=str, default='weather.csv', help='data file') 27 | parser.add_argument('--features', type=str, default='M', 28 | help='forecasting task, options:[M, S]; M:multivariate predict multivariate, S:univariate predict univariate') 29 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task') 30 | parser.add_argument('--freq', type=str, default='h', 31 | help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') 32 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints') 33 | 34 | # forecasting task 35 | parser.add_argument('--seq_len', type=int, default=96, help='input sequence length') 36 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length') 37 | parser.add_argument('--individual', action='store_true', default=False, 38 | help='DLinear: a linear layer for each variate(channel) individually') 39 | 40 | # model 41 | parser.add_argument('--d_model', type=int, default=16) 42 | parser.add_argument('--d_ff', type=int, default=64) 43 | parser.add_argument('--num_nodes', type=int, default=21) 44 | parser.add_argument('--layer_nums', type=int, default=3) 45 | parser.add_argument('--k', type=int, default=2, help='choose the Top K patch size at the every layer ') 46 | parser.add_argument('--num_experts_list', type=list, default=[4, 4, 4]) 47 | parser.add_argument('--patch_size_list', nargs='+', type=int, default=[16,12,8,32,12,8,6,4,8,6,4,2]) 48 | parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data') 49 | parser.add_argument('--revin', type=int, default=1, help='whether to apply RevIN') 50 | parser.add_argument('--drop', type=float, default=0.1, help='dropout ratio') 51 | parser.add_argument('--embed', type=str, default='timeF', 52 | help='time features encoding, options:[timeF, fixed, learned]') 53 | parser.add_argument('--residual_connection', type=int, default=0) 54 | parser.add_argument('--metric', type=str, default='mae') 55 | parser.add_argument('--batch_norm', type=int, default=0) 56 | 57 | 58 | # optimization 59 | parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') 60 | parser.add_argument('--itr', type=int, default=1, help='experiments times') 61 | parser.add_argument('--train_epochs', type=int, default=20, help='train epochs') 62 | parser.add_argument('--batch_size', type=int, default=64, help='batch size of train input data') 63 | parser.add_argument('--patience', type=int, default=5, help='early stopping patience') 64 | parser.add_argument('--learning_rate', type=float, default=0.001, help='optimizer learning rate') 65 | parser.add_argument('--lradj', type=str, default='TST', help='adjust learning rate') 66 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False) 67 | parser.add_argument('--pct_start', type=float, default=0.4, help='pct_start') 68 | 69 | # GPU 70 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') 71 | parser.add_argument('--gpu', type=int, default=0, help='gpu') 72 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) 73 | parser.add_argument('--devices', type=str, default='2', help='device ids of multile gpus') 74 | parser.add_argument('--test_flop', action='store_true', default=False, help='See utils/tools for usage') 75 | 76 | args = parser.parse_args() 77 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False 78 | 79 | if args.use_gpu and args.use_multi_gpu: 80 | args.dvices = args.devices.replace(' ', '') 81 | device_ids = args.devices.split(',') 82 | args.device_ids = [int(id_) for id_ in device_ids] 83 | args.gpu = args.device_ids[0] 84 | 85 | args.patch_size_list = np.array(args.patch_size_list).reshape(args.layer_nums, -1).tolist() 86 | 87 | print('Args in experiment:') 88 | print(args) 89 | 90 | Exp = Exp_Main 91 | 92 | if args.is_training: 93 | for ii in range(args.itr): 94 | # setting record of experiments 95 | setting = '{}_{}_ft{}_sl{}_pl{}_{}'.format( 96 | args.model_id, 97 | args.model, 98 | args.data_path[:-4], 99 | args.features, 100 | args.seq_len, 101 | args.pred_len, ii) 102 | 103 | exp = Exp(args) # set experiments 104 | 105 | 106 | 107 | 108 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) 109 | exp.train(setting) 110 | 111 | time_now = time.time() 112 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 113 | exp.test(setting) 114 | print('Inference time: ', time.time() - time_now) 115 | 116 | if args.do_predict: 117 | print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 118 | exp.predict(setting, True) 119 | 120 | torch.cuda.empty_cache() 121 | else: 122 | ii = 0 123 | setting = '{}_{}_ft{}_sl{}_pl{}_{}'.format( 124 | args.model_id, 125 | args.model, 126 | args.data_path[:-4], 127 | args.features, 128 | args.seq_len, 129 | args.pred_len, ii) 130 | 131 | exp = Exp(args) # set experiments 132 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) 133 | exp.test(setting, test=1) 134 | torch.cuda.empty_cache() 135 | -------------------------------------------------------------------------------- /scripts/multivariate/ETTh1.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "./logs" ]; then 2 | mkdir ./logs 3 | fi 4 | 5 | if [ ! -d "./logs/LongForecasting" ]; then 6 | mkdir ./logs/LongForecasting 7 | fi 8 | seq_len=96 9 | model_name=PathFormer 10 | 11 | root_path_name=./dataset/ETT/ 12 | data_path_name=ETTh1.csv 13 | model_id_name=ETTh1 14 | data_name=ETTh1 15 | 16 | for pred_len in 96 17 | do 18 | python -u run.py \ 19 | --is_training 1 \ 20 | --root_path $root_path_name \ 21 | --data_path $data_path_name \ 22 | --model_id $model_id_name_$seq_len'_'$pred_len \ 23 | --model $model_name \ 24 | --data $data_name \ 25 | --features M \ 26 | --seq_len $seq_len \ 27 | --pred_len $pred_len \ 28 | --patch_size_list 16 12 8 32 12 8 6 4 8 6 4 2 \ 29 | --num_nodes 7 \ 30 | --layer_nums 3 \ 31 | --batch_norm 0 \ 32 | --residual_connection 1\ 33 | --k 3\ 34 | --d_model 4 \ 35 | --d_ff 64 \ 36 | --train_epochs 30\ 37 | --patience 10\ 38 | --lradj 'TST'\ 39 | --itr 1 \ 40 | --batch_size 128 --learning_rate 0.0005 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 41 | done 42 | 43 | 44 | for pred_len in 192 45 | do 46 | python -u run.py \ 47 | --is_training 1 \ 48 | --root_path $root_path_name \ 49 | --data_path $data_path_name \ 50 | --model_id $model_id_name_$seq_len'_'$pred_len \ 51 | --model $model_name \ 52 | --data $data_name \ 53 | --features M \ 54 | --seq_len $seq_len \ 55 | --pred_len $pred_len \ 56 | --patch_size_list 16 12 8 32 12 8 6 4 8 6 4 2 \ 57 | --num_nodes 7 \ 58 | --layer_nums 3 \ 59 | --batch_norm 0 \ 60 | --residual_connection 1\ 61 | --k 3\ 62 | --d_model 4 \ 63 | --d_ff 64 \ 64 | --train_epochs 30\ 65 | --patience 10\ 66 | --lradj 'TST'\ 67 | --itr 1 \ 68 | --batch_size 128 --learning_rate 0.0005 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 69 | done 70 | 71 | for pred_len in 336 72 | do 73 | python -u run.py \ 74 | --is_training 1 \ 75 | --root_path $root_path_name \ 76 | --data_path $data_path_name \ 77 | --model_id $model_id_name_$seq_len'_'$pred_len \ 78 | --model $model_name \ 79 | --data $data_name \ 80 | --features M \ 81 | --seq_len $seq_len \ 82 | --pred_len $pred_len \ 83 | --patch_size_list 16 12 8 32 12 8 6 16 8 6 4 16 \ 84 | --num_nodes 7 \ 85 | --layer_nums 3 \ 86 | --batch_norm 0 \ 87 | --residual_connection 0\ 88 | --k 3\ 89 | --d_model 4 \ 90 | --d_ff 64 \ 91 | --train_epochs 30\ 92 | --patience 10\ 93 | --lradj 'TST'\ 94 | --itr 1 \ 95 | --batch_size 512 --learning_rate 0.0005 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 96 | done 97 | 98 | 99 | for pred_len in 720 100 | do 101 | python -u run.py \ 102 | --is_training 1 \ 103 | --root_path $root_path_name \ 104 | --data_path $data_path_name \ 105 | --model_id $model_id_name_$seq_len'_'$pred_len \ 106 | --model $model_name \ 107 | --data $data_name \ 108 | --features M \ 109 | --seq_len $seq_len \ 110 | --pred_len $pred_len \ 111 | --patch_size_list 16 12 8 32 12 8 6 4 8 6 4 2 \ 112 | --num_nodes 7 \ 113 | --layer_nums 3 \ 114 | --batch_norm 0 \ 115 | --residual_connection 0\ 116 | --k 2\ 117 | --d_model 4 \ 118 | --d_ff 64 \ 119 | --train_epochs 30\ 120 | --patience 10\ 121 | --lradj 'TST'\ 122 | --itr 1 \ 123 | --batch_size 512 --learning_rate 0.0005 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 124 | done 125 | 126 | -------------------------------------------------------------------------------- /scripts/multivariate/ETTh2.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "./logs" ]; then 2 | mkdir ./logs 3 | fi 4 | 5 | if [ ! -d "./logs/LongForecasting" ]; then 6 | mkdir ./logs/LongForecasting 7 | fi 8 | seq_len=96 9 | model_name=PathFormer 10 | 11 | root_path_name=./dataset/ETT/ 12 | data_path_name=ETTh2.csv 13 | model_id_name=ETTh2 14 | data_name=ETTh2 15 | 16 | 17 | for pred_len in 96 18 | do 19 | python -u run.py \ 20 | --is_training 1 \ 21 | --root_path $root_path_name \ 22 | --data_path $data_path_name \ 23 | --model_id $model_id_name_$seq_len'_'$pred_len \ 24 | --model $model_name \ 25 | --data $data_name \ 26 | --features M \ 27 | --seq_len $seq_len \ 28 | --pred_len $pred_len \ 29 | --patch_size_list 16 12 8 32 12 8 6 4 8 6 4 2 \ 30 | --num_nodes 7 \ 31 | --layer_nums 3 \ 32 | --batch_norm 0 \ 33 | --residual_connection 0 \ 34 | --k 2 \ 35 | --d_model 4 \ 36 | --train_epochs 30\ 37 | --patience 10\ 38 | --lradj 'TST'\ 39 | --itr 1 \ 40 | --batch_size 512 --learning_rate 0.0005 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 41 | done 42 | 43 | 44 | 45 | for pred_len in 192 46 | do 47 | python -u run.py \ 48 | --is_training 1 \ 49 | --root_path $root_path_name \ 50 | --data_path $data_path_name \ 51 | --model_id $model_id_name_$seq_len'_'$pred_len \ 52 | --model $model_name \ 53 | --data $data_name \ 54 | --features M \ 55 | --seq_len $seq_len \ 56 | --pred_len $pred_len \ 57 | --patch_size_list 16 12 8 32 12 8 6 4 8 6 4 2 \ 58 | --num_nodes 7 \ 59 | --layer_nums 3 \ 60 | --batch_norm 0 \ 61 | --residual_connection 0 \ 62 | --k 2 \ 63 | --d_model 8 \ 64 | --train_epochs 30\ 65 | --patience 10\ 66 | --lradj 'TST'\ 67 | --itr 1 \ 68 | --batch_size 512 --learning_rate 0.0005 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 69 | done 70 | 71 | 72 | for pred_len in 336 73 | do 74 | python -u run.py \ 75 | --is_training 1 \ 76 | --root_path $root_path_name \ 77 | --data_path $data_path_name \ 78 | --model_id $model_id_name_$seq_len'_'$pred_len \ 79 | --model $model_name \ 80 | --data $data_name \ 81 | --features M \ 82 | --seq_len $seq_len \ 83 | --pred_len $pred_len \ 84 | --patch_size_list 16 12 8 32 12 8 6 4 8 6 4 2 \ 85 | --num_nodes 7 \ 86 | --layer_nums 3 \ 87 | --batch_norm 0 \ 88 | --residual_connection 0 \ 89 | --k 2 \ 90 | --d_model 4 \ 91 | --train_epochs 30\ 92 | --patience 10\ 93 | --lradj 'TST'\ 94 | --itr 1 \ 95 | --batch_size 512 --learning_rate 0.0005 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 96 | done 97 | 98 | 99 | for pred_len in 720 100 | do 101 | python -u run.py \ 102 | --is_training 1 \ 103 | --root_path $root_path_name \ 104 | --data_path $data_path_name \ 105 | --model_id $model_id_name_$seq_len'_'$pred_len \ 106 | --model $model_name \ 107 | --data $data_name \ 108 | --features M \ 109 | --seq_len $seq_len \ 110 | --pred_len $pred_len \ 111 | --patch_size_list 16 12 8 32 12 8 6 32 8 6 16 12 \ 112 | --num_nodes 7 \ 113 | --layer_nums 3 \ 114 | --batch_norm 0 \ 115 | --residual_connection 0 \ 116 | --k 3 \ 117 | --d_model 16 \ 118 | --train_epochs 30\ 119 | --patience 10\ 120 | --lradj 'TST'\ 121 | --itr 1 \ 122 | --batch_size 512 --learning_rate 0.0005 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 123 | done 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /scripts/multivariate/ETTm1.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "./logs" ]; then 2 | mkdir ./logs 3 | fi 4 | 5 | if [ ! -d "./logs/LongForecasting/" ]; then 6 | mkdir ./logs/LongForecasting/ 7 | fi 8 | seq_len=96 9 | model_name=PathFormer 10 | 11 | root_path_name=./dataset/ETT/ 12 | data_path_name=ETTm1.csv 13 | model_id_name=ETTm1 14 | data_name=ETTm1 15 | 16 | 17 | 18 | 19 | for pred_len in 96 192 336 720 20 | do 21 | python -u run.py \ 22 | --is_training 1 \ 23 | --root_path $root_path_name \ 24 | --data_path $data_path_name \ 25 | --model_id $model_id_name_$seq_len'_'$pred_len \ 26 | --model $model_name \ 27 | --data $data_name \ 28 | --features M \ 29 | --seq_len $seq_len \ 30 | --pred_len $pred_len \ 31 | --patch_size_list 16 12 8 4 12 8 6 4 8 6 2 12 \ 32 | --batch_norm 0 \ 33 | --residual_connection 1\ 34 | --num_nodes 7 \ 35 | --layer_nums 3 \ 36 | --k 3\ 37 | --d_model 8 \ 38 | --d_ff 64 \ 39 | --train_epochs 30\ 40 | --patience 10\ 41 | --lradj 'TST'\ 42 | --itr 1 \ 43 | --batch_size 512 --learning_rate 0.0005 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 44 | done -------------------------------------------------------------------------------- /scripts/multivariate/ETTm2.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "./logs" ]; then 2 | mkdir ./logs 3 | fi 4 | 5 | if [ ! -d "./logs/LongForecasting" ]; then 6 | mkdir ./logs/LongForecasting 7 | fi 8 | seq_len=96 9 | model_name=PathFormer 10 | 11 | root_path_name=./dataset/ETT/ 12 | data_path_name=ETTm2.csv 13 | model_id_name=ETTm2 14 | data_name=ETTm2 15 | 16 | 17 | for pred_len in 96 192 336 720 18 | do 19 | python -u run.py \ 20 | --is_training 1 \ 21 | --root_path $root_path_name \ 22 | --data_path $data_path_name \ 23 | --model_id $model_id_name_$seq_len'_'$pred_len \ 24 | --model $model_name \ 25 | --data $data_name \ 26 | --features M \ 27 | --seq_len $seq_len \ 28 | --pred_len $pred_len \ 29 | --patch_size_list 16 12 8 32 12 8 6 32 8 6 16 12 \ 30 | --num_nodes 7 \ 31 | --layer_nums 3 \ 32 | --k 2\ 33 | --d_model 16 \ 34 | --d_ff 64 \ 35 | --train_epochs 30\ 36 | --patience 10\ 37 | --lradj 'TST'\ 38 | --itr 1 \ 39 | --batch_size 512 --learning_rate 0.001 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 40 | done 41 | 42 | -------------------------------------------------------------------------------- /scripts/multivariate/electricity.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "./logs" ]; then 2 | mkdir ./logs 3 | fi 4 | 5 | if [ ! -d "./logs/LongForecasting" ]; then 6 | mkdir ./logs/LongForecasting 7 | fi 8 | seq_len=96 9 | model_name=PathFormer 10 | 11 | root_path_name=./dataset/electricity 12 | data_path_name=electricity.csv 13 | model_id_name=electricity 14 | data_name=custom 15 | 16 | for pred_len in 96 192 336 720 17 | do 18 | python -u run.py \ 19 | --is_training 1 \ 20 | --root_path $root_path_name \ 21 | --data_path $data_path_name \ 22 | --model_id $model_id_name_$seq_len'_'$pred_len \ 23 | --model $model_name \ 24 | --data $data_name \ 25 | --features M \ 26 | --seq_len $seq_len \ 27 | --pred_len $pred_len \ 28 | --num_nodes 321 \ 29 | --layer_nums 3 \ 30 | --residual_connection 1\ 31 | --k 2\ 32 | --d_model 16 \ 33 | --d_ff 128 \ 34 | --patch_size_list 16 12 8 32 12 8 6 4 8 6 4 2 \ 35 | --train_epochs 50\ 36 | --patience 10 \ 37 | --lradj 'TST' \ 38 | --pct_start 0.2 \ 39 | --itr 1 \ 40 | --batch_size 16 --learning_rate 0.001 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 41 | done 42 | 43 | 44 | -------------------------------------------------------------------------------- /scripts/multivariate/traffic.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "./logs" ]; then 2 | mkdir ./logs 3 | fi 4 | 5 | if [ ! -d "./logs/LongForecasting" ]; then 6 | mkdir ./logs/LongForecasting 7 | fi 8 | seq_len=96 9 | model_name=PathFormer 10 | 11 | root_path_name=./dataset/ 12 | data_path_name=traffic.csv 13 | model_id_name=traffic 14 | data_name=custom 15 | 16 | for pred_len in 96 192 336 720 17 | do 18 | python -u run.py \ 19 | --is_training 1 \ 20 | --root_path $root_path_name \ 21 | --data_path $data_path_name \ 22 | --model_id $model_id_name_$seq_len'_'$pred_len \ 23 | --model $model_name \ 24 | --data $data_name \ 25 | --features M \ 26 | --seq_len $seq_len \ 27 | --pred_len $pred_len \ 28 | --patch_size_list 16 12 8 32 12 8 6 32 8 6 16 12 \ 29 | --num_nodes 862 \ 30 | --layer_nums 3 \ 31 | --k 2\ 32 | --d_model 16 \ 33 | --d_ff 128 \ 34 | --train_epochs 50\ 35 | --residual_connection 1\ 36 | --patience 10 \ 37 | --lradj 'TST' \ 38 | --pct_start 0.2 \ 39 | --itr 1 \ 40 | --batch_size 24 --learning_rate 0.0002 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 41 | done 42 | 43 | 44 | -------------------------------------------------------------------------------- /scripts/multivariate/weather.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "./logs" ]; then 2 | mkdir ./logs 3 | fi 4 | 5 | if [ ! -d "./logs/LongForecasting" ]; then 6 | mkdir ./logs/LongForecasting 7 | fi 8 | seq_len=96 9 | model_name=PathFormer 10 | 11 | root_path_name=./dataset/weather 12 | data_path_name=weather.csv 13 | model_id_name=weather 14 | data_name=custom 15 | 16 | 17 | for pred_len in 96 192 336 720 18 | do 19 | python -u run.py \ 20 | --is_training 1 \ 21 | --root_path $root_path_name \ 22 | --data_path $data_path_name \ 23 | --model_id $model_id_name_$seq_len'_'$pred_len \ 24 | --model $model_name \ 25 | --data $data_name \ 26 | --features M \ 27 | --seq_len $seq_len \ 28 | --pred_len $pred_len \ 29 | --num_nodes 21 \ 30 | --layer_nums 3 \ 31 | --patch_size_list 16 12 8 4 12 8 6 4 8 6 2 12 \ 32 | --residual_connection 1\ 33 | --k 2\ 34 | --d_model 8 \ 35 | --d_ff 64 \ 36 | --train_epochs 30\ 37 | --patience 10\ 38 | --lradj 'TST'\ 39 | --itr 1 \ 40 | --batch_size 256 --learning_rate 0.001 >logs/LongForecasting/$model_name'_'$model_id_name'_'$seq_len'_'$pred_len.log 41 | done 42 | -------------------------------------------------------------------------------- /utils/Other.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | 6 | import torch.fft as fft 7 | from einops import rearrange, reduce, repeat 8 | 9 | 10 | class SparseDispatcher(object): 11 | def __init__(self, num_experts, gates): 12 | """Create a SparseDispatcher.""" 13 | 14 | self._gates = gates 15 | self._num_experts = num_experts 16 | 17 | # sort experts 18 | sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) 19 | _, self._expert_index = sorted_experts.split(1, dim=1) 20 | # get according batch index for each expert 21 | self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0] 22 | self._part_sizes = (gates > 0).sum(0).tolist() 23 | gates_exp = gates[self._batch_index.flatten()] 24 | self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) 25 | 26 | def dispatch(self, inp): 27 | # assigns samples to experts whose gate is nonzero 28 | # expand according to batch index so we can just split by _part_sizes 29 | inp_exp = inp[self._batch_index].squeeze(1) 30 | return torch.split(inp_exp, self._part_sizes, dim=0) 31 | 32 | def combine(self, expert_out, multiply_by_gates=True): 33 | # apply exp to expert outputs, so we are not longer in log space 34 | stitched = torch.cat(expert_out, 0).exp() 35 | if multiply_by_gates: 36 | stitched = torch.einsum("ijkh,ik -> ijkh", stitched, self._nonzero_gates) 37 | zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), expert_out[-1].size(2), expert_out[-1].size(3), 38 | requires_grad=True, device=stitched.device) 39 | # combine samples that have been processed by the same k experts 40 | combined = zeros.index_add(0, self._batch_index, stitched.float()) 41 | # add eps to all zero values in order to avoid nans when going back to log space 42 | combined[combined == 0] = np.finfo(float).eps 43 | # back to log space 44 | return combined.log() 45 | def expert_to_gates(self): 46 | # split nonzero gates for each expert 47 | return torch.split(self._nonzero_gates, self._part_sizes, dim=0) 48 | 49 | 50 | class MLP(nn.Module): 51 | def __init__(self, input_size, output_size): 52 | super(MLP, self).__init__() 53 | self.fc = nn.Conv2d(in_channels=input_size, 54 | out_channels=output_size, 55 | kernel_size=(1, 1), 56 | bias=True) 57 | 58 | def forward(self, x): 59 | out = self.fc(x) 60 | return out 61 | 62 | 63 | 64 | class moving_avg(nn.Module): 65 | """ 66 | Moving average block to highlight the trend of time series 67 | """ 68 | 69 | def __init__(self, kernel_size, stride): 70 | super(moving_avg, self).__init__() 71 | self.kernel_size = kernel_size 72 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) 73 | 74 | def forward(self, x): 75 | # padding on the both ends of time series 76 | front = x[:, 0:1, :].repeat(1, self.kernel_size - 1 - math.floor((self.kernel_size - 1) // 2), 1) 77 | end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2), 1) 78 | x = torch.cat([front, x, end], dim=1) 79 | x = self.avg(x.permute(0, 2, 1)) 80 | x = x.permute(0, 2, 1) 81 | return x 82 | 83 | 84 | class series_decomp(nn.Module): 85 | """ 86 | Series decomposition block 87 | """ 88 | 89 | def __init__(self, kernel_size): 90 | super(series_decomp, self).__init__() 91 | self.moving_avg = moving_avg(kernel_size, stride=1) 92 | 93 | def forward(self, x): 94 | moving_mean = self.moving_avg(x) 95 | res = x - moving_mean 96 | return res, moving_mean 97 | 98 | 99 | class series_decomp_multi(nn.Module): 100 | """ 101 | Series decomposition block 102 | """ 103 | 104 | def __init__(self, kernel_size): 105 | super(series_decomp_multi, self).__init__() 106 | self.moving_avg = [moving_avg(kernel, stride=1) for kernel in kernel_size] 107 | self.layer = torch.nn.Linear(1, len(kernel_size)) 108 | 109 | def forward(self, x): 110 | moving_mean = [] 111 | for func in self.moving_avg: 112 | moving_avg = func(x) 113 | moving_mean.append(moving_avg.unsqueeze(-1)) 114 | moving_mean = torch.cat(moving_mean, dim=-1) 115 | moving_mean = torch.sum(moving_mean * nn.Softmax(-1)(self.layer(x.unsqueeze(-1))), dim=-1) 116 | res = x - moving_mean 117 | return res, moving_mean 118 | 119 | 120 | class FourierLayer(nn.Module): 121 | 122 | def __init__(self, pred_len, k=None, low_freq=1, output_attention=False): 123 | super().__init__() 124 | # self.d_model = d_model 125 | self.pred_len = pred_len 126 | self.k = k 127 | self.low_freq = low_freq 128 | self.output_attention = output_attention 129 | 130 | def forward(self, x): 131 | """x: (b, t, d)""" 132 | 133 | if self.output_attention: 134 | return self.dft_forward(x) 135 | 136 | b, t, d = x.shape 137 | x_freq = fft.rfft(x, dim=1) 138 | 139 | if t % 2 == 0: 140 | x_freq = x_freq[:, self.low_freq:-1] 141 | f = fft.rfftfreq(t)[self.low_freq:-1] 142 | else: 143 | x_freq = x_freq[:, self.low_freq:] 144 | f = fft.rfftfreq(t)[self.low_freq:] 145 | 146 | x_freq, index_tuple = self.topk_freq(x_freq) 147 | f = repeat(f, 'f -> b f d', b=x_freq.size(0), d=x_freq.size(2)) 148 | f = f.to(x_freq.device) 149 | f = rearrange(f[index_tuple], 'b f d -> b f () d').to(x_freq.device) 150 | 151 | return self.extrapolate(x_freq, f, t), None 152 | 153 | def extrapolate(self, x_freq, f, t): 154 | x_freq = torch.cat([x_freq, x_freq.conj()], dim=1) 155 | f = torch.cat([f, -f], dim=1) 156 | t_val = rearrange(torch.arange(t + self.pred_len, dtype=torch.float), 157 | 't -> () () t ()').to(x_freq.device) 158 | 159 | amp = rearrange(x_freq.abs() / t, 'b f d -> b f () d') 160 | phase = rearrange(x_freq.angle(), 'b f d -> b f () d') 161 | 162 | x_time = amp * torch.cos(2 * math.pi * f * t_val + phase) 163 | 164 | return reduce(x_time, 'b f t d -> b t d', 'sum') 165 | 166 | def topk_freq(self, x_freq): 167 | values, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True) 168 | mesh_a, mesh_b = torch.meshgrid(torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2))) 169 | index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1)) 170 | x_freq = x_freq[index_tuple] 171 | 172 | return x_freq, index_tuple 173 | 174 | def dft_forward(self, x): 175 | T = x.size(1) 176 | 177 | dft_mat = fft.fft(torch.eye(T)) 178 | i, j = torch.meshgrid(torch.arange(self.pred_len + T), torch.arange(T)) 179 | omega = np.exp(2 * math.pi * 1j / T) 180 | idft_mat = (np.power(omega, i * j) / T).cfloat() 181 | 182 | x_freq = torch.einsum('ft,btd->bfd', [dft_mat, x.cfloat()]) 183 | 184 | if T % 2 == 0: 185 | x_freq = x_freq[:, self.low_freq:T // 2] 186 | else: 187 | x_freq = x_freq[:, self.low_freq:T // 2 + 1] 188 | 189 | _, indices = torch.topk(x_freq.abs(), self.k, dim=1, largest=True, sorted=True) 190 | indices = indices + self.low_freq 191 | indices = torch.cat([indices, -indices], dim=1) 192 | 193 | dft_mat = repeat(dft_mat, 'f t -> b f t d', b=x.shape[0], d=x.shape[-1]) 194 | idft_mat = repeat(idft_mat, 't f -> b t f d', b=x.shape[0], d=x.shape[-1]) 195 | 196 | mesh_a, mesh_b = torch.meshgrid(torch.arange(x.size(0)), torch.arange(x.size(2))) 197 | 198 | dft_mask = torch.zeros_like(dft_mat) 199 | dft_mask[mesh_a, indices, :, mesh_b] = 1 200 | dft_mat = dft_mat * dft_mask 201 | 202 | idft_mask = torch.zeros_like(idft_mat) 203 | idft_mask[mesh_a, :, indices, mesh_b] = 1 204 | idft_mat = idft_mat * idft_mask 205 | 206 | attn = torch.einsum('bofd,bftd->botd', [idft_mat, dft_mat]).real 207 | return torch.einsum('botd,btd->bod', [attn, x]), rearrange(attn, 'b o t d -> b d o t') -------------------------------------------------------------------------------- /utils/decomposition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import repeat, rearrange 5 | from contextlib import contextmanager 6 | 7 | def svd_denoise(x, cut): 8 | x_ = x.clone().detach() 9 | U, S, V = torch.linalg.svd(x_, full_matrices=False) 10 | S[:, cut:] = 0 11 | 12 | return U @ torch.diag(S[0, :]) @ V 13 | 14 | @contextmanager 15 | def null_context(): 16 | yield 17 | 18 | def exists(val): 19 | return val is not None 20 | 21 | def default(val, d): 22 | return val if exists(val) else d 23 | 24 | class NMF(nn.Module): 25 | def __init__(self, dim, n, ratio=8, K=6, eps=2e-8): 26 | super().__init__() 27 | r = dim // ratio 28 | 29 | D = torch.zeros(dim, r).uniform_(0, 1) 30 | C = torch.zeros(r, n).uniform_(0, 1) 31 | 32 | self.K = K 33 | self.D = nn.Parameter(D) 34 | self.C = nn.Parameter(C) 35 | 36 | self.eps = eps 37 | 38 | def forward(self, x): 39 | b, D, C, eps = x.shape[0], self.D, self.C, self.eps 40 | 41 | # x is made non-negative with relu as proposed in paper 42 | x = F.relu(x) 43 | 44 | D = repeat(D, 'd r -> b d r', b = b) 45 | C = repeat(C, 'r n -> b r n', b = b) 46 | 47 | # transpose 48 | t = lambda tensor: rearrange(tensor, 'b i j -> b j i') 49 | 50 | for k in reversed(range(self.K)): 51 | # only calculate gradients on the last step, per propose 'One-step Gradient' 52 | context = null_context if k == 0 else torch.no_grad 53 | with context(): 54 | C_new = C * ((t(D) @ x) / ((t(D) @ D @ C) + eps)) 55 | D_new = D * ((x @ t(C)) / ((D @ C @ t(C)) + eps)) 56 | C, D = C_new, D_new 57 | 58 | return D @ C -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class TriangularCausalMask(): 4 | def __init__(self, B, L, device="cpu"): 5 | mask_shape = [B, 1, L, L] 6 | with torch.no_grad(): 7 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) 8 | 9 | @property 10 | def mask(self): 11 | return self._mask 12 | 13 | 14 | class ProbMask(): 15 | def __init__(self, B, H, L, index, scores, device="cpu"): 16 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) 17 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) 18 | indicator = _mask_ex[torch.arange(B)[:, None, None], 19 | torch.arange(H)[None, :, None], 20 | index, :].to(device) 21 | self._mask = indicator.view(scores.shape).to(device) 22 | 23 | @property 24 | def mask(self): 25 | return self._mask 26 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def RSE(pred, true): 5 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) 6 | 7 | 8 | def CORR(pred, true): 9 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) 10 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) 11 | d += 1e-12 12 | return 0.01*(u / d).mean(-1) 13 | 14 | 15 | def MAE(pred, true): 16 | return np.mean(np.abs(pred - true)) 17 | 18 | 19 | def MSE(pred, true): 20 | return np.mean((pred - true) ** 2) 21 | 22 | 23 | def RMSE(pred, true): 24 | return np.sqrt(MSE(pred, true)) 25 | 26 | 27 | def MAPE(pred, true): 28 | return np.mean(np.abs((pred - true) / true)) 29 | 30 | 31 | def MSPE(pred, true): 32 | return np.mean(np.square((pred - true) / true)) 33 | 34 | 35 | def metric(pred, true): 36 | mae = MAE(pred, true) 37 | mse = MSE(pred, true) 38 | rmse = RMSE(pred, true) 39 | mape = MAPE(pred, true) 40 | mspe = MSPE(pred, true) 41 | rse = RSE(pred, true) 42 | corr = CORR(pred, true) 43 | 44 | return mae, mse, rmse, mape, mspe, rse, corr -------------------------------------------------------------------------------- /utils/timefeatures.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from pandas.tseries import offsets 6 | from pandas.tseries.frequencies import to_offset 7 | 8 | 9 | class TimeFeature: 10 | def __init__(self): 11 | pass 12 | 13 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 14 | pass 15 | 16 | def __repr__(self): 17 | return self.__class__.__name__ + "()" 18 | 19 | 20 | class SecondOfMinute(TimeFeature): 21 | """Minute of hour encoded as value between [-0.5, 0.5]""" 22 | 23 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 24 | return index.second / 59.0 - 0.5 25 | 26 | 27 | class MinuteOfHour(TimeFeature): 28 | """Minute of hour encoded as value between [-0.5, 0.5]""" 29 | 30 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 31 | return index.minute / 59.0 - 0.5 32 | 33 | 34 | class HourOfDay(TimeFeature): 35 | """Hour of day encoded as value between [-0.5, 0.5]""" 36 | 37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 38 | return index.hour / 23.0 - 0.5 39 | 40 | 41 | class DayOfWeek(TimeFeature): 42 | """Hour of day encoded as value between [-0.5, 0.5]""" 43 | 44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 45 | return index.dayofweek / 6.0 - 0.5 46 | 47 | 48 | class DayOfMonth(TimeFeature): 49 | """Day of month encoded as value between [-0.5, 0.5]""" 50 | 51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 52 | return (index.day - 1) / 30.0 - 0.5 53 | 54 | 55 | class DayOfYear(TimeFeature): 56 | """Day of year encoded as value between [-0.5, 0.5]""" 57 | 58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 59 | return (index.dayofyear - 1) / 365.0 - 0.5 60 | 61 | 62 | class MonthOfYear(TimeFeature): 63 | """Month of year encoded as value between [-0.5, 0.5]""" 64 | 65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 66 | return (index.month - 1) / 11.0 - 0.5 67 | 68 | 69 | class WeekOfYear(TimeFeature): 70 | """Week of year encoded as value between [-0.5, 0.5]""" 71 | 72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: 73 | return (index.isocalendar().week - 1) / 52.0 - 0.5 74 | 75 | 76 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: 77 | """ 78 | Returns a list of time features that will be appropriate for the given frequency string. 79 | Parameters 80 | ---------- 81 | freq_str 82 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. 83 | """ 84 | 85 | features_by_offsets = { 86 | offsets.YearEnd: [], 87 | offsets.QuarterEnd: [MonthOfYear], 88 | offsets.MonthEnd: [MonthOfYear], 89 | offsets.Week: [DayOfMonth, WeekOfYear], 90 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], 91 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], 92 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], 93 | offsets.Minute: [ 94 | MinuteOfHour, 95 | HourOfDay, 96 | DayOfWeek, 97 | DayOfMonth, 98 | DayOfYear, 99 | ], 100 | offsets.Second: [ 101 | SecondOfMinute, 102 | MinuteOfHour, 103 | HourOfDay, 104 | DayOfWeek, 105 | DayOfMonth, 106 | DayOfYear, 107 | ], 108 | } 109 | 110 | offset = to_offset(freq_str) 111 | 112 | for offset_type, feature_classes in features_by_offsets.items(): 113 | if isinstance(offset, offset_type): 114 | return [cls() for cls in feature_classes] 115 | 116 | supported_freq_msg = f""" 117 | Unsupported frequency {freq_str} 118 | The following frequencies are supported: 119 | Y - yearly 120 | alias: A 121 | M - monthly 122 | W - weekly 123 | D - daily 124 | B - business days 125 | H - hourly 126 | T - minutely 127 | alias: min 128 | S - secondly 129 | """ 130 | raise RuntimeError(supported_freq_msg) 131 | 132 | 133 | def time_features(dates, freq='h'): 134 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) 135 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | plt.switch_backend('agg') 6 | 7 | def adjust_learning_rate(optimizer,scheduler, epoch, args, printout=True): 8 | # lr = args.learning_rate * (0.2 ** (epoch // 2)) 9 | if args.lradj == 'type1': 10 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} 11 | elif args.lradj == 'type2': 12 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 2))} 13 | elif args.lradj == 'type3': 14 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 3))} 15 | elif args.lradj == 'type4': 16 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 4))} 17 | elif args.lradj == 'type5': 18 | lr_adjust = { 19 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 20 | 10: 5e-7, 15: 1e-7, 20: 5e-8 21 | } 22 | elif args.lradj == 'constant': 23 | lr_adjust = {epoch: args.learning_rate} 24 | elif args.lradj == '3': 25 | lr_adjust = {epoch: args.learning_rate if epoch < 10 else args.learning_rate*0.1} 26 | elif args.lradj == '4': 27 | lr_adjust = {epoch: args.learning_rate if epoch < 15 else args.learning_rate*0.1} 28 | elif args.lradj == '5': 29 | lr_adjust = {epoch: args.learning_rate if epoch < 25 else args.learning_rate*0.1} 30 | elif args.lradj == '6': 31 | lr_adjust = {epoch: args.learning_rate if epoch < 5 else args.learning_rate*0.1} 32 | elif args.lradj == 'TST': 33 | lr_adjust = {epoch: scheduler.get_last_lr()[0]} 34 | if epoch in lr_adjust.keys(): 35 | lr = lr_adjust[epoch] 36 | for param_group in optimizer.param_groups: 37 | param_group['lr'] = lr 38 | if printout: print('Updating learning rate to {}'.format(lr)) 39 | 40 | class EarlyStopping: 41 | def __init__(self, patience=7, verbose=False, delta=0): 42 | self.patience = patience 43 | self.verbose = verbose 44 | self.counter = 0 45 | self.best_score = None 46 | self.early_stop = False 47 | self.val_loss_min = np.Inf 48 | self.delta = delta 49 | 50 | def __call__(self, val_loss, model, path): 51 | score = -val_loss 52 | if self.best_score is None: 53 | self.best_score = score 54 | self.save_checkpoint(val_loss, model, path) 55 | elif score < self.best_score + self.delta: 56 | self.counter += 1 57 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 58 | if self.counter >= self.patience: 59 | self.early_stop = True 60 | else: 61 | self.best_score = score 62 | self.save_checkpoint(val_loss, model, path) 63 | self.counter = 0 64 | 65 | def save_checkpoint(self, val_loss, model, path): 66 | if self.verbose: 67 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 68 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 69 | self.val_loss_min = val_loss 70 | 71 | 72 | class dotdict(dict): 73 | """dot.notation access to dictionary attributes""" 74 | __getattr__ = dict.get 75 | __setattr__ = dict.__setitem__ 76 | __delattr__ = dict.__delitem__ 77 | 78 | 79 | class StandardScaler(): 80 | def __init__(self, mean, std): 81 | self.mean = mean 82 | self.std = std 83 | 84 | def transform(self, data): 85 | return (data - self.mean) / self.std 86 | 87 | def inverse_transform(self, data): 88 | return (data * self.std) + self.mean 89 | 90 | 91 | def visual(true, preds=None, name='./pic/test.pdf'): 92 | """ 93 | Results visualization 94 | """ 95 | plt.style.use('ggplot') 96 | plt.figure() 97 | plt.plot(true, label='GroundTruth', linewidth=2) 98 | if preds is not None: 99 | plt.plot(preds, label='Prediction', linewidth=2) 100 | plt.legend(loc="upper right") 101 | plt.savefig(name,bbox_inches='tight') 102 | 103 | def test_params_flop(model,x_shape): 104 | """ 105 | If you want to thest former's flop, you need to give default value to inputs in model.forward(), the following code can only pass one argument to forward() 106 | """ 107 | model_params = 0 108 | for parameter in model.parameters(): 109 | model_params += parameter.numel() 110 | print('INFO: Trainable parameter count: {:.2f}M'.format(model_params / 1000000.0)) 111 | from ptflops import get_model_complexity_info 112 | with torch.cuda.device(0): 113 | macs, params = get_model_complexity_info(model.cuda(), x_shape, as_strings=True, print_per_layer_stat=True) 114 | # print('Flops:' + flops) 115 | # print('Params:' + params) 116 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 117 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 118 | 119 | 120 | --------------------------------------------------------------------------------