├── README.md
├── data_provider
├── __pycache__
│ ├── data_factory.cpython-38.pyc
│ ├── data_factory.cpython-39.pyc
│ ├── data_loader.cpython-38.pyc
│ └── data_loader.cpython-39.pyc
├── data_factory.py
└── data_loader.py
├── exp
├── __pycache__
│ ├── exp_basic.cpython-38.pyc
│ ├── exp_basic.cpython-39.pyc
│ ├── exp_main.cpython-38.pyc
│ └── exp_main.cpython-39.pyc
├── exp_basic.py
├── exp_main.py
└── exp_stat.py
├── layers
├── AutoCorrelation.py
├── Autoformer_EncDec.py
├── Embed.py
├── MSGBlock.py
├── SelfAttention_Family.py
├── Transformer_EncDec.py
└── __pycache__
│ ├── AutoCorrelation.cpython-38.pyc
│ ├── AutoCorrelation.cpython-39.pyc
│ ├── Autoformer_EncDec.cpython-38.pyc
│ ├── Autoformer_EncDec.cpython-39.pyc
│ ├── Embed.cpython-38.pyc
│ ├── Embed.cpython-39.pyc
│ ├── MSGBlock.cpython-39.pyc
│ ├── SelfAttention_Family.cpython-38.pyc
│ ├── SelfAttention_Family.cpython-39.pyc
│ ├── Transformer_EncDec.cpython-38.pyc
│ └── Transformer_EncDec.cpython-39.pyc
├── models
├── Autoformer.py
├── DLinear.py
├── Informer.py
├── MSGNet.py
└── __pycache__
│ ├── Autoformer.cpython-39.pyc
│ ├── DLinear.cpython-39.pyc
│ ├── Informer.cpython-39.pyc
│ └── MSGNet.cpython-39.pyc
├── pic
├── main_result.jpg
├── model1.jpg
└── model2.jpg
├── run_longExp.py
├── scripts
├── ETTh1.sh
├── ETTh2.sh
├── ETTm1.sh
├── ETTm2.sh
├── Flight.sh
├── electricity.sh
├── exchange.sh
└── weather.sh
└── utils
├── __pycache__
├── masking.cpython-38.pyc
├── masking.cpython-39.pyc
├── metrics.cpython-38.pyc
├── metrics.cpython-39.pyc
├── timefeatures.cpython-38.pyc
├── timefeatures.cpython-39.pyc
├── tools.cpython-38.pyc
└── tools.cpython-39.pyc
├── masking.py
├── metrics.py
├── timefeatures.py
└── tools.py
/README.md:
--------------------------------------------------------------------------------
1 | # MSGNet (AAAI2024)
2 |
3 | Paper Link:[MSGNet: Learning Multi-Scale Inter-Series Correlations for Multivariate Time Series Forecasting](https://arxiv.org/abs/2401.00423)
4 |
5 | ## Usage
6 |
7 | - Train and evaluate MSGNet
8 | - You can use the following command:`sh ./scripts/ETTh1.sh`.
9 |
10 | - Train your model
11 | - Add model file in the folder `./models/your_model.py`.
12 | - Add model in the ***class*** Exp_Main.
13 |
14 | - Flight dataset
15 | - You can obtain the dataset from [Google Drive](https://drive.google.com/drive/folders/1JSZByfM0Ghat3g_D3a-puTZ2JsfebNWL?usp=sharing). Then please place it in the folder `./dataset`.
16 |
17 | ## Model
18 |
19 | MSGNet employs several ScaleGraph blocks, each encompassing three pivotal modules: an FFT module for multi-scale data identification, an adaptive graph convolution module for inter-series correlation learning within a time scale, and a multi-head attention module for intra-series correlation learning.
20 |
21 |
22 |

23 |
24 |
25 | ## Main Results
26 |
27 | Forecast results with 96 review window and prediction length {96, 192, 336, 720}. The best result is represented in bold, followed by underline.
28 |
29 |
30 |

31 |
32 |
33 | ## Citation
34 |
35 | ```
36 | @article{cai2023msgnet,
37 | title={MSGNet: Learning Multi-Scale Inter-Series Correlations for Multivariate Time Series Forecasting},
38 | author={Cai, Wanlin and Liang, Yuxuan and Liu, Xianggen and Feng, Jianshuai and Wu, Yuankai},
39 | journal={arXiv preprint arXiv:2401.00423},
40 | year={2023}
41 | }
42 | ```
43 |
44 | ## Acknowledgement
45 |
46 | We appreciate the valuable contributions of the following GitHub.
47 |
48 | - LTSF-Linear (https://github.com/cure-lab/LTSF-Linear)
49 | - TimesNet (https://github.com/thuml/TimesNet)
50 | - Time-Series-Library (https://github.com/thuml/Time-Series-Library)
51 | - MTGnn (https://github.com/nnzhan/MTGNN)
52 | - Autoformer (https://github.com/thuml/Autoformer)
53 | - Informer (https://github.com/zhouhaoyi/Informer2020)
54 |
--------------------------------------------------------------------------------
/data_provider/__pycache__/data_factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/data_provider/__pycache__/data_factory.cpython-38.pyc
--------------------------------------------------------------------------------
/data_provider/__pycache__/data_factory.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/data_provider/__pycache__/data_factory.cpython-39.pyc
--------------------------------------------------------------------------------
/data_provider/__pycache__/data_loader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/data_provider/__pycache__/data_loader.cpython-38.pyc
--------------------------------------------------------------------------------
/data_provider/__pycache__/data_loader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/data_provider/__pycache__/data_loader.cpython-39.pyc
--------------------------------------------------------------------------------
/data_provider/data_factory.py:
--------------------------------------------------------------------------------
1 | from .data_loader import Dataset_ETT_hour, Dataset_ETT_minute, \
2 | Dataset_Custom, Dataset_Pred,Dataset_Flight
3 | from torch.utils.data import DataLoader
4 |
5 | data_dict = {
6 | 'ETTh1': Dataset_ETT_hour,
7 | 'ETTh2': Dataset_ETT_hour,
8 | 'ETTm1': Dataset_ETT_minute,
9 | 'ETTm2': Dataset_ETT_minute,
10 | 'custom': Dataset_Custom,
11 | 'Flight':Dataset_Flight,
12 | }
13 |
14 |
15 | # flag = 'train' or 'val' or 'test'
16 | def data_provider(args, flag):
17 | Data = data_dict[args.data]
18 | #time features encoding, options: [timeF, fixed, learned]
19 | timeenc = 0 if args.embed != 'timeF' else 1
20 |
21 | if flag == 'test':
22 | shuffle_flag = False
23 | drop_last = True
24 | batch_size = args.batch_size
25 | freq = args.freq
26 |
27 | elif flag == 'pred':
28 | shuffle_flag = False
29 | drop_last = False
30 | batch_size = 1
31 | freq = args.freq
32 | Data = Dataset_Pred
33 |
34 | else:
35 | shuffle_flag = True
36 | drop_last = True
37 | batch_size = args.batch_size
38 | freq = args.freq
39 |
40 | data_set = Data(
41 | root_path=args.root_path,
42 | data_path=args.data_path,
43 | flag=flag,
44 | size=[args.seq_len, args.label_len, args.pred_len],
45 | features=args.features,
46 | target=args.target,
47 | timeenc=timeenc,
48 | freq=freq,
49 | seasonal_patterns = args.seasonal_patterns
50 | )
51 | print(flag, len(data_set))
52 | data_loader = DataLoader(
53 | data_set,
54 | batch_size=batch_size,
55 | shuffle=shuffle_flag,
56 | num_workers=args.num_workers,
57 | drop_last=drop_last)
58 | return data_set, data_loader
59 |
--------------------------------------------------------------------------------
/data_provider/data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | import os
5 | import torch
6 | from torch.utils.data import Dataset, DataLoader
7 | from sklearn.preprocessing import StandardScaler
8 | from utils.timefeatures import time_features
9 | import warnings
10 |
11 | warnings.filterwarnings('ignore')
12 |
13 | #for Flight 4:4:2 split
14 | class Dataset_Flight(Dataset):
15 | def __init__(self, root_path, flag='train', size=None,
16 | features='S', data_path='Flight.csv',
17 | target='OT', scale=True, timeenc=0, freq='h',seasonal_patterns=None):
18 | # size [seq_len, label_len, pred_len]
19 | # info
20 | if size == None:
21 | self.seq_len = 24 * 4 * 4
22 | self.label_len = 24 * 4
23 | self.pred_len = 24 * 4
24 | else:
25 | self.seq_len = size[0]
26 | self.label_len = size[1]
27 | self.pred_len = size[2]
28 | # init
29 | assert flag in ['train', 'test', 'val']
30 | type_map = {'train': 0, 'val': 1, 'test': 2}
31 | self.set_type = type_map[flag]
32 |
33 | self.features = features
34 | self.target = target
35 | self.scale = scale
36 | self.timeenc = timeenc
37 | self.freq = freq
38 |
39 | self.root_path = root_path
40 | self.data_path = data_path
41 | self.__read_data__()
42 |
43 | def __read_data__(self):
44 | self.scaler = StandardScaler()
45 | df_raw = pd.read_csv(os.path.join(self.root_path,
46 | self.data_path))
47 |
48 | '''
49 | df_raw.columns: ['date', ...(other features), target feature]
50 | '''
51 | cols = list(df_raw.columns)
52 | cols.remove(self.target)
53 | cols.remove('date')
54 | df_raw = df_raw[['date'] + cols + [self.target]]
55 | # print(cols)
56 | num_train = int(len(df_raw) * 0.4)
57 | num_test = int(len(df_raw) * 0.2)
58 | num_vali = len(df_raw) - num_train - num_test
59 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
60 | border2s = [num_train, num_train + num_vali, len(df_raw)]
61 | border1 = border1s[self.set_type]
62 | border2 = border2s[self.set_type]
63 |
64 | if self.features == 'M' or self.features == 'MS':
65 | cols_data = df_raw.columns[1:]
66 | df_data = df_raw[cols_data]
67 | elif self.features == 'S':
68 | df_data = df_raw[[self.target]]
69 |
70 | if self.scale:
71 | train_data = df_data[border1s[0]:border2s[0]]
72 | self.scaler.fit(train_data.values)
73 | data = self.scaler.transform(df_data.values)
74 | else:
75 | data = df_data.values
76 |
77 | df_stamp = df_raw[['date']][border1:border2]
78 | df_stamp['date'] = pd.to_datetime(df_stamp.date)
79 | if self.timeenc == 0:
80 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
81 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
82 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
83 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
84 | data_stamp = df_stamp.drop(['date'], 1).values
85 | elif self.timeenc == 1:
86 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
87 | data_stamp = data_stamp.transpose(1, 0)
88 |
89 | self.data_x = data[border1:border2]
90 | self.data_y = data[border1:border2]
91 | self.data_stamp = data_stamp
92 |
93 | def __getitem__(self, index):
94 | s_begin = index
95 | s_end = s_begin + self.seq_len
96 | r_begin = s_end - self.label_len
97 | r_end = r_begin + self.label_len + self.pred_len
98 |
99 | seq_x = self.data_x[s_begin:s_end]
100 | seq_y = self.data_y[r_begin:r_end]
101 | seq_x_mark = self.data_stamp[s_begin:s_end]
102 | seq_y_mark = self.data_stamp[r_begin:r_end]
103 |
104 | return seq_x, seq_y, seq_x_mark, seq_y_mark
105 |
106 | def __len__(self):
107 | return len(self.data_x) - self.seq_len - self.pred_len + 1
108 |
109 | def inverse_transform(self, data):
110 | return self.scaler.inverse_transform(data)
111 |
112 | class Dataset_Custom(Dataset):
113 | def __init__(self, root_path, flag='train', size=None,
114 | features='S', data_path='ETTh1.csv',
115 | target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
116 | # size [seq_len, label_len, pred_len]
117 | # info
118 | if size == None:
119 | self.seq_len = 24 * 4 * 4
120 | self.label_len = 24 * 4
121 | self.pred_len = 24 * 4
122 | else:
123 | self.seq_len = size[0]
124 | self.label_len = size[1]
125 | self.pred_len = size[2]
126 | # init
127 | assert flag in ['train', 'test', 'val']
128 | type_map = {'train': 0, 'val': 1, 'test': 2}
129 | self.set_type = type_map[flag]
130 |
131 | self.features = features
132 | self.target = target
133 | self.scale = scale
134 | self.timeenc = timeenc
135 | self.freq = freq
136 |
137 | self.root_path = root_path
138 | self.data_path = data_path
139 | self.__read_data__()
140 |
141 | def __read_data__(self):
142 | self.scaler = StandardScaler()
143 | df_raw = pd.read_csv(os.path.join(self.root_path,
144 | self.data_path))
145 |
146 | '''
147 | df_raw.columns: ['date', ...(other features), target feature]
148 | '''
149 | cols = list(df_raw.columns)
150 | cols.remove(self.target)
151 | cols.remove('date')
152 | df_raw = df_raw[['date'] + cols + [self.target]]
153 | # print(cols)
154 | num_train = int(len(df_raw) * 0.7)
155 | num_test = int(len(df_raw) * 0.2)
156 | num_vali = len(df_raw) - num_train - num_test
157 | border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
158 | border2s = [num_train, num_train + num_vali, len(df_raw)]
159 | border1 = border1s[self.set_type]
160 | border2 = border2s[self.set_type]
161 |
162 | if self.features == 'M' or self.features == 'MS':
163 | cols_data = df_raw.columns[1:]
164 | df_data = df_raw[cols_data]
165 | elif self.features == 'S':
166 | df_data = df_raw[[self.target]]
167 |
168 | if self.scale:
169 | train_data = df_data[border1s[0]:border2s[0]]
170 | self.scaler.fit(train_data.values)
171 | # print(self.scaler.mean_)
172 | # exit()
173 | data = self.scaler.transform(df_data.values)
174 | else:
175 | data = df_data.values
176 |
177 | df_stamp = df_raw[['date']][border1:border2]
178 | df_stamp['date'] = pd.to_datetime(df_stamp.date)
179 | if self.timeenc == 0:
180 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
181 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
182 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
183 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
184 | data_stamp = df_stamp.drop(['date'], 1).values
185 | elif self.timeenc == 1:
186 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
187 | data_stamp = data_stamp.transpose(1, 0)
188 |
189 | self.data_x = data[border1:border2]
190 | self.data_y = data[border1:border2]
191 | self.data_stamp = data_stamp
192 |
193 | def __getitem__(self, index):
194 | s_begin = index
195 | s_end = s_begin + self.seq_len
196 | r_begin = s_end - self.label_len
197 | r_end = r_begin + self.label_len + self.pred_len
198 |
199 | seq_x = self.data_x[s_begin:s_end]
200 | seq_y = self.data_y[r_begin:r_end]
201 | seq_x_mark = self.data_stamp[s_begin:s_end]
202 | seq_y_mark = self.data_stamp[r_begin:r_end]
203 |
204 | return seq_x, seq_y, seq_x_mark, seq_y_mark
205 |
206 | def __len__(self):
207 | return len(self.data_x) - self.seq_len - self.pred_len + 1
208 |
209 | def inverse_transform(self, data):
210 | return self.scaler.inverse_transform(data)
211 |
212 | class Dataset_Pred(Dataset):
213 | def __init__(self, root_path, flag='pred', size=None,
214 | features='S', data_path='ETTh1.csv',
215 | target='OT', scale=True, inverse=False, timeenc=0, freq='15min', seasonal_patterns=None,cols=None):
216 | # size [seq_len, label_len, pred_len]
217 | # info
218 | if size == None:
219 | self.seq_len = 24 * 4 * 4
220 | self.label_len = 24 * 4
221 | self.pred_len = 24 * 4
222 | else:
223 | self.seq_len = size[0]
224 | self.label_len = size[1]
225 | self.pred_len = size[2]
226 | # init
227 | assert flag in ['pred']
228 |
229 | self.features = features
230 | self.target = target
231 | self.scale = scale
232 | self.inverse = inverse
233 | self.timeenc = timeenc
234 | self.freq = freq
235 | self.cols = cols
236 | self.root_path = root_path
237 | self.data_path = data_path
238 | self.__read_data__()
239 |
240 | def __read_data__(self):
241 | self.scaler = StandardScaler()
242 | df_raw = pd.read_csv(os.path.join(self.root_path,
243 | self.data_path))
244 | '''
245 | df_raw.columns: ['date', ...(other features), target feature]
246 | '''
247 | if self.cols:
248 | cols = self.cols.copy()
249 | cols.remove(self.target)
250 | else:
251 | cols = list(df_raw.columns)
252 | cols.remove(self.target)
253 | cols.remove('date')
254 | df_raw = df_raw[['date'] + cols + [self.target]]
255 | border1 = len(df_raw) - self.seq_len
256 | border2 = len(df_raw)
257 |
258 | if self.features == 'M' or self.features == 'MS':
259 | cols_data = df_raw.columns[1:]
260 | df_data = df_raw[cols_data]
261 | elif self.features == 'S':
262 | df_data = df_raw[[self.target]]
263 | if self.scale:
264 | self.scaler.fit(df_data.values)
265 | data = self.scaler.transform(df_data.values)
266 | else:
267 | data = df_data.values
268 |
269 | tmp_stamp = df_raw[['date']][border1:border2]
270 | tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date)
271 | pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq)
272 |
273 | df_stamp = pd.DataFrame(columns=['date'])
274 | df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:])
275 | if self.timeenc == 0:
276 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
277 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
278 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
279 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
280 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
281 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
282 | data_stamp = df_stamp.drop(['date'], 1).values
283 | elif self.timeenc == 1:
284 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
285 | data_stamp = data_stamp.transpose(1, 0)
286 |
287 | self.data_x = data[border1:border2]
288 | if self.inverse:
289 | self.data_y = df_data.values[border1:border2]
290 | else:
291 | self.data_y = data[border1:border2]
292 | self.data_stamp = data_stamp
293 |
294 | def __getitem__(self, index):
295 | s_begin = index
296 | s_end = s_begin + self.seq_len
297 | r_begin = s_end - self.label_len
298 | r_end = r_begin + self.label_len + self.pred_len
299 |
300 | seq_x = self.data_x[s_begin:s_end]
301 | if self.inverse:
302 | seq_y = self.data_x[r_begin:r_begin + self.label_len]
303 | else:
304 | seq_y = self.data_y[r_begin:r_begin + self.label_len]
305 | seq_x_mark = self.data_stamp[s_begin:s_end]
306 | seq_y_mark = self.data_stamp[r_begin:r_end]
307 |
308 | return seq_x, seq_y, seq_x_mark, seq_y_mark
309 |
310 | def __len__(self):
311 | return len(self.data_x) - self.seq_len + 1
312 |
313 | def inverse_transform(self, data):
314 | return self.scaler.inverse_transform(data)
315 |
316 | class Dataset_ETT_hour(Dataset):
317 | def __init__(self, root_path, flag='train', size=None,
318 | features='S', data_path='ETTh1.csv',
319 | target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
320 | # size [seq_len, label_len, pred_len]
321 | # info
322 | if size == None:
323 | self.seq_len = 24 * 4 * 4
324 | self.label_len = 24 * 4
325 | self.pred_len = 24 * 4
326 | else:
327 | self.seq_len = size[0]
328 | self.label_len = size[1]
329 | self.pred_len = size[2]
330 | # init
331 | assert flag in ['train', 'test', 'val']
332 | type_map = {'train': 0, 'val': 1, 'test': 2}
333 | self.set_type = type_map[flag]
334 |
335 | # M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate'
336 | self.features = features
337 | self.target = target
338 | self.scale = scale
339 | self.timeenc = timeenc
340 | self.freq = freq
341 |
342 | self.root_path = root_path
343 | self.data_path = data_path
344 | self.__read_data__()
345 |
346 | def __read_data__(self):
347 | self.scaler = StandardScaler()
348 | df_raw = pd.read_csv(os.path.join(self.root_path,
349 | self.data_path))
350 | border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
351 | border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
352 | border1 = border1s[self.set_type]
353 | border2 = border2s[self.set_type]
354 | if self.features == 'M' or self.features == 'MS':
355 | cols_data = df_raw.columns[1:]
356 | df_data = df_raw[cols_data]
357 | elif self.features == 'S':
358 | df_data = df_raw[[self.target]]
359 | if self.scale:
360 | train_data = df_data[border1s[0]:border2s[0]]
361 | self.scaler.fit(train_data.values)
362 | data = self.scaler.transform(df_data.values)
363 | else:
364 | data = df_data.values
365 | df_stamp = df_raw[['date']][border1:border2]
366 | df_stamp['date'] = pd.to_datetime(df_stamp.date)
367 | if self.timeenc == 0:
368 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
369 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
370 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
371 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
372 | data_stamp = df_stamp.drop(['date'], 1).values
373 | elif self.timeenc == 1:
374 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
375 | data_stamp = data_stamp.transpose(1, 0)
376 |
377 | self.data_x = data[border1:border2]
378 | self.data_y = data[border1:border2]
379 | self.data_stamp = data_stamp
380 |
381 | def __getitem__(self, index):
382 | s_begin = index
383 | s_end = s_begin + self.seq_len
384 | r_begin = s_end - self.label_len
385 | r_end = r_begin + self.label_len + self.pred_len
386 |
387 | seq_x = self.data_x[s_begin:s_end]
388 | seq_y = self.data_y[r_begin:r_end]
389 | seq_x_mark = self.data_stamp[s_begin:s_end]
390 | seq_y_mark = self.data_stamp[r_begin:r_end]
391 |
392 | return seq_x, seq_y, seq_x_mark, seq_y_mark
393 |
394 | def __len__(self):
395 | return len(self.data_x) - self.seq_len - self.pred_len + 1
396 |
397 | def inverse_transform(self, data):
398 | return self.scaler.inverse_transform(data)
399 |
400 | class Dataset_ETT_minute(Dataset):
401 | def __init__(self, root_path, flag='train', size=None,
402 | features='S', data_path='ETTm1.csv',
403 | target='OT', scale=True, timeenc=0, freq='t', seasonal_patterns=None):
404 | # size [seq_len, label_len, pred_len]
405 | # info
406 | if size == None:
407 | self.seq_len = 24 * 4 * 4
408 | self.label_len = 24 * 4
409 | self.pred_len = 24 * 4
410 | else:
411 | self.seq_len = size[0]
412 | self.label_len = size[1]
413 | self.pred_len = size[2]
414 | # init
415 | assert flag in ['train', 'test', 'val']
416 | type_map = {'train': 0, 'val': 1, 'test': 2}
417 | self.set_type = type_map[flag]
418 |
419 | self.features = features
420 | self.target = target
421 | self.scale = scale
422 | self.timeenc = timeenc
423 | self.freq = freq
424 |
425 | self.root_path = root_path
426 | self.data_path = data_path
427 | self.__read_data__()
428 |
429 | def __read_data__(self):
430 | self.scaler = StandardScaler()
431 | df_raw = pd.read_csv(os.path.join(self.root_path,
432 | self.data_path))
433 |
434 | border1s = [0,
435 | 12 * 30 * 24 * 4 - self.seq_len,
436 | 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
437 | border2s = [12 * 30 * 24 * 4,
438 | 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4,
439 | 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
440 | border1 = border1s[self.set_type]
441 | border2 = border2s[self.set_type]
442 |
443 | if self.features == 'M' or self.features == 'MS':
444 | cols_data = df_raw.columns[1:]
445 | df_data = df_raw[cols_data]
446 | elif self.features == 'S':
447 | df_data = df_raw[[self.target]]
448 | if self.scale:
449 | train_data = df_data[border1s[0]:border2s[0]]
450 | self.scaler.fit(train_data.values)
451 | data = self.scaler.transform(df_data.values)
452 | else:
453 | data = df_data.values
454 |
455 | df_stamp = df_raw[['date']][border1:border2]
456 | df_stamp['date'] = pd.to_datetime(df_stamp.date)
457 | if self.timeenc == 0:
458 | df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
459 | df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
460 | df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
461 | df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
462 | df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
463 | df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
464 | data_stamp = df_stamp.drop(['date'], 1).values
465 | elif self.timeenc == 1:
466 | data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
467 | data_stamp = data_stamp.transpose(1, 0)
468 |
469 | self.data_x = data[border1:border2]
470 | self.data_y = data[border1:border2]
471 | self.data_stamp = data_stamp
472 |
473 | def __getitem__(self, index):
474 | s_begin = index
475 | s_end = s_begin + self.seq_len
476 | r_begin = s_end - self.label_len
477 | r_end = r_begin + self.label_len + self.pred_len
478 |
479 | seq_x = self.data_x[s_begin:s_end]
480 | seq_y = self.data_y[r_begin:r_end]
481 | seq_x_mark = self.data_stamp[s_begin:s_end]
482 | seq_y_mark = self.data_stamp[r_begin:r_end]
483 |
484 | return seq_x, seq_y, seq_x_mark, seq_y_mark
485 |
486 | def __len__(self):
487 | return len(self.data_x) - self.seq_len - self.pred_len + 1
488 |
489 | def inverse_transform(self, data):
490 | return self.scaler.inverse_transform(data)
--------------------------------------------------------------------------------
/exp/__pycache__/exp_basic.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/exp/__pycache__/exp_basic.cpython-38.pyc
--------------------------------------------------------------------------------
/exp/__pycache__/exp_basic.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/exp/__pycache__/exp_basic.cpython-39.pyc
--------------------------------------------------------------------------------
/exp/__pycache__/exp_main.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/exp/__pycache__/exp_main.cpython-38.pyc
--------------------------------------------------------------------------------
/exp/__pycache__/exp_main.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/exp/__pycache__/exp_main.cpython-39.pyc
--------------------------------------------------------------------------------
/exp/exp_basic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 |
6 | class Exp_Basic(object):
7 | def __init__(self, args):
8 | self.args = args
9 | self.device = self._acquire_device()
10 | self.model = self._build_model().to(self.device)
11 |
12 | def _build_model(self):
13 | raise NotImplementedError
14 | return None
15 |
16 | def _acquire_device(self):
17 | if self.args.use_gpu:
18 | os.environ["CUDA_VISIBLE_DEVICES"] = str(
19 | self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
20 | device = torch.device('cuda:{}'.format(self.args.gpu))
21 | print('Use GPU: cuda:{}'.format(self.args.gpu))
22 | else:
23 | device = torch.device('cpu')
24 | print('Use CPU')
25 | return device
26 |
27 | def _get_data(self):
28 | pass
29 |
30 | def vali(self):
31 | pass
32 |
33 | def train(self):
34 | pass
35 |
36 | def test(self):
37 | pass
38 |
--------------------------------------------------------------------------------
/exp/exp_main.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_factory import data_provider
2 | from .exp_basic import Exp_Basic
3 | from models import Informer, Autoformer, DLinear, MSGNet
4 | from utils.tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop
5 | from utils.metrics import metric
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch import optim, autograd
10 |
11 | import os
12 | import time
13 |
14 | import warnings
15 | import matplotlib.pyplot as plt
16 | import numpy as np
17 |
18 | warnings.filterwarnings('ignore')
19 |
20 | class Exp_Main(Exp_Basic):
21 | def __init__(self, args):
22 | super(Exp_Main, self).__init__(args)
23 |
24 | def _build_model(self):
25 | model_dict = {
26 | 'Informer': Informer,
27 | 'Autoformer': Autoformer,
28 | 'DLinear': DLinear,
29 | 'MSGNet': MSGNet
30 | }
31 | model = model_dict[self.args.model].Model(self.args).float()
32 |
33 | if self.args.use_multi_gpu and self.args.use_gpu:
34 | model = nn.DataParallel(model, device_ids=self.args.device_ids)
35 | return model
36 |
37 | #flag = 'train' or 'val' or 'test'
38 | def _get_data(self, flag):
39 | data_set, data_loader = data_provider(self.args, flag)
40 | return data_set, data_loader
41 |
42 | def _select_optimizer(self):
43 | model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
44 | return model_optim
45 |
46 | def _select_criterion(self):
47 | criterion = nn.MSELoss()
48 | return criterion
49 |
50 |
51 | def vali(self, vali_data, vali_loader, criterion):
52 | total_loss = []
53 | self.model.eval()
54 | with torch.no_grad():
55 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
56 | batch_x = batch_x.float().to(self.device)
57 | batch_y = batch_y.float()
58 | batch_x_mark = batch_x_mark.float().to(self.device)
59 | batch_y_mark = batch_y_mark.float().to(self.device)
60 |
61 | # decoder input
62 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
63 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
64 | # encoder - decoder
65 | if self.args.use_amp:
66 | with torch.cuda.amp.autocast():
67 | if 'Linear' in self.args.model:
68 | outputs = self.model(batch_x)
69 | else:
70 | if self.args.output_attention:
71 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
72 | else:
73 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
74 | else:
75 | if 'Linear' in self.args.model:
76 | outputs = self.model(batch_x)
77 | else:
78 | if self.args.output_attention:
79 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
80 | else:
81 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
82 | f_dim = -1 if self.args.features == 'MS' else 0
83 | outputs = outputs[:, -self.args.pred_len:, f_dim:]
84 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
85 |
86 | pred = outputs.detach().cpu()
87 | true = batch_y.detach().cpu()
88 |
89 | loss = criterion(pred, true)
90 |
91 | total_loss.append(loss)
92 | total_loss = np.average(total_loss)
93 | self.model.train()
94 | return total_loss
95 |
96 | def train(self, setting):
97 | train_data, train_loader = self._get_data(flag='train')
98 | vali_data, vali_loader = self._get_data(flag='val')
99 | test_data, test_loader = self._get_data(flag='test')
100 |
101 | path = os.path.join(self.args.checkpoints, setting)
102 | if not os.path.exists(path):
103 | os.makedirs(path)
104 |
105 | time_now = time.time()
106 | train_steps = len(train_loader)
107 | early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
108 |
109 | model_optim = self._select_optimizer()
110 | criterion = self._select_criterion()
111 | #use automatic mixed precision training
112 | if self.args.use_amp:
113 | scaler = torch.cuda.amp.GradScaler()
114 | for epoch in range(self.args.train_epochs):
115 | iter_count = 0
116 | train_loss = []
117 |
118 | self.model.train()
119 | epoch_time = time.time()
120 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
121 |
122 | iter_count += 1
123 | model_optim.zero_grad()
124 |
125 | batch_x = batch_x.float().to(self.device)
126 | batch_y = batch_y.float().to(self.device)
127 | batch_x_mark = batch_x_mark.float().to(self.device)
128 | batch_y_mark = batch_y_mark.float().to(self.device)
129 |
130 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
131 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
132 |
133 | # encoder - decoder
134 | if self.args.use_amp:
135 | with torch.cuda.amp.autocast():
136 | if 'Linear' in self.args.model:
137 | outputs = self.model(batch_x)
138 | else:
139 | if self.args.output_attention:
140 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
141 | else:
142 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
143 |
144 | f_dim = -1 if self.args.features == 'MS' else 0
145 | outputs = outputs[:, -self.args.pred_len:, f_dim:]
146 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
147 | loss = criterion(outputs, batch_y)
148 | train_loss.append(loss.item())
149 | else:
150 | if 'Linear' in self.args.model:
151 | # print("Linear")
152 | outputs = self.model(batch_x)
153 | else:
154 | if self.args.output_attention: #whether to output attention in ecoder
155 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
156 |
157 | else:
158 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
159 | # print(outputs.shape,batch_y.shape)
160 | f_dim = -1 if self.args.features == 'MS' else 0
161 | outputs = outputs[:, -self.args.pred_len:, f_dim:]
162 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
163 | loss = criterion(outputs, batch_y)
164 | train_loss.append(loss.item())
165 |
166 | if (i + 1) % 100 == 0:
167 | print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
168 | speed = (time.time() - time_now) / iter_count
169 | left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
170 | print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
171 | iter_count = 0
172 | time_now = time.time()
173 |
174 | if self.args.use_amp:
175 | scaler.scale(loss).backward()
176 | scaler.step(model_optim)
177 | scaler.update()
178 | else:
179 | with autograd.detect_anomaly():
180 | loss.backward()
181 | model_optim.step()
182 |
183 | print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
184 | train_loss = np.average(train_loss)
185 | vali_loss = self.vali(vali_data, vali_loader, criterion)
186 | test_loss = self.vali(test_data, test_loader, criterion)
187 |
188 | print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
189 | epoch + 1, train_steps, train_loss, vali_loss, test_loss))
190 | early_stopping(vali_loss, self.model, path)
191 | if early_stopping.early_stop:
192 | print("Early stopping")
193 | break
194 |
195 | adjust_learning_rate(model_optim, epoch + 1, self.args)
196 |
197 | best_model_path = path + '/' + 'checkpoint.pth'
198 | self.model.load_state_dict(torch.load(best_model_path))
199 |
200 | return self.model
201 |
202 | def test(self, setting, test=0):
203 | test_data, test_loader = self._get_data(flag='test')
204 | if test:
205 | print('loading model')
206 | self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
207 |
208 | preds = []
209 | trues = []
210 | inputx = []
211 | folder_path = './test_results/' + setting + '/'
212 | if not os.path.exists(folder_path):
213 | os.makedirs(folder_path)
214 |
215 | self.model.eval()
216 | with torch.no_grad():
217 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
218 |
219 | batch_x = batch_x.float().to(self.device)
220 | batch_y = batch_y.float().to(self.device)
221 | batch_x_mark = batch_x_mark.float().to(self.device)
222 | batch_y_mark = batch_y_mark.float().to(self.device)
223 |
224 | # decoder input
225 | dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
226 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
227 | # encoder - decoder
228 | if self.args.use_amp:
229 | with torch.cuda.amp.autocast():
230 | if 'Linear' in self.args.model:
231 | outputs = self.model(batch_x)
232 | else:
233 | if self.args.output_attention:
234 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
235 | else:
236 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
237 | else:
238 | if 'Linear' in self.args.model:
239 | outputs = self.model(batch_x)
240 | else:
241 | if self.args.output_attention:
242 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
243 | else:
244 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
245 |
246 | f_dim = -1 if self.args.features == 'MS' else 0
247 | # print(outputs.shape,batch_y.shape)
248 | outputs = outputs[:, -self.args.pred_len:, f_dim:]
249 | batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
250 | outputs = outputs.detach().cpu().numpy()
251 | batch_y = batch_y.detach().cpu().numpy()
252 |
253 | pred = outputs # outputs.detach().cpu().numpy() # .squeeze()
254 | true = batch_y # batch_y.detach().cpu().numpy() # .squeeze()
255 |
256 | preds.append(pred)
257 | trues.append(true)
258 | inputx.append(batch_x.detach().cpu().numpy())
259 | if i % 10 == 0:
260 | input = batch_x.detach().cpu().numpy()
261 | gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
262 | pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
263 | visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))
264 | #See utils / tools for usage
265 | if self.args.test_flop:
266 | test_params_flop((batch_x.shape[1],batch_x.shape[2]))
267 | exit()
268 | # print('preds_shape:', len(preds),len(preds[0]),len(preds[1]))
269 |
270 | preds = np.array(preds)
271 | trues = np.array(trues)
272 | inputx = np.array(inputx)
273 |
274 | print('preds_shape:', preds.shape)
275 | print('trues_shape:', trues.shape)
276 |
277 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
278 | trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
279 | inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])
280 |
281 | # result save
282 | folder_path = './results/' + setting + '/'
283 | if not os.path.exists(folder_path):
284 | os.makedirs(folder_path)
285 |
286 | mae, mse, rmse, mape, mspe, rse, corr, nd, nrmse = metric(preds, trues)
287 | print('nd:{}, nrmse:{}, mse:{}, mae:{}, rse:{}, mape:{}'.format(nd, nrmse,mse, mae, rse, mape))
288 | f = open("result.txt", 'a')
289 | f.write(setting + " \n")
290 | f.write('nd:{}, nrmse:{}, mse:{}, mae:{}, rse:{}, mape:{}'.format(nd, nrmse,mse, mae, rse, mape))
291 | f.write('\n')
292 | f.write('\n')
293 | f.close()
294 |
295 | # np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe,rse, corr]))
296 | # np.save(folder_path + 'pred.npy', preds)
297 | # np.save(folder_path + 'true.npy', trues)
298 | # np.save(folder_path + 'x.npy', inputx)
299 | return
300 |
301 |
302 | def predict(self, setting, load=False):
303 | pred_data, pred_loader = self._get_data(flag='pred')
304 |
305 | if load:
306 | path = os.path.join(self.args.checkpoints, setting)
307 | best_model_path = path + '/' + 'checkpoint.pth'
308 | self.model.load_state_dict(torch.load(best_model_path))
309 |
310 | preds = []
311 |
312 | self.model.eval()
313 | with torch.no_grad():
314 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader):
315 | batch_x = batch_x.float().to(self.device)
316 | batch_y = batch_y.float()
317 | batch_x_mark = batch_x_mark.float().to(self.device)
318 | batch_y_mark = batch_y_mark.float().to(self.device)
319 |
320 | # decoder input
321 | dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[2]]).float().to(batch_y.device)
322 | dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
323 | # encoder - decoder
324 | if self.args.use_amp:
325 | with torch.cuda.amp.autocast():
326 | if 'Linear' in self.args.model:
327 | outputs = self.model(batch_x)
328 | else:
329 | if self.args.output_attention:
330 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
331 | else:
332 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
333 | else:
334 | if 'Linear' in self.args.model:
335 | outputs = self.model(batch_x)
336 | else:
337 | if self.args.output_attention:
338 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
339 | else:
340 | outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
341 | pred = outputs.detach().cpu().numpy() # .squeeze()
342 | preds.append(pred)
343 |
344 | preds = np.array(preds)
345 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
346 |
347 | # result save
348 | folder_path = './results/' + setting + '/'
349 | if not os.path.exists(folder_path):
350 | os.makedirs(folder_path)
351 |
352 | np.save(folder_path + 'real_prediction.npy', preds)
353 |
354 | return
355 |
--------------------------------------------------------------------------------
/exp/exp_stat.py:
--------------------------------------------------------------------------------
1 | from data_provider.data_factory import data_provider
2 | from exp.exp_basic import Exp_Basic
3 | from utils.tools import EarlyStopping, adjust_learning_rate, visual
4 | from utils.metrics import metric
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch import optim
10 |
11 | import os
12 | import time
13 | import warnings
14 | import matplotlib.pyplot as plt
15 | from models.Stat_models import *
16 |
17 | warnings.filterwarnings('ignore')
18 |
19 |
20 | class Exp_Main(Exp_Basic):
21 | def __init__(self, args):
22 | super(Exp_Main, self).__init__(args)
23 |
24 | def _build_model(self):
25 | model_dict = {
26 | 'Naive': Naive_repeat,
27 | 'ARIMA': Arima,
28 | 'SARIMA': SArima,
29 | 'GBRT': GBRT,
30 | }
31 | model = model_dict[self.args.model](self.args).float()
32 |
33 | return model
34 |
35 | def _get_data(self, flag):
36 | data_set, data_loader = data_provider(self.args, flag)
37 | return data_set, data_loader
38 |
39 | def test(self, setting, test=0):
40 | test_data, test_loader = self._get_data(flag='test')
41 |
42 | # Sample 10%
43 | samples = max(int(self.args.sample * self.args.batch_size),1)
44 |
45 | preds = []
46 | trues = []
47 | inputx = []
48 | folder_path = './test_results/' + setting + '/'
49 | if not os.path.exists(folder_path):
50 | os.makedirs(folder_path)
51 |
52 | with torch.no_grad():
53 | for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
54 | batch_x = batch_x.float().to(self.device).cpu().numpy()
55 | batch_y = batch_y.float().to(self.device).cpu().numpy()
56 |
57 | batch_x = batch_x[:samples]
58 | outputs = self.model(batch_x)
59 |
60 | f_dim = -1 if self.args.features == 'MS' else 0
61 | # print(outputs.shape,batch_y.shape)
62 | outputs = outputs[:, -self.args.pred_len:, f_dim:]
63 | batch_y = batch_y[:samples, -self.args.pred_len:, f_dim:]
64 |
65 | pred = outputs # outputs.detach().cpu().numpy() # .squeeze()
66 | true = batch_y # batch_y.detach().cpu().numpy() # .squeeze()
67 |
68 | preds.append(pred)
69 | trues.append(true)
70 | inputx.append(batch_x)
71 | if i % 20 == 0:
72 | input = batch_x
73 | gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
74 | pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
75 | visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))
76 |
77 | preds = np.array(preds)
78 | trues = np.array(trues)
79 | inputx = np.array(inputx)
80 |
81 | preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
82 | trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
83 | inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])
84 |
85 | folder_path = './results/' + setting + '/'
86 | if not os.path.exists(folder_path):
87 | os.makedirs(folder_path)
88 |
89 | mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues)
90 | corr = []
91 | print('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
92 | f = open("result.txt", 'a')
93 | f.write(setting + " \n")
94 | f.write('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
95 | f.write('\n')
96 | f.write('\n')
97 | f.close()
98 |
99 | np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe,rse, corr]))
100 | np.save(folder_path + 'pred.npy', preds)
101 | np.save(folder_path + 'true.npy', trues)
102 | # np.save(folder_path + 'x.npy', inputx)
103 | return
--------------------------------------------------------------------------------
/layers/AutoCorrelation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import math
7 | from math import sqrt
8 | import os
9 |
10 |
11 | class AutoCorrelation(nn.Module):
12 | """
13 | AutoCorrelation Mechanism with the following two phases:
14 | (1) period-based dependencies discovery
15 | (2) time delay aggregation
16 | This block can replace the self-attention family mechanism seamlessly.
17 | """
18 | def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
19 | super(AutoCorrelation, self).__init__()
20 | self.factor = factor
21 | self.scale = scale
22 | self.mask_flag = mask_flag
23 | self.output_attention = output_attention
24 | self.dropout = nn.Dropout(attention_dropout)
25 |
26 | def time_delay_agg_training(self, values, corr):
27 | """
28 | SpeedUp version of Autocorrelation (a batch-normalization style design)
29 | This is for the training phase.
30 | """
31 | head = values.shape[1]
32 | channel = values.shape[2]
33 | length = values.shape[3]
34 | # find top k
35 | top_k = int(self.factor * math.log(length))
36 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
37 | index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
38 | weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
39 | # update corr
40 | tmp_corr = torch.softmax(weights, dim=-1)
41 | # aggregation
42 | tmp_values = values
43 | delays_agg = torch.zeros_like(values).float()
44 | for i in range(top_k):
45 | pattern = torch.roll(tmp_values, -int(index[i]), -1)
46 | delays_agg = delays_agg + pattern * \
47 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
48 | return delays_agg
49 |
50 | def time_delay_agg_inference(self, values, corr):
51 | """
52 | SpeedUp version of Autocorrelation (a batch-normalization style design)
53 | This is for the inference phase.
54 | """
55 | batch = values.shape[0]
56 | head = values.shape[1]
57 | channel = values.shape[2]
58 | length = values.shape[3]
59 | # index init
60 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61 | # init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda()
62 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(device)
63 |
64 | # find top k
65 | top_k = int(self.factor * math.log(length))
66 | mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
67 | weights = torch.topk(mean_value, top_k, dim=-1)[0]
68 | delay = torch.topk(mean_value, top_k, dim=-1)[1]
69 | # update corr
70 | tmp_corr = torch.softmax(weights, dim=-1)
71 | # aggregation
72 | tmp_values = values.repeat(1, 1, 1, 2)
73 | delays_agg = torch.zeros_like(values).float()
74 | for i in range(top_k):
75 | tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
76 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
77 | delays_agg = delays_agg + pattern * \
78 | (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
79 | return delays_agg
80 |
81 | def time_delay_agg_full(self, values, corr):
82 | """
83 | Standard version of Autocorrelation
84 | """
85 | batch = values.shape[0]
86 | head = values.shape[1]
87 | channel = values.shape[2]
88 | length = values.shape[3]
89 | # index init
90 | init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda()
91 | # find top k
92 | top_k = int(self.factor * math.log(length))
93 | weights = torch.topk(corr, top_k, dim=-1)[0]
94 | delay = torch.topk(corr, top_k, dim=-1)[1]
95 | # update corr
96 | tmp_corr = torch.softmax(weights, dim=-1)
97 | # aggregation
98 | tmp_values = values.repeat(1, 1, 1, 2)
99 | delays_agg = torch.zeros_like(values).float()
100 | for i in range(top_k):
101 | tmp_delay = init_index + delay[..., i].unsqueeze(-1)
102 | pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
103 | delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
104 | return delays_agg
105 |
106 | def forward(self, queries, keys, values, attn_mask):
107 | B, L, H, E = queries.shape
108 | _, S, _, D = values.shape
109 | if L > S:
110 | zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
111 | values = torch.cat([values, zeros], dim=1)
112 | keys = torch.cat([keys, zeros], dim=1)
113 | else:
114 | values = values[:, :L, :, :]
115 | keys = keys[:, :L, :, :]
116 |
117 | # period-based dependencies (b, len//period , period , d_model) ->(b, period ,d_model, len//period)
118 | #(b , T, h , n) ->(b, h, n, T)
119 | q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
120 | k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
121 | #(b, period ,d_model, period),
122 | res = q_fft * torch.conj(k_fft)
123 | corr = torch.fft.irfft(res, dim=-1)
124 |
125 | # time delay agg
126 | if self.training:
127 | V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
128 | else:
129 | V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
130 |
131 | if self.output_attention:
132 | return (V.contiguous(), corr.permute(0, 3, 1, 2))
133 | else:
134 | return (V.contiguous(), None)
135 |
136 |
137 | class AutoCorrelationLayer(nn.Module):
138 | def __init__(self, correlation, d_model, n_heads, d_keys=None,
139 | d_values=None):
140 | super(AutoCorrelationLayer, self).__init__()
141 |
142 | d_keys = d_keys or (d_model // n_heads)
143 | d_values = d_values or (d_model // n_heads)
144 |
145 | self.inner_correlation = correlation
146 | self.query_projection = nn.Linear(d_model, d_keys * n_heads)
147 | self.key_projection = nn.Linear(d_model, d_keys * n_heads)
148 | self.value_projection = nn.Linear(d_model, d_values * n_heads)
149 | self.out_projection = nn.Linear(d_values * n_heads, d_model)
150 | self.n_heads = n_heads
151 |
152 | def forward(self, queries, keys, values, attn_mask):
153 | B, L, _ = queries.shape
154 | _, S, _ = keys.shape
155 | H = self.n_heads
156 |
157 | queries = self.query_projection(queries).view(B, L, H, -1)
158 | keys = self.key_projection(keys).view(B, S, H, -1)
159 | values = self.value_projection(values).view(B, S, H, -1)
160 |
161 | out, attn = self.inner_correlation(
162 | queries,
163 | keys,
164 | values,
165 | attn_mask
166 | )
167 | out = out.view(B, L, -1)
168 |
169 | return self.out_projection(out), attn
170 |
--------------------------------------------------------------------------------
/layers/Autoformer_EncDec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class my_Layernorm(nn.Module):
7 | """
8 | Special designed layernorm for the seasonal part
9 | """
10 | def __init__(self, channels):
11 | super(my_Layernorm, self).__init__()
12 | self.layernorm = nn.LayerNorm(channels)
13 |
14 | def forward(self, x):
15 | x_hat = self.layernorm(x)
16 | bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
17 | return x_hat - bias
18 |
19 |
20 | class moving_avg(nn.Module):
21 | """
22 | Moving average block to highlight the trend of time series
23 | """
24 | def __init__(self, kernel_size, stride):
25 | super(moving_avg, self).__init__()
26 | self.kernel_size = kernel_size
27 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
28 |
29 | def forward(self, x):
30 | # padding on the both ends of time series
31 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
32 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
33 | x = torch.cat([front, x, end], dim=1)
34 | x = self.avg(x.permute(0, 2, 1))
35 | x = x.permute(0, 2, 1)
36 | return x
37 |
38 |
39 | class series_decomp(nn.Module):
40 | """
41 | Series decomposition block
42 | """
43 | def __init__(self, kernel_size):
44 | super(series_decomp, self).__init__()
45 | self.moving_avg = moving_avg(kernel_size, stride=1)
46 |
47 | def forward(self, x):
48 | moving_mean = self.moving_avg(x)
49 | res = x - moving_mean
50 | return res, moving_mean
51 |
52 |
53 | class EncoderLayer(nn.Module):
54 | """
55 | Autoformer encoder layer with the progressive decomposition architecture
56 | """
57 | def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
58 | super(EncoderLayer, self).__init__()
59 | d_ff = d_ff or 4 * d_model
60 | self.attention = attention
61 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
62 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
63 | self.decomp1 = series_decomp(moving_avg)
64 | self.decomp2 = series_decomp(moving_avg)
65 | self.dropout = nn.Dropout(dropout)
66 | self.activation = F.relu if activation == "relu" else F.gelu
67 |
68 | def forward(self, x, attn_mask=None):
69 | new_x, attn = self.attention(
70 | x, x, x,
71 | attn_mask=attn_mask
72 | )
73 | x = x + self.dropout(new_x)
74 | x, _ = self.decomp1(x)
75 | y = x
76 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
77 | y = self.dropout(self.conv2(y).transpose(-1, 1))
78 | res, _ = self.decomp2(x + y)
79 | return res, attn
80 |
81 |
82 | class Encoder(nn.Module):
83 | """
84 | Autoformer encoder
85 | """
86 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
87 | super(Encoder, self).__init__()
88 | self.attn_layers = nn.ModuleList(attn_layers)
89 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
90 | self.norm = norm_layer
91 |
92 | def forward(self, x, attn_mask=None):
93 | attns = []
94 | if self.conv_layers is not None:
95 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
96 | x, attn = attn_layer(x, attn_mask=attn_mask)
97 | x = conv_layer(x)
98 | attns.append(attn)
99 | x, attn = self.attn_layers[-1](x)
100 | attns.append(attn)
101 | else:
102 | for attn_layer in self.attn_layers:
103 | x, attn = attn_layer(x, attn_mask=attn_mask)
104 | attns.append(attn)
105 |
106 | if self.norm is not None:
107 | x = self.norm(x)
108 |
109 | return x, attns
110 |
111 |
112 | class DecoderLayer(nn.Module):
113 | """
114 | Autoformer decoder layer with the progressive decomposition architecture
115 | """
116 | def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
117 | moving_avg=25, dropout=0.1, activation="relu"):
118 | super(DecoderLayer, self).__init__()
119 | d_ff = d_ff or 4 * d_model
120 | self.self_attention = self_attention
121 | self.cross_attention = cross_attention
122 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
123 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
124 | self.decomp1 = series_decomp(moving_avg)
125 | self.decomp2 = series_decomp(moving_avg)
126 | self.decomp3 = series_decomp(moving_avg)
127 | self.dropout = nn.Dropout(dropout)
128 | self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
129 | padding_mode='circular', bias=False)
130 | self.activation = F.relu if activation == "relu" else F.gelu
131 |
132 | def forward(self, x, cross, x_mask=None, cross_mask=None):
133 | x = x + self.dropout(self.self_attention(
134 | x, x, x,
135 | attn_mask=x_mask
136 | )[0])
137 | x, trend1 = self.decomp1(x)
138 | x = x + self.dropout(self.cross_attention(
139 | x, cross, cross,
140 | attn_mask=cross_mask
141 | )[0])
142 | x, trend2 = self.decomp2(x)
143 | y = x
144 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
145 | y = self.dropout(self.conv2(y).transpose(-1, 1))
146 | x, trend3 = self.decomp3(x + y)
147 |
148 | residual_trend = trend1 + trend2 + trend3
149 | residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
150 | return x, residual_trend
151 |
152 |
153 | class Decoder(nn.Module):
154 | """
155 | Autoformer encoder
156 | """
157 | def __init__(self, layers, norm_layer=None, projection=None):
158 | super(Decoder, self).__init__()
159 | self.layers = nn.ModuleList(layers)
160 | self.norm = norm_layer
161 | self.projection = projection
162 |
163 | def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
164 | for layer in self.layers:
165 | x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
166 | trend = trend + residual_trend
167 |
168 | if self.norm is not None:
169 | x = self.norm(x)
170 |
171 | if self.projection is not None:
172 | x = self.projection(x)
173 | return x, trend
174 |
--------------------------------------------------------------------------------
/layers/Embed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import weight_norm
5 | import math
6 |
7 |
8 | class PositionalEmbedding(nn.Module):
9 | def __init__(self, d_model, max_len=5000):
10 | super(PositionalEmbedding, self).__init__()
11 | # Compute the positional encodings once in log space.
12 | pe = torch.zeros(max_len, d_model).float()
13 | pe.require_grad = False
14 |
15 | position = torch.arange(0, max_len).float().unsqueeze(1)
16 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
17 | pe[:, 0::2] = torch.sin(position * div_term)
18 | pe[:, 1::2] = torch.cos(position * div_term)
19 |
20 | pe = pe.unsqueeze(0)
21 | self.register_buffer('pe', pe)
22 |
23 | def forward(self, x):
24 | return self.pe[:, :x.size(1)]
25 |
26 |
27 | class TokenEmbedding(nn.Module):
28 | def __init__(self, c_in, d_model):
29 | super(TokenEmbedding, self).__init__()
30 | padding = 1 if torch.__version__ >= '1.5.0' else 2
31 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
32 | kernel_size=3, padding=padding, padding_mode='circular', bias=False)
33 | for m in self.modules():
34 | if isinstance(m, nn.Conv1d):
35 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
36 |
37 | def forward(self, x):
38 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
39 | return x
40 |
41 |
42 | class FixedEmbedding(nn.Module):
43 | def __init__(self, c_in, d_model):
44 | super(FixedEmbedding, self).__init__()
45 |
46 | w = torch.zeros(c_in, d_model).float()
47 | w.require_grad = False
48 |
49 | position = torch.arange(0, c_in).float().unsqueeze(1)
50 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
51 |
52 | w[:, 0::2] = torch.sin(position * div_term)
53 | w[:, 1::2] = torch.cos(position * div_term)
54 |
55 | self.emb = nn.Embedding(c_in, d_model)
56 | self.emb.weight = nn.Parameter(w, requires_grad=False)
57 |
58 | def forward(self, x):
59 | return self.emb(x).detach()
60 |
61 |
62 | class TemporalEmbedding(nn.Module):
63 | def __init__(self, d_model, embed_type='fixed', freq='h'):
64 | super(TemporalEmbedding, self).__init__()
65 |
66 | minute_size = 4
67 | hour_size = 24
68 | weekday_size = 7
69 | day_size = 32
70 | month_size = 13
71 |
72 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
73 | if freq == 't':
74 | self.minute_embed = Embed(minute_size, d_model)
75 | self.hour_embed = Embed(hour_size, d_model)
76 | self.weekday_embed = Embed(weekday_size, d_model)
77 | self.day_embed = Embed(day_size, d_model)
78 | self.month_embed = Embed(month_size, d_model)
79 |
80 | def forward(self, x):
81 | x = x.long()
82 |
83 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0.
84 | hour_x = self.hour_embed(x[:, :, 3])
85 | weekday_x = self.weekday_embed(x[:, :, 2])
86 | day_x = self.day_embed(x[:, :, 1])
87 | month_x = self.month_embed(x[:, :, 0])
88 |
89 | return hour_x + weekday_x + day_x + month_x + minute_x
90 |
91 | class TimeFeatureEmbedding(nn.Module):
92 | def __init__(self, d_model, embed_type='timeF', freq='h'):
93 | super(TimeFeatureEmbedding, self).__init__()
94 | freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
95 | d_inp = freq_map[freq]
96 | self.embed = nn.Linear(d_inp, d_model, bias=False)
97 |
98 | def forward(self, x):
99 | return self.embed(x)
100 |
101 |
102 | class DataEmbedding(nn.Module):
103 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
104 | super(DataEmbedding, self).__init__()
105 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
106 | #(batch_size, len batch_x[1], d_model )
107 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
108 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
109 | d_model=d_model, embed_type=embed_type, freq=freq)
110 | #(1, len batch_x[1], d_model)
111 | self.position_embedding = PositionalEmbedding(d_model=d_model)
112 |
113 | self.dropout = nn.Dropout(p=dropout)
114 |
115 | def forward(self, x, x_mark):
116 | if x_mark is None:
117 | x = self.value_embedding(x) + self.position_embedding(x)
118 | else:
119 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
120 | return self.dropout(x)
121 |
122 |
123 | class DataEmbedding_wo_pos(nn.Module):
124 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
125 | super(DataEmbedding_wo_pos, self).__init__()
126 |
127 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
128 | self.position_embedding = PositionalEmbedding(d_model=d_model)
129 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
130 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
131 | d_model=d_model, embed_type=embed_type, freq=freq)
132 | self.dropout = nn.Dropout(p=dropout)
133 |
134 | def forward(self, x, x_mark):
135 | x = self.value_embedding(x) + self.temporal_embedding(x_mark)
136 | return self.dropout(x)
137 |
138 | class DataEmbedding_wo_pos_temp(nn.Module):
139 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
140 | super(DataEmbedding_wo_pos_temp, self).__init__()
141 |
142 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
143 | self.position_embedding = PositionalEmbedding(d_model=d_model)
144 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
145 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
146 | d_model=d_model, embed_type=embed_type, freq=freq)
147 | self.dropout = nn.Dropout(p=dropout)
148 |
149 | def forward(self, x, x_mark):
150 | x = self.value_embedding(x)
151 | return self.dropout(x)
152 |
153 | class DataEmbedding_wo_temp(nn.Module):
154 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
155 | super(DataEmbedding_wo_temp, self).__init__()
156 |
157 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
158 | self.position_embedding = PositionalEmbedding(d_model=d_model)
159 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
160 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
161 | d_model=d_model, embed_type=embed_type, freq=freq)
162 | self.dropout = nn.Dropout(p=dropout)
163 |
164 | def forward(self, x, x_mark):
165 | x = self.value_embedding(x) + self.position_embedding(x)
166 | return self.dropout(x)
--------------------------------------------------------------------------------
/layers/MSGBlock.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | import numpy as np
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch
6 | from torch import nn, Tensor
7 | from einops import rearrange
8 | from einops.layers.torch import Rearrange
9 | from utils.masking import TriangularCausalMask
10 |
11 | class Predict(nn.Module):
12 | def __init__(self, individual, c_out, seq_len, pred_len, dropout):
13 | super(Predict, self).__init__()
14 | self.individual = individual
15 | self.c_out = c_out
16 |
17 | if self.individual:
18 | self.seq2pred = nn.ModuleList()
19 | self.dropout = nn.ModuleList()
20 | for i in range(self.c_out):
21 | self.seq2pred.append(nn.Linear(seq_len , pred_len))
22 | self.dropout.append(nn.Dropout(dropout))
23 | else:
24 | self.seq2pred = nn.Linear(seq_len , pred_len)
25 | self.dropout = nn.Dropout(dropout)
26 |
27 | #(B, c_out , seq)
28 | def forward(self, x):
29 | if self.individual:
30 | out = []
31 | for i in range(self.c_out):
32 | per_out = self.seq2pred[i](x[:,i,:])
33 | per_out = self.dropout[i](per_out)
34 | out.append(per_out)
35 | out = torch.stack(out,dim=1)
36 | else:
37 | out = self.seq2pred(x)
38 | out = self.dropout(out)
39 |
40 | return out
41 |
42 |
43 | class Attention_Block(nn.Module):
44 | def __init__(self, d_model, d_ff=None, n_heads=8, dropout=0.1, activation="relu"):
45 | super(Attention_Block, self).__init__()
46 | d_ff = d_ff or 4 * d_model
47 | self.attention = self_attention(FullAttention, d_model, n_heads=n_heads)
48 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
49 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
50 | self.norm1 = nn.LayerNorm(d_model)
51 | self.norm2 = nn.LayerNorm(d_model)
52 | self.dropout = nn.Dropout(dropout)
53 | self.activation = F.relu if activation == "relu" else F.gelu
54 |
55 | def forward(self, x, attn_mask=None):
56 | new_x, attn = self.attention(
57 | x, x, x,
58 | attn_mask=attn_mask
59 | )
60 | x = x + self.dropout(new_x)
61 |
62 | y = x = self.norm1(x)
63 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
64 | y = self.dropout(self.conv2(y).transpose(-1, 1))
65 |
66 | return self.norm2(x + y)
67 |
68 |
69 | class self_attention(nn.Module):
70 | def __init__(self, attention, d_model ,n_heads):
71 | super(self_attention, self).__init__()
72 | d_keys = d_model // n_heads
73 | d_values = d_model // n_heads
74 |
75 | self.inner_attention = attention( attention_dropout = 0.1)
76 | self.query_projection = nn.Linear(d_model, d_keys * n_heads)
77 | self.key_projection = nn.Linear(d_model, d_keys * n_heads)
78 | self.value_projection = nn.Linear(d_model, d_values * n_heads)
79 | self.out_projection = nn.Linear(d_values * n_heads, d_model)
80 | self.n_heads = n_heads
81 |
82 |
83 | def forward(self, queries ,keys ,values, attn_mask= None):
84 | B, L, _ = queries.shape
85 | _, S, _ = keys.shape
86 | H = self.n_heads
87 | queries = self.query_projection(queries).view(B, L, H, -1)
88 | keys = self.key_projection(keys).view(B, S, H, -1)
89 | values = self.value_projection(values).view(B, S, H, -1)
90 |
91 | out, attn = self.inner_attention(
92 | queries,
93 | keys,
94 | values,
95 | attn_mask
96 | )
97 | out = out.view(B, L, -1)
98 | out = self.out_projection(out)
99 | return out , attn
100 |
101 |
102 | class FullAttention(nn.Module):
103 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
104 | super(FullAttention, self).__init__()
105 | self.scale = scale
106 | self.mask_flag = mask_flag
107 | self.output_attention = output_attention
108 | self.dropout = nn.Dropout(attention_dropout)
109 |
110 | def forward(self, queries, keys, values, attn_mask):
111 | B, L, H, E = queries.shape
112 | _, S, _, D = values.shape
113 | scale = self.scale or 1. / sqrt(E)
114 | scores = torch.einsum("blhe,bshe->bhls", queries, keys)
115 | if self.mask_flag:
116 | if attn_mask is None:
117 | attn_mask = TriangularCausalMask(B, L, device=queries.device)
118 | scores.masked_fill_(attn_mask.mask, -np.inf)
119 | A = self.dropout(torch.softmax(scale * scores, dim=-1))
120 | V = torch.einsum("bhls,bshd->blhd", A, values)
121 | # return V.contiguous()
122 | if self.output_attention:
123 | return (V.contiguous(), A)
124 | else:
125 | return (V.contiguous(), None)
126 |
127 |
128 | class GraphBlock(nn.Module):
129 | def __init__(self, c_out , d_model , conv_channel, skip_channel,
130 | gcn_depth , dropout, propalpha ,seq_len , node_dim):
131 | super(GraphBlock, self).__init__()
132 |
133 | self.nodevec1 = nn.Parameter(torch.randn(c_out, node_dim), requires_grad=True)
134 | self.nodevec2 = nn.Parameter(torch.randn(node_dim, c_out), requires_grad=True)
135 | self.start_conv = nn.Conv2d(1 , conv_channel, (d_model - c_out + 1, 1))
136 | self.gconv1 = mixprop(conv_channel, skip_channel, gcn_depth, dropout, propalpha)
137 | self.gelu = nn.GELU()
138 | self.end_conv = nn.Conv2d(skip_channel, seq_len , (1, seq_len ))
139 | self.linear = nn.Linear(c_out, d_model)
140 | self.norm = nn.LayerNorm(d_model)
141 |
142 | # x in (B, T, d_model)
143 | # Here we use a mlp to fit a complex mapping f (x)
144 | def forward(self, x):
145 | adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
146 | out = x.unsqueeze(1).transpose(2, 3)
147 | out = self.start_conv(out)
148 | out = self.gelu(self.gconv1(out , adp))
149 | out = self.end_conv(out).squeeze()
150 | out = self.linear(out)
151 |
152 | return self.norm(x + out)
153 |
154 |
155 | class nconv(nn.Module):
156 | def __init__(self):
157 | super(nconv,self).__init__()
158 |
159 | def forward(self,x, A):
160 | x = torch.einsum('ncwl,vw->ncvl',(x,A))
161 | # x = torch.einsum('ncwl,wv->nclv',(x,A)
162 | return x.contiguous()
163 |
164 |
165 | class linear(nn.Module):
166 | def __init__(self,c_in,c_out,bias=True):
167 | super(linear,self).__init__()
168 | self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias)
169 |
170 | def forward(self,x):
171 | return self.mlp(x)
172 |
173 |
174 | class mixprop(nn.Module):
175 | def __init__(self,c_in,c_out,gdep,dropout,alpha):
176 | super(mixprop, self).__init__()
177 | self.nconv = nconv()
178 | self.mlp = linear((gdep+1)*c_in,c_out)
179 | self.gdep = gdep
180 | self.dropout = dropout
181 | self.alpha = alpha
182 |
183 | def forward(self, x, adj):
184 | adj = adj + torch.eye(adj.size(0)).to(x.device)
185 | d = adj.sum(1)
186 | h = x
187 | out = [h]
188 | a = adj / d.view(-1, 1)
189 | for i in range(self.gdep):
190 | h = self.alpha*x + (1-self.alpha)*self.nconv(h,a)
191 | out.append(h)
192 | ho = torch.cat(out,dim=1)
193 | ho = self.mlp(ho)
194 | return ho
195 |
196 |
197 | class simpleVIT(nn.Module):
198 | def __init__(self, in_channels, emb_size, patch_size=2, depth=1, num_heads=4, dropout=0.1,init_weight =True):
199 | super(simpleVIT, self).__init__()
200 | self.emb_size = emb_size
201 | self.depth = depth
202 | self.to_patch = nn.Sequential(
203 | nn.Conv2d(in_channels, emb_size, 2 * patch_size + 1, padding= patch_size),
204 | Rearrange('b e (h) (w) -> b (h w) e'),
205 | )
206 | self.layers = nn.ModuleList([])
207 | for _ in range(self.depth):
208 | self.layers.append(nn.ModuleList([
209 | nn.LayerNorm(emb_size),
210 | MultiHeadAttention(emb_size, num_heads, dropout),
211 | FeedForward(emb_size, emb_size)
212 | ]))
213 |
214 | if init_weight:
215 | self._initialize_weights()
216 |
217 | def _initialize_weights(self):
218 | for m in self.modules():
219 | if isinstance(m, nn.Conv2d):
220 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
221 | if m.bias is not None:
222 | nn.init.constant_(m.bias, 0)
223 |
224 | def forward(self,x):
225 | B , N ,_ ,P = x.shape
226 | x = self.to_patch(x)
227 | # x = x.permute(0, 2, 3, 1).reshape(B,-1, N)
228 | for norm ,attn, ff in self.layers:
229 | x = attn(norm(x)) + x
230 | x = ff(x) + x
231 |
232 | x = x.transpose(1,2).reshape(B, self.emb_size ,-1, P)
233 | return x
234 |
235 | class MultiHeadAttention(nn.Module):
236 | def __init__(self, emb_size, num_heads, dropout):
237 | super().__init__()
238 | self.emb_size = emb_size
239 | self.num_heads = num_heads
240 | self.keys = nn.Linear(emb_size, emb_size)
241 | self.queries = nn.Linear(emb_size, emb_size)
242 | self.values = nn.Linear(emb_size, emb_size)
243 | self.att_drop = nn.Dropout(dropout)
244 | self.projection = nn.Linear(emb_size, emb_size)
245 |
246 | def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
247 | queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
248 | keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
249 | values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
250 | energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
251 | if mask is not None:
252 | fill_value = torch.finfo(torch.float32).min
253 | energy.mask_fill(~mask, fill_value)
254 |
255 | scaling = self.emb_size ** (1 / 2)
256 | att = F.softmax(energy, dim=-1) / scaling
257 | att = self.att_drop(att)
258 | # sum up over the third axis
259 | out = torch.einsum('bhal, bhlv -> bhav ', att, values)
260 | out = rearrange(out, "b h n d -> b n (h d)")
261 | out = self.projection(out)
262 | return out
263 |
264 | class FeedForward(nn.Module):
265 | def __init__(self, dim, hidden_dim):
266 | super().__init__()
267 | self.net = nn.Sequential(
268 | nn.LayerNorm(dim),
269 | nn.Linear(dim, hidden_dim),
270 | nn.GELU(),
271 | nn.Linear(hidden_dim, dim),
272 | )
273 | def forward(self, x):
274 | return self.net(x)
--------------------------------------------------------------------------------
/layers/SelfAttention_Family.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import matplotlib.pyplot as plt
6 |
7 | import numpy as np
8 | import math
9 | from math import sqrt
10 | from utils.masking import TriangularCausalMask, ProbMask
11 | import os
12 |
13 | class FullAttention(nn.Module):
14 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
15 | super(FullAttention, self).__init__()
16 | self.scale = scale
17 | self.mask_flag = mask_flag
18 | self.output_attention = output_attention
19 | self.dropout = nn.Dropout(attention_dropout)
20 |
21 | def forward(self, queries, keys, values, attn_mask):
22 | B, L, H, E = queries.shape
23 | _, S, _, D = values.shape
24 | scale = self.scale or 1. / sqrt(E)
25 | scores = torch.einsum("blhe,bshe->bhls", queries, keys)
26 | if self.mask_flag:
27 | if attn_mask is None:
28 | attn_mask = TriangularCausalMask(B, L, device=queries.device)
29 | scores.masked_fill_(attn_mask.mask, -np.inf)
30 | A = self.dropout(torch.softmax(scale * scores, dim=-1))
31 | V = torch.einsum("bhls,bshd->blhd", A, values)
32 | if self.output_attention:
33 | return (V.contiguous(), A)
34 | else:
35 | return (V.contiguous(), None)
36 |
37 |
38 | class ProbAttention(nn.Module):
39 | def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
40 | super(ProbAttention, self).__init__()
41 | self.factor = factor
42 | self.scale = scale
43 | self.mask_flag = mask_flag
44 | self.output_attention = output_attention
45 | self.dropout = nn.Dropout(attention_dropout)
46 |
47 | def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
48 | # Q [B, H, L, D]
49 | B, H, L_K, E = K.shape
50 | _, _, L_Q, _ = Q.shape
51 |
52 | # calculate the sampled Q_K
53 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
54 | index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
55 | K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
56 | Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
57 |
58 | # find the Top_k query with sparisty measurement
59 | M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
60 | M_top = M.topk(n_top, sorted=False)[1]
61 |
62 | # use the reduced Q to calculate Q_K
63 | Q_reduce = Q[torch.arange(B)[:, None, None],
64 | torch.arange(H)[None, :, None],
65 | M_top, :] # factor*ln(L_q)
66 | Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
67 | return Q_K, M_top
68 |
69 | def _get_initial_context(self, V, L_Q):
70 | B, H, L_V, D = V.shape
71 | if not self.mask_flag:
72 | # V_sum = V.sum(dim=-2)
73 | V_sum = V.mean(dim=-2)
74 | contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
75 | else: # use mask
76 | assert (L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only
77 | contex = V.cumsum(dim=-2)
78 | return contex
79 |
80 | def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
81 | B, H, L_V, D = V.shape
82 |
83 | if self.mask_flag:
84 | attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
85 | scores.masked_fill_(attn_mask.mask, -np.inf)
86 |
87 | attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
88 |
89 | context_in[torch.arange(B)[:, None, None],
90 | torch.arange(H)[None, :, None],
91 | index, :] = torch.matmul(attn, V).type_as(context_in)
92 | if self.output_attention:
93 | attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
94 | attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
95 | return (context_in, attns)
96 | else:
97 | return (context_in, None)
98 |
99 | def forward(self, queries, keys, values, attn_mask):
100 | B, L_Q, H, D = queries.shape
101 | _, L_K, _, _ = keys.shape
102 |
103 | queries = queries.transpose(2, 1)
104 | keys = keys.transpose(2, 1)
105 | values = values.transpose(2, 1)
106 |
107 | U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
108 | u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
109 |
110 | U_part = U_part if U_part < L_K else L_K
111 | u = u if u < L_Q else L_Q
112 |
113 | scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
114 |
115 | # add scale factor
116 | scale = self.scale or 1. / sqrt(D)
117 | if scale is not None:
118 | scores_top = scores_top * scale
119 | # get the context
120 | context = self._get_initial_context(values, L_Q)
121 | # update the context with selected top_k queries
122 | context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
123 |
124 | return context.contiguous(), attn
125 |
126 |
127 | class AttentionLayer(nn.Module):
128 | def __init__(self, attention, d_model, n_heads, d_keys=None,
129 | d_values=None):
130 | super(AttentionLayer, self).__init__()
131 |
132 | d_keys = d_keys or (d_model // n_heads)
133 | d_values = d_values or (d_model // n_heads)
134 |
135 | self.inner_attention = attention
136 | self.query_projection = nn.Linear(d_model, d_keys * n_heads)
137 | self.key_projection = nn.Linear(d_model, d_keys * n_heads)
138 | self.value_projection = nn.Linear(d_model, d_values * n_heads)
139 | self.out_projection = nn.Linear(d_values * n_heads, d_model)
140 | self.n_heads = n_heads
141 |
142 | def forward(self, queries, keys, values, attn_mask):
143 | B, L, _ = queries.shape
144 | _, S, _ = keys.shape
145 | H = self.n_heads
146 | queries = self.query_projection(queries).view(B, L, H, -1)
147 | keys = self.key_projection(keys).view(B, S, H, -1)
148 | values = self.value_projection(values).view(B, S, H, -1)
149 | out, attn = self.inner_attention(
150 | queries,
151 | keys,
152 | values,
153 | attn_mask
154 | )
155 | out = out.view(B, L, -1)
156 |
157 | return self.out_projection(out), attn
158 |
--------------------------------------------------------------------------------
/layers/Transformer_EncDec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ConvLayer(nn.Module):
7 | def __init__(self, c_in):
8 | super(ConvLayer, self).__init__()
9 | self.downConv = nn.Conv1d(in_channels=c_in,
10 | out_channels=c_in,
11 | kernel_size=3,
12 | padding=2,
13 | padding_mode='circular')
14 | self.norm = nn.BatchNorm1d(c_in)
15 | self.activation = nn.ELU()
16 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
17 |
18 | def forward(self, x):
19 | x = self.downConv(x.permute(0, 2, 1))
20 | x = self.norm(x)
21 | x = self.activation(x)
22 | x = self.maxPool(x)
23 | x = x.transpose(1, 2)
24 | return x
25 |
26 | class EncoderLayer(nn.Module):
27 | def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
28 | super(EncoderLayer, self).__init__()
29 | d_ff = d_ff or 4 * d_model
30 | self.attention = attention
31 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
32 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
33 | self.norm1 = nn.LayerNorm(d_model)
34 | self.norm2 = nn.LayerNorm(d_model)
35 | self.dropout = nn.Dropout(dropout)
36 | self.activation = F.relu if activation == "relu" else F.gelu
37 | def forward(self, x, attn_mask=None):
38 | new_x, attn = self.attention(
39 | x, x, x,
40 | attn_mask=attn_mask
41 | )
42 | x = x + self.dropout(new_x)
43 |
44 | y = x = self.norm1(x)
45 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
46 | y = self.dropout(self.conv2(y).transpose(-1, 1))
47 |
48 | return self.norm2(x + y), attn
49 |
50 |
51 | class Encoder(nn.Module):
52 | def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
53 | super(Encoder, self).__init__()
54 | self.attn_layers = nn.ModuleList(attn_layers)
55 | self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
56 | self.norm = norm_layer
57 |
58 | def forward(self, x, attn_mask=None):
59 | attns = []
60 | if self.conv_layers is not None:
61 | for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
62 | x, attn = attn_layer(x, attn_mask=attn_mask)
63 | x = conv_layer(x)
64 | attns.append(attn)
65 | x, attn = self.attn_layers[-1](x)
66 | attns.append(attn)
67 | else:
68 | for attn_layer in self.attn_layers:
69 | x, attn = attn_layer(x, attn_mask=attn_mask)
70 | attns.append(attn)
71 |
72 | if self.norm is not None:
73 | x = self.norm(x)
74 |
75 | return x, attns
76 |
77 |
78 | class DecoderLayer(nn.Module):
79 | def __init__(self, self_attention, cross_attention, d_model, d_ff=None,
80 | dropout=0.1, activation="relu"):
81 | super(DecoderLayer, self).__init__()
82 | d_ff = d_ff or 4 * d_model
83 | self.self_attention = self_attention
84 | self.cross_attention = cross_attention
85 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
86 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
87 | self.norm1 = nn.LayerNorm(d_model)
88 | self.norm2 = nn.LayerNorm(d_model)
89 | self.norm3 = nn.LayerNorm(d_model)
90 | self.dropout = nn.Dropout(dropout)
91 | self.activation = F.relu if activation == "relu" else F.gelu
92 | # dec_out(in), enc_out
93 | def forward(self, x, cross, x_mask=None, cross_mask=None):
94 | x = x + self.dropout(self.self_attention(
95 | x, x, x,
96 | attn_mask=x_mask
97 | )[0])
98 | x = self.norm1(x)
99 |
100 | x = x + self.dropout(self.cross_attention(
101 | x, cross, cross, #q,k,v
102 | attn_mask=cross_mask
103 | )[0])
104 |
105 | y = x = self.norm2(x)
106 | y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
107 | y = self.dropout(self.conv2(y).transpose(-1, 1))
108 |
109 | return self.norm3(x + y)
110 |
111 |
112 | class Decoder(nn.Module):
113 | def __init__(self, layers, norm_layer=None, projection=None):
114 | super(Decoder, self).__init__()
115 | self.layers = nn.ModuleList(layers)
116 | self.norm = norm_layer
117 | self.projection = projection
118 |
119 | #self.decoder(dec_out(in), enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
120 | def forward(self, x, cross, x_mask=None, cross_mask=None ,external=None):
121 | for layer in self.layers:
122 | x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
123 |
124 | if self.norm is not None:
125 | x = self.norm(x)
126 |
127 | if self.projection is not None:
128 | x = self.projection(x)
129 | return x
130 |
--------------------------------------------------------------------------------
/layers/__pycache__/AutoCorrelation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/layers/__pycache__/AutoCorrelation.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/AutoCorrelation.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/layers/__pycache__/AutoCorrelation.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/Autoformer_EncDec.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/layers/__pycache__/Autoformer_EncDec.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/Autoformer_EncDec.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/layers/__pycache__/Autoformer_EncDec.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/Embed.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/layers/__pycache__/Embed.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/Embed.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/layers/__pycache__/Embed.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/MSGBlock.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/layers/__pycache__/MSGBlock.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/SelfAttention_Family.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/layers/__pycache__/SelfAttention_Family.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/SelfAttention_Family.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/layers/__pycache__/SelfAttention_Family.cpython-39.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/Transformer_EncDec.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/layers/__pycache__/Transformer_EncDec.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/Transformer_EncDec.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/layers/__pycache__/Transformer_EncDec.cpython-39.pyc
--------------------------------------------------------------------------------
/models/Autoformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layers.Embed import DataEmbedding, DataEmbedding_wo_pos,DataEmbedding_wo_pos_temp,DataEmbedding_wo_temp
5 | from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer
6 | from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp
7 | import math
8 | import numpy as np
9 |
10 |
11 | class Model(nn.Module):
12 | """
13 | Autoformer is the first method to achieve the series-wise connection,
14 | with inherent O(LlogL) complexity
15 | """
16 | def __init__(self, configs):
17 | super(Model, self).__init__()
18 | self.seq_len = configs.seq_len
19 | self.label_len = configs.label_len
20 | self.pred_len = configs.pred_len
21 | self.output_attention = configs.output_attention
22 |
23 | # Decomp
24 | kernel_size = configs.moving_avg
25 | self.decomp = series_decomp(kernel_size)
26 |
27 | # Embedding
28 | # The series-wise connection inherently contains the sequential information.
29 | # Thus, we can discard the position embedding of transformers.
30 | if configs.embed_type == 0:
31 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
32 | configs.dropout)
33 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq,
34 | configs.dropout)
35 | elif configs.embed_type == 1:
36 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
37 | configs.dropout)
38 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
39 | configs.dropout)
40 | elif configs.embed_type == 2:
41 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
42 | configs.dropout)
43 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq,
44 | configs.dropout)
45 |
46 | elif configs.embed_type == 3:
47 | self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq,
48 | configs.dropout)
49 | self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq,
50 | configs.dropout)
51 | elif configs.embed_type == 4:
52 | self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq,
53 | configs.dropout)
54 | self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq,
55 | configs.dropout)
56 |
57 | # Encoder
58 | self.encoder = Encoder(
59 | [
60 | EncoderLayer(
61 | AutoCorrelationLayer(
62 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout,
63 | output_attention=configs.output_attention),
64 | configs.d_model, configs.n_heads),
65 | configs.d_model,
66 | configs.d_ff,
67 | moving_avg=configs.moving_avg,
68 | dropout=configs.dropout,
69 | activation=configs.activation
70 | ) for l in range(configs.e_layers)
71 | ],
72 | norm_layer=my_Layernorm(configs.d_model)
73 | )
74 | # Decoder
75 | self.decoder = Decoder(
76 | [
77 | DecoderLayer(
78 | AutoCorrelationLayer(
79 | AutoCorrelation(True, configs.factor, attention_dropout=configs.dropout,
80 | output_attention=False),
81 | configs.d_model, configs.n_heads),
82 | AutoCorrelationLayer(
83 | AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout,
84 | output_attention=False),
85 | configs.d_model, configs.n_heads),
86 | configs.d_model,
87 | configs.c_out,
88 | configs.d_ff,
89 | moving_avg=configs.moving_avg,
90 | dropout=configs.dropout,
91 | activation=configs.activation,
92 | )
93 | for l in range(configs.d_layers)
94 | ],
95 | norm_layer=my_Layernorm(configs.d_model),
96 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
97 | )
98 |
99 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
100 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
101 | # decomp init
102 | mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1)
103 | zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device)
104 | seasonal_init, trend_init = self.decomp(x_enc)
105 | # decoder input
106 | trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
107 | seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)
108 | # enc
109 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
110 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
111 | # dec
112 | dec_out = self.dec_embedding(seasonal_init, x_mark_dec)
113 | seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask,
114 | trend=trend_init)
115 | # final
116 | dec_out = trend_part + seasonal_part
117 |
118 | if self.output_attention:
119 | return dec_out[:, -self.pred_len:, :], attns
120 | else:
121 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
122 |
--------------------------------------------------------------------------------
/models/DLinear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class moving_avg(nn.Module):
7 | """
8 | Moving average block to highlight the trend of time series
9 | """
10 | def __init__(self, kernel_size, stride):
11 | super(moving_avg, self).__init__()
12 | self.kernel_size = kernel_size
13 | self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
14 |
15 | def forward(self, x):
16 | # padding on the both ends of time series
17 | front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
18 | end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
19 | x = torch.cat([front, x, end], dim=1)
20 | # print(x.permute(0, 2, 1))
21 | x = self.avg(x.permute(0, 2, 1))
22 | x = x.permute(0, 2, 1)
23 | return x
24 |
25 |
26 | class series_decomp(nn.Module):
27 | """
28 | Series decomposition block
29 | """
30 | def __init__(self, kernel_size):
31 | super(series_decomp, self).__init__()
32 | self.moving_avg = moving_avg(kernel_size, stride=1)
33 |
34 | def forward(self, x):
35 | moving_mean = self.moving_avg(x)
36 | res = x - moving_mean
37 | return res, moving_mean
38 |
39 | class Model(nn.Module):
40 | """
41 | Decomposition-Linear
42 | """
43 | def __init__(self, configs):
44 | super(Model, self).__init__()
45 | self.seq_len = configs.seq_len
46 | self.pred_len = configs.pred_len
47 |
48 | # Decompsition Kernel Size
49 | kernel_size = 25
50 | self.decompsition = series_decomp(kernel_size) #return res, moving_mean
51 | self.individual = configs.individual
52 | self.channels = configs.enc_in
53 |
54 | if self.individual:
55 | self.Linear_Seasonal = nn.ModuleList()
56 | self.Linear_Trend = nn.ModuleList()
57 |
58 | for i in range(self.channels):
59 | self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len))
60 | self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len))
61 |
62 | # Use this two lines if you want to visualize the weights
63 | # self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
64 | # self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
65 | else:
66 | self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len)
67 | self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len)
68 |
69 | # Use this two lines if you want to visualize the weights
70 | # self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
71 | # self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
72 |
73 | def forward(self, x):
74 | # x: [Batch, Input length, Channel]
75 | seasonal_init, trend_init = self.decompsition(x) #return res, moving_mean
76 | seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)
77 | if self.individual:
78 | seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device)
79 | trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device)
80 | for i in range(self.channels):
81 | seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:])
82 | trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:])
83 | else:
84 | seasonal_output = self.Linear_Seasonal(seasonal_init)
85 | trend_output = self.Linear_Trend(trend_init)
86 |
87 | x = seasonal_output + trend_output
88 | return x.permute(0,2,1) # to [Batch, Output length, Channel]
89 |
--------------------------------------------------------------------------------
/models/Informer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from utils.masking import TriangularCausalMask, ProbMask
5 | from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer
6 | from layers.SelfAttention_Family import FullAttention, ProbAttention, AttentionLayer
7 | from layers.Embed import DataEmbedding,DataEmbedding_wo_pos,DataEmbedding_wo_temp,DataEmbedding_wo_pos_temp
8 | import numpy as np
9 |
10 |
11 | class Model(nn.Module):
12 | """
13 | Informer with Propspare attention in O(LlogL) complexity
14 | """
15 | def __init__(self, configs):
16 | super(Model, self).__init__()
17 | self.pred_len = configs.pred_len
18 | self.output_attention = configs.output_attention
19 |
20 | # Embedding
21 | if configs.embed_type == 0:
22 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
23 | configs.dropout)
24 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
25 | configs.dropout)
26 | elif configs.embed_type == 1:
27 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
28 | configs.dropout)
29 | self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
30 | configs.dropout)
31 | elif configs.embed_type == 2:
32 | self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
33 | configs.dropout)
34 | self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq,
35 | configs.dropout)
36 |
37 | elif configs.embed_type == 3:
38 | self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq,
39 | configs.dropout)
40 | self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq,
41 | configs.dropout)
42 | elif configs.embed_type == 4:
43 | self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq,
44 | configs.dropout)
45 | self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq,
46 | configs.dropout)
47 | # Encoder
48 | self.encoder = Encoder(
49 | [
50 | EncoderLayer(
51 | AttentionLayer(
52 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout,
53 | output_attention=configs.output_attention),
54 | configs.d_model, configs.n_heads),
55 | configs.d_model,
56 | configs.d_ff,
57 | dropout=configs.dropout,
58 | activation=configs.activation
59 | ) for l in range(configs.e_layers)
60 | ],
61 | [
62 | ConvLayer(
63 | configs.d_model
64 | ) for l in range(configs.e_layers - 1)
65 | ] if configs.distil else None,
66 | norm_layer=torch.nn.LayerNorm(configs.d_model)
67 | )
68 | # Decoder
69 | self.decoder = Decoder(
70 | [
71 | DecoderLayer(
72 | AttentionLayer(
73 | ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False),
74 | configs.d_model, configs.n_heads),
75 | AttentionLayer(
76 | ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),
77 | configs.d_model, configs.n_heads),
78 | configs.d_model,
79 | configs.d_ff,
80 | dropout=configs.dropout,
81 | activation=configs.activation,
82 | )
83 | for l in range(configs.d_layers)
84 | ],
85 | norm_layer=torch.nn.LayerNorm(configs.d_model),
86 | projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
87 | )
88 |
89 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
90 | enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
91 |
92 | enc_out = self.enc_embedding(x_enc, x_mark_enc)
93 | enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
94 |
95 | dec_out = self.dec_embedding(x_dec, x_mark_dec)
96 | dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
97 |
98 | if self.output_attention:
99 | return dec_out[:, -self.pred_len:, :], attns
100 | else:
101 | return dec_out[:, -self.pred_len:, :] # [B, L, D]
102 |
--------------------------------------------------------------------------------
/models/MSGNet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # import pywt
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.fft
7 | from layers.Embed import DataEmbedding
8 | from layers.MSGBlock import GraphBlock, simpleVIT, Attention_Block, Predict
9 |
10 |
11 | def FFT_for_Period(x, k=2):
12 | # [B, T, C]
13 | xf = torch.fft.rfft(x, dim=1)
14 | frequency_list = abs(xf).mean(0).mean(-1)
15 | frequency_list[0] = 0
16 | _, top_list = torch.topk(frequency_list, k)
17 | top_list = top_list.detach().cpu().numpy()
18 | period = x.shape[1] // top_list
19 | return period, abs(xf).mean(-1)[:, top_list]
20 |
21 |
22 | class ScaleGraphBlock(nn.Module):
23 | def __init__(self, configs):
24 | super(ScaleGraphBlock, self).__init__()
25 | self.seq_len = configs.seq_len
26 | self.pred_len = configs.pred_len
27 | self.k = configs.top_k
28 |
29 | self.att0 = Attention_Block(configs.d_model, configs.d_ff,
30 | n_heads=configs.n_heads, dropout=configs.dropout, activation="gelu")
31 | self.norm = nn.LayerNorm(configs.d_model)
32 | self.gelu = nn.GELU()
33 | self.gconv = nn.ModuleList()
34 | for i in range(self.k):
35 | self.gconv.append(
36 | GraphBlock(configs.c_out , configs.d_model , configs.conv_channel, configs.skip_channel,
37 | configs.gcn_depth , configs.dropout, configs.propalpha ,configs.seq_len,
38 | configs.node_dim))
39 |
40 |
41 | def forward(self, x):
42 | B, T, N = x.size()
43 | scale_list, scale_weight = FFT_for_Period(x, self.k)
44 | res = []
45 | for i in range(self.k):
46 | scale = scale_list[i]
47 | #Gconv
48 | x = self.gconv[i](x)
49 | # paddng
50 | if (self.seq_len) % scale != 0:
51 | length = (((self.seq_len) // scale) + 1) * scale
52 | padding = torch.zeros([x.shape[0], (length - (self.seq_len)), x.shape[2]]).to(x.device)
53 | out = torch.cat([x, padding], dim=1)
54 | else:
55 | length = self.seq_len
56 | out = x
57 | out = out.reshape(B, length // scale, scale, N)
58 |
59 | #for Mul-attetion
60 | out = out.reshape(-1 , scale , N)
61 | out = self.norm(self.att0(out))
62 | out = self.gelu(out)
63 | out = out.reshape(B, -1 , scale , N).reshape(B ,-1 ,N)
64 | # #for simpleVIT
65 | # out = self.att(out.permute(0, 3, 1, 2).contiguous()) #return
66 | # out = out.permute(0, 2, 3, 1).reshape(B, -1 ,N)
67 |
68 | out = out[:, :self.seq_len, :]
69 | res.append(out)
70 |
71 | res = torch.stack(res, dim=-1)
72 | # adaptive aggregation
73 | scale_weight = F.softmax(scale_weight, dim=1)
74 | scale_weight = scale_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1)
75 | res = torch.sum(res * scale_weight, -1)
76 | # residual connection
77 | res = res + x
78 | return res
79 |
80 |
81 | class Model(nn.Module):
82 | def __init__(self, configs):
83 | super(Model, self).__init__()
84 | self.configs = configs
85 | self.task_name = configs.task_name
86 | self.seq_len = configs.seq_len
87 | self.label_len = configs.label_len
88 | self.pred_len = configs.pred_len
89 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
90 |
91 | # for graph
92 | # self.num_nodes = configs.c_out
93 | # self.subgraph_size = configs.subgraph_size
94 | # self.node_dim = configs.node_dim
95 | # to return adj (node , node)
96 | # self.graph = constructor_graph()
97 |
98 | self.model = nn.ModuleList([ScaleGraphBlock(configs) for _ in range(configs.e_layers)])
99 | self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model,
100 | configs.embed, configs.freq, configs.dropout)
101 | self.layer = configs.e_layers
102 | self.layer_norm = nn.LayerNorm(configs.d_model)
103 | self.predict_linear = nn.Linear(
104 | self.seq_len, self.pred_len + self.seq_len)
105 | self.projection = nn.Linear(
106 | configs.d_model, configs.c_out, bias=True)
107 | self.seq2pred = Predict(configs.individual ,configs.c_out,
108 | configs.seq_len, configs.pred_len, configs.dropout)
109 |
110 |
111 | def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
112 | # Normalization from Non-stationary Transformer
113 | means = x_enc.mean(1, keepdim=True).detach()
114 | x_enc = x_enc - means
115 | stdev = torch.sqrt(
116 | torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
117 | x_enc /= stdev
118 |
119 | # embedding
120 | enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
121 | # adp = self.graph(torch.arange(self.num_nodes).to(self.device))
122 | for i in range(self.layer):
123 | enc_out = self.layer_norm(self.model[i](enc_out))
124 |
125 | # porject back
126 | dec_out = self.projection(enc_out)
127 | dec_out = self.seq2pred(dec_out.transpose(1, 2)).transpose(1, 2)
128 |
129 | # De-Normalization from Non-stationary Transformer
130 | dec_out = dec_out * \
131 | (stdev[:, 0, :].unsqueeze(1).repeat(
132 | 1, self.pred_len, 1))
133 | dec_out = dec_out + \
134 | (means[:, 0, :].unsqueeze(1).repeat(
135 | 1, self.pred_len, 1))
136 |
137 | return dec_out[:, -self.pred_len:, :]
138 |
139 |
140 |
--------------------------------------------------------------------------------
/models/__pycache__/Autoformer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/models/__pycache__/Autoformer.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/DLinear.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/models/__pycache__/DLinear.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/Informer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/models/__pycache__/Informer.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/MSGNet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/models/__pycache__/MSGNet.cpython-39.pyc
--------------------------------------------------------------------------------
/pic/main_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/pic/main_result.jpg
--------------------------------------------------------------------------------
/pic/model1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/pic/model1.jpg
--------------------------------------------------------------------------------
/pic/model2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/pic/model2.jpg
--------------------------------------------------------------------------------
/run_longExp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | from multiprocessing import freeze_support
5 | import torch
6 | from exp.exp_main import Exp_Main
7 | import random
8 | import numpy as np
9 |
10 | fix_seed = 2021
11 | random.seed(fix_seed)
12 | torch.manual_seed(fix_seed)
13 | np.random.seed(fix_seed)
14 |
15 | parser = argparse.ArgumentParser(description='MSGNet for Time Series Forecasting')
16 |
17 | # basic config
18 | parser.add_argument('--task_name', type=str, required=False, default='long_term_forecast',
19 | help='task name, options:[long_term_forecast, mask, short_term_forecast, imputation, classification, anomaly_detection]')
20 | parser.add_argument('--is_training', type=int, required=True, default=1, help='status')
21 | parser.add_argument('--model_id', type=str, required=True, default='test', help='model id')
22 | parser.add_argument('--model', type=str, required=True, default='Autoformer',
23 | help='model name, options: [Autoformer, Informer, Transformer]')
24 |
25 | # data loader
26 | parser.add_argument('--data', type=str, required=True, default='ETTm1', help='dataset type')
27 | parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file')
28 | parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
29 | parser.add_argument('--features', type=str, default='M',
30 | help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate,'
31 | ' S:univariate predict univariate, MS:multivariate predict univariate')
32 | parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
33 | parser.add_argument('--freq', type=str, default='h',
34 | help='freq for time features encoding, '
35 | 'options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], '
36 | 'you can also use more detailed freq like 15min or 3h')
37 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
38 |
39 | # forecasting task
40 | parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
41 | parser.add_argument('--label_len', type=int, default=48, help='start token length')
42 | parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
43 | parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4')
44 |
45 |
46 | parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock/ScaleGraphBlock')
47 | parser.add_argument('--num_kernels', type=int, default=6, help='for Inception')
48 |
49 | parser.add_argument('--num_nodes', type=int, default=7, help='to create Graph')
50 | parser.add_argument('--subgraph_size', type=int, default=3, help='neighbors number')
51 | parser.add_argument('--tanhalpha', type=float, default=3, help='')
52 |
53 | #GCN
54 | parser.add_argument('--node_dim', type=int, default=10, help='each node embbed to dim dimentions')
55 | parser.add_argument('--gcn_depth', type=int, default=2, help='')
56 | parser.add_argument('--gcn_dropout', type=float, default=0.3, help='')
57 | parser.add_argument('--propalpha', type=float, default=0.3, help='')
58 | parser.add_argument('--conv_channel', type=int, default=32, help='')
59 | parser.add_argument('--skip_channel', type=int, default=32, help='')
60 |
61 |
62 | # DLinear
63 | parser.add_argument('--individual', action='store_true', default=False, help='DLinear: a linear layer for each variate(channel) individually')
64 | # Formers
65 | parser.add_argument('--embed_type', type=int, default=0, help='0: default '
66 | '1: value embedding + temporal embedding + positional embedding '
67 | '2: value embedding + temporal embedding '
68 | '3: value embedding + positional embedding '
69 | '4: value embedding')
70 | parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
71 | parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
72 | parser.add_argument('--c_out', type=int, default=7, help='output size')
73 | parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
74 | parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
75 | parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
76 | parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
77 | parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
78 | parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
79 | parser.add_argument('--factor', type=int, default=1, help='attn factor')
80 | parser.add_argument('--distil', action='store_false',
81 | help='whether to use distilling in encoder, using this argument means not using distilling',
82 | default=True)
83 | parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
84 | parser.add_argument('--embed', type=str, default='timeF',
85 | help='time features encoding, options:[timeF, fixed, learned]')
86 | parser.add_argument('--activation', type=str, default='gelu', help='activation')
87 | parser.add_argument('--output_attention', action='store_true', help='whether to output attention in encoder')
88 | parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data')
89 |
90 | # optimization
91 | parser.add_argument('--num_workers', type=int, default=8, help='data loader num workers')
92 | parser.add_argument('--itr', type=int, default=2, help='experiments times')
93 | parser.add_argument('--train_epochs', type=int, default=10, help='train epochs')
94 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
95 | parser.add_argument('--patience', type=int, default=3, help='early stopping patience')
96 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
97 | parser.add_argument('--des', type=str, default='test', help='exp description')
98 | parser.add_argument('--loss', type=str, default='MSE', help='loss function')
99 | parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
100 | parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
101 |
102 | # GPU
103 | parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
104 | parser.add_argument('--gpu', type=int, default=0, help='gpu')
105 | parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
106 | parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
107 | parser.add_argument('--test_flop', action='store_true', default=False, help='See utils/tools for usage')
108 |
109 | args = parser.parse_args()
110 | args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
111 |
112 | if args.use_gpu and args.use_multi_gpu:
113 | args.dvices = args.devices.replace(' ', '')
114 | device_ids = args.devices.split(',')
115 | args.device_ids = [int(id_) for id_ in device_ids]
116 | args.gpu = args.device_ids[0]
117 |
118 | print('Args in experiment:')
119 | print(args)
120 |
121 | Exp = Exp_Main
122 |
123 | if args.is_training:
124 | start = time.time()
125 | for ii in range(args.itr):
126 | # setting record of experiments
127 | setting = '{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format(
128 | args.model_id,
129 | args.model,
130 | args.data,
131 | args.features,
132 | args.seq_len,
133 | args.label_len,
134 | args.pred_len,
135 | args.d_model,
136 | args.n_heads,
137 | args.e_layers,
138 | args.d_layers,
139 | args.d_ff,
140 | args.factor,
141 | args.embed,
142 | args.distil,
143 | args.des, ii)
144 |
145 | exp = Exp(args) # set experiments
146 | print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
147 | exp.train(setting)
148 |
149 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
150 | exp.test(setting)
151 |
152 | # if args.do_predict:
153 | # print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
154 | # exp.predict(setting, True)
155 |
156 | torch.cuda.empty_cache()
157 | end = time.time()
158 | used_time = end -start
159 | print("time:",used_time)
160 | f = open("result.txt", 'a')
161 | f.write('time:{}'.format(used_time))
162 | f.write('\n')
163 | f.write('\n')
164 | f.close()
165 | else:
166 | ii = 0
167 | setting = '{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format(args.model_id,
168 | args.model,
169 | args.data,
170 | args.features,
171 | args.seq_len,
172 | args.label_len,
173 | args.pred_len,
174 | args.d_model,
175 | args.n_heads,
176 | args.e_layers,
177 | args.d_layers,
178 | args.d_ff,
179 | args.factor,
180 | args.embed,
181 | args.distil,
182 | args.des, ii)
183 |
184 | exp = Exp(args) # set experiments
185 | print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
186 | exp.test(setting, test=1)
187 | torch.cuda.empty_cache()
188 |
--------------------------------------------------------------------------------
/scripts/ETTh1.sh:
--------------------------------------------------------------------------------
1 | if [ ! -d "./logs" ]; then
2 | mkdir ./logs
3 | fi
4 |
5 | if [ ! -d "./logs/ETTh1" ]; then
6 | mkdir ./logs/ETTh1
7 | fi
8 | export CUDA_VISIBLE_DEVICES=0
9 |
10 | seq_len=96
11 | label_len=48
12 | model_name=MSGNet
13 |
14 | pred_len=96
15 | python -u run_longExp.py \
16 | --is_training 1 \
17 | --root_path ./dataset/ \
18 | --data_path ETTh1.csv \
19 | --model_id ETTh1'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data ETTh1 \
22 | --features M \
23 | --freq h \
24 | --target 'OT' \
25 | --seq_len $seq_len \
26 | --label_len $label_len \
27 | --pred_len $pred_len \
28 | --e_layers 1 \
29 | --d_layers 1 \
30 | --factor 3 \
31 | --enc_in 7 \
32 | --dec_in 7 \
33 | --c_out 7 \
34 | --des 'Exp' \
35 | --d_model 32 \
36 | --d_ff 64 \
37 | --top_k 3 \
38 | --conv_channel 32 \
39 | --skip_channel 32 \
40 | --dropout 0.1 \
41 | --batch_size 32 \
42 | --itr 1 #>logs/ETTh1/$model_name'_'ETTh1_$seq_len'_'$pred_len.log
43 |
44 |
45 | pred_len=192
46 | python -u run_longExp.py \
47 | --is_training 1 \
48 | --root_path ./dataset/ \
49 | --data_path ETTh1.csv \
50 | --model_id ETTh1'_'$seq_len'_'$pred_len \
51 | --model $model_name \
52 | --data ETTh1 \
53 | --features M \
54 | --freq h \
55 | --target 'OT' \
56 | --seq_len $seq_len \
57 | --label_len $label_len \
58 | --pred_len $pred_len \
59 | --e_layers 1 \
60 | --d_layers 1 \
61 | --factor 3 \
62 | --enc_in 7 \
63 | --dec_in 7 \
64 | --c_out 7 \
65 | --des 'Exp' \
66 | --d_model 32 \
67 | --d_ff 64 \
68 | --top_k 3 \
69 | --conv_channel 32 \
70 | --skip_channel 32 \
71 | --dropout 0.1 \
72 | --batch_size 32 \
73 | --itr 1 #>logs/ETTh1/$model_name'_'ETTh1_$seq_len'_'$pred_len.log
74 |
75 | pred_len=336
76 | python -u run_longExp.py \
77 | --is_training 1 \
78 | --root_path ./dataset/ \
79 | --data_path ETTh1.csv \
80 | --model_id ETTh1'_'$seq_len'_'$pred_len \
81 | --model $model_name \
82 | --data ETTh1 \
83 | --features M \
84 | --freq h \
85 | --target 'OT' \
86 | --seq_len $seq_len \
87 | --label_len $label_len \
88 | --pred_len $pred_len \
89 | --e_layers 2 \
90 | --d_layers 1 \
91 | --factor 3 \
92 | --enc_in 7 \
93 | --dec_in 7 \
94 | --c_out 7 \
95 | --des 'Exp' \
96 | --d_model 32 \
97 | --d_ff 64 \
98 | --top_k 3 \
99 | --conv_channel 32 \
100 | --skip_channel 32 \
101 | --dropout 0.1 \
102 | --batch_size 32 \
103 | --itr 1 #>logs/ETTh1/$model_name'_'ETTh1_$seq_len'_'$pred_len.log
104 |
105 |
106 | pred_len=720
107 | python -u run_longExp.py \
108 | --is_training 1 \
109 | --root_path ./dataset/ \
110 | --data_path ETTh1.csv \
111 | --model_id ETTh1'_'$seq_len'_'$pred_len \
112 | --model $model_name \
113 | --data ETTh1 \
114 | --features M \
115 | --freq h \
116 | --target 'OT' \
117 | --seq_len $seq_len \
118 | --label_len $label_len \
119 | --pred_len $pred_len \
120 | --e_layers 1 \
121 | --d_layers 1 \
122 | --factor 3 \
123 | --enc_in 7 \
124 | --dec_in 7 \
125 | --c_out 7 \
126 | --des 'Exp' \
127 | --d_model 16 \
128 | --d_ff 32 \
129 | --top_k 3 \
130 | --conv_channel 32 \
131 | --skip_channel 32 \
132 | --dropout 0.1 \
133 | --batch_size 32 \
134 | --itr 1 #>logs/ETTh1/$model_name'_'ETTh1_$seq_len'_'$pred_len.log
--------------------------------------------------------------------------------
/scripts/ETTh2.sh:
--------------------------------------------------------------------------------
1 | if [ ! -d "./logs" ]; then
2 | mkdir ./logs
3 | fi
4 |
5 | if [ ! -d "./logs/ETTh2" ]; then
6 | mkdir ./logs/ETTh2
7 | fi
8 | export CUDA_VISIBLE_DEVICES=1
9 |
10 | seq_len=96
11 | label_len=48
12 | model_name=MSGNet
13 |
14 | pred_len=96
15 | python -u run_longExp.py \
16 | --is_training 1 \
17 | --root_path ./dataset/ \
18 | --data_path ETTh2.csv \
19 | --model_id ETTh2'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data ETTh2 \
22 | --features M \
23 | --freq h \
24 | --target 'OT' \
25 | --seq_len $seq_len \
26 | --label_len $label_len \
27 | --pred_len $pred_len \
28 | --e_layers 2 \
29 | --d_layers 1 \
30 | --factor 3 \
31 | --enc_in 7 \
32 | --dec_in 7 \
33 | --c_out 7 \
34 | --des 'Exp' \
35 | --d_model 16 \
36 | --d_ff 32 \
37 | --conv_channel 32 \
38 | --skip_channel 32 \
39 | --top_k 5 \
40 | --batch_size 32 \
41 | --itr 1 #>logs/ETTh2/$model_name'_'ETTh2_$seq_len'_'$pred_len.log
42 |
43 | pred_len=192
44 | python -u run_longExp.py \
45 | --is_training 1 \
46 | --root_path ./dataset/ \
47 | --data_path ETTh2.csv \
48 | --model_id ETTh2'_'$seq_len'_'$pred_len \
49 | --model $model_name \
50 | --data ETTh2 \
51 | --features M \
52 | --freq h \
53 | --target 'OT' \
54 | --seq_len $seq_len \
55 | --label_len $label_len \
56 | --pred_len $pred_len \
57 | --e_layers 2 \
58 | --d_layers 1 \
59 | --factor 3 \
60 | --enc_in 7 \
61 | --dec_in 7 \
62 | --c_out 7 \
63 | --des 'Exp' \
64 | --d_model 16 \
65 | --d_ff 32 \
66 | --conv_channel 32 \
67 | --skip_channel 32 \
68 | --top_k 5 \
69 | --batch_size 32 \
70 | --itr 1 #>logs/ETTh2/$model_name'_'ETTh2_$seq_len'_'$pred_len.log
71 |
72 | pred_len=336
73 | python -u run_longExp.py \
74 | --is_training 1 \
75 | --root_path ./dataset/ \
76 | --data_path ETTh2.csv \
77 | --model_id ETTh2'_'$seq_len'_'$pred_len \
78 | --model $model_name \
79 | --data ETTh2 \
80 | --features M \
81 | --freq h \
82 | --target 'OT' \
83 | --seq_len $seq_len \
84 | --label_len $label_len \
85 | --pred_len $pred_len \
86 | --e_layers 2 \
87 | --d_layers 1 \
88 | --factor 3 \
89 | --enc_in 7 \
90 | --dec_in 7 \
91 | --c_out 7 \
92 | --des 'Exp' \
93 | --d_model 16 \
94 | --d_ff 32 \
95 | --conv_channel 32 \
96 | --skip_channel 32 \
97 | --top_k 5 \
98 | --batch_size 32 \
99 | --itr 1 #>logs/ETTh2/$model_name'_'ETTh2_$seq_len'_'$pred_len.log
100 |
101 | pred_len=720
102 | python -u run_longExp.py \
103 | --is_training 1 \
104 | --root_path ./dataset/ \
105 | --data_path ETTh2.csv \
106 | --model_id ETTh2'_'$seq_len'_'$pred_len \
107 | --model $model_name \
108 | --data ETTh2 \
109 | --features M \
110 | --freq h \
111 | --target 'OT' \
112 | --seq_len $seq_len \
113 | --label_len $label_len \
114 | --pred_len $pred_len \
115 | --e_layers 2 \
116 | --d_layers 1 \
117 | --factor 3 \
118 | --enc_in 7 \
119 | --dec_in 7 \
120 | --c_out 7 \
121 | --des 'Exp' \
122 | --d_model 16 \
123 | --d_ff 32 \
124 | --conv_channel 32 \
125 | --skip_channel 32 \
126 | --top_k 5 \
127 | --batch_size 32 \
128 | --itr 1 #>logs/ETTh2/$model_name'_'ETTh2_$seq_len'_'$pred_len.log
--------------------------------------------------------------------------------
/scripts/ETTm1.sh:
--------------------------------------------------------------------------------
1 | if [ ! -d "./logs" ]; then
2 | mkdir ./logs
3 | fi
4 |
5 | if [ ! -d "./logs/ETTm1" ]; then
6 | mkdir ./logs/ETTm1
7 | fi
8 |
9 | export CUDA_VISIBLE_DEVICES=2
10 |
11 | seq_len=96
12 | label_len=48
13 | model_name=MSGNet
14 |
15 | pred_len=96
16 | python -u run_longExp.py \
17 | --is_training 1 \
18 | --root_path ./dataset/ \
19 | --data_path ETTm1.csv \
20 | --model_id ETTm1'_'$seq_len'_'$pred_len \
21 | --model $model_name \
22 | --data ETTm1 \
23 | --features M \
24 | --target 'OT' \
25 | --seq_len $seq_len \
26 | --label_len $label_len \
27 | --pred_len $pred_len \
28 | --e_layers 1 \
29 | --d_layers 1 \
30 | --factor 3 \
31 | --enc_in 7 \
32 | --dec_in 7 \
33 | --c_out 7 \
34 | --des 'Exp' \
35 | --d_model 32 \
36 | --d_ff 32 \
37 | --top_k 3 \
38 | --conv_channel 32 \
39 | --skip_channel 32 \
40 | --batch_size 32 \
41 | --itr 1 #>logs/ETTm1/$model_name'_'ETTm1_$seq_len'_'$pred_len.log
42 |
43 | pred_len=192
44 | python -u run_longExp.py \
45 | --is_training 1 \
46 | --root_path ./dataset/ \
47 | --data_path ETTm1.csv \
48 | --model_id ETTm1'_'$seq_len'_'$pred_len \
49 | --model $model_name \
50 | --data ETTm1 \
51 | --features M \
52 | --target 'OT' \
53 | --seq_len $seq_len \
54 | --label_len $label_len \
55 | --pred_len $pred_len \
56 | --e_layers 1 \
57 | --d_layers 1 \
58 | --factor 3 \
59 | --enc_in 7 \
60 | --dec_in 7 \
61 | --c_out 7 \
62 | --des 'Exp' \
63 | --d_model 32 \
64 | --d_ff 32 \
65 | --top_k 3 \
66 | --conv_channel 16 \
67 | --skip_channel 32 \
68 | --batch_size 32 \
69 | --itr 1 #>logs/ETTm1/$model_name'_'ETTm1_$seq_len'_'$pred_len.log
70 |
71 |
72 | pred_len=336
73 | python -u run_longExp.py \
74 | --is_training 1 \
75 | --root_path ./dataset/ \
76 | --data_path ETTm1.csv \
77 | --model_id ETTm1'_'$seq_len'_'$pred_len \
78 | --model $model_name \
79 | --data ETTm1 \
80 | --features M \
81 | --target 'OT' \
82 | --seq_len $seq_len \
83 | --label_len $label_len \
84 | --pred_len $pred_len \
85 | --e_layers 1 \
86 | --d_layers 1 \
87 | --factor 3 \
88 | --enc_in 7 \
89 | --dec_in 7 \
90 | --c_out 7 \
91 | --des 'Exp' \
92 | --d_model 32 \
93 | --d_ff 32 \
94 | --top_k 3 \
95 | --conv_channel 16 \
96 | --skip_channel 32 \
97 | --batch_size 32 \
98 | --itr 1 #>logs/ETTm1/$model_name'_'ETTm1_$seq_len'_'$pred_len.log
99 |
100 |
101 | pred_len=720
102 | python -u run_longExp.py \
103 | --is_training 1 \
104 | --root_path ./dataset/ \
105 | --data_path ETTm1.csv \
106 | --model_id ETTm1'_'$seq_len'_'$pred_len \
107 | --model $model_name \
108 | --data ETTm1 \
109 | --features M \
110 | --target 'OT' \
111 | --seq_len $seq_len \
112 | --label_len $label_len \
113 | --pred_len $pred_len \
114 | --e_layers 1 \
115 | --d_layers 1 \
116 | --factor 3 \
117 | --enc_in 7 \
118 | --dec_in 7 \
119 | --c_out 7 \
120 | --des 'Exp' \
121 | --d_model 32 \
122 | --d_ff 32 \
123 | --top_k 3 \
124 | --conv_channel 16 \
125 | --skip_channel 32 \
126 | --batch_size 32 \
127 | --itr 1 #>logs/ETTm1/$model_name'_'ETTm1_$seq_len'_'$pred_len.log
--------------------------------------------------------------------------------
/scripts/ETTm2.sh:
--------------------------------------------------------------------------------
1 | if [ ! -d "./logs" ]; then
2 | mkdir ./logs
3 | fi
4 |
5 | if [ ! -d "./logs/ETTm2" ]; then
6 | mkdir ./logs/ETTm2
7 | fi
8 |
9 | export CUDA_VISIBLE_DEVICES=3
10 |
11 | seq_len=96
12 | label_len=48
13 | model_name=MSGNet
14 |
15 | pred_len=96
16 | python -u run_longExp.py \
17 | --is_training 1 \
18 | --root_path ./dataset/ \
19 | --data_path ETTm2.csv \
20 | --model_id ETTm2'_'$seq_len'_'$pred_len \
21 | --model $model_name \
22 | --data ETTm2 \
23 | --features M \
24 | --target 'OT' \
25 | --seq_len $seq_len \
26 | --label_len $label_len \
27 | --pred_len $pred_len \
28 | --e_layers 2 \
29 | --d_layers 1 \
30 | --factor 3 \
31 | --enc_in 7 \
32 | --dec_in 7 \
33 | --c_out 7 \
34 | --des 'Exp' \
35 | --d_model 32 \
36 | --d_ff 32 \
37 | --top_k 3 \
38 | --conv_channel 32 \
39 | --skip_channel 32 \
40 | --dropout 0.3 \
41 | --batch_size 32 \
42 | --itr 1 #>logs/ETTm2/$model_name'_'ETTm2_$seq_len'_'$pred_len.log
43 |
44 |
45 | pred_len=192
46 | python -u run_longExp.py \
47 | --is_training 1 \
48 | --root_path ./dataset/ \
49 | --data_path ETTm2.csv \
50 | --model_id ETTm2'_'$seq_len'_'$pred_len \
51 | --model $model_name \
52 | --data ETTm2 \
53 | --features M \
54 | --target 'OT' \
55 | --seq_len $seq_len \
56 | --label_len $label_len \
57 | --pred_len $pred_len \
58 | --e_layers 2 \
59 | --d_layers 1 \
60 | --factor 3 \
61 | --enc_in 7 \
62 | --dec_in 7 \
63 | --c_out 7 \
64 | --des 'Exp' \
65 | --d_model 32 \
66 | --d_ff 64 \
67 | --top_k 3 \
68 | --conv_channel 32 \
69 | --skip_channel 32 \
70 | --dropout 0.3 \
71 | --batch_size 32 \
72 | --itr 1 #>logs/ETTm2/$model_name'_'ETTm2_$seq_len'_'$pred_len.log
73 |
74 |
75 | pred_len=336
76 | python -u run_longExp.py \
77 | --is_training 1 \
78 | --root_path ./dataset/ \
79 | --data_path ETTm2.csv \
80 | --model_id ETTm2'_'$seq_len'_'$pred_len \
81 | --model $model_name \
82 | --data ETTm2 \
83 | --features M \
84 | --target 'OT' \
85 | --seq_len $seq_len \
86 | --label_len $label_len \
87 | --pred_len $pred_len \
88 | --e_layers 2 \
89 | --d_layers 1 \
90 | --factor 3 \
91 | --enc_in 7 \
92 | --dec_in 7 \
93 | --c_out 7 \
94 | --des 'Exp' \
95 | --d_model 32 \
96 | --d_ff 32 \
97 | --top_k 3 \
98 | --conv_channel 32 \
99 | --skip_channel 32 \
100 | --dropout 0.3 \
101 | --batch_size 32 \
102 | --itr 1 #>logs/ETTm2/$model_name'_'ETTm2_$seq_len'_'$pred_len.log
103 |
104 |
105 | pred_len=720
106 | python -u run_longExp.py \
107 | --is_training 1 \
108 | --root_path ./dataset/ \
109 | --data_path ETTm2.csv \
110 | --model_id ETTm2'_'$seq_len'_'$pred_len \
111 | --model $model_name \
112 | --data ETTm2 \
113 | --features M \
114 | --target 'OT' \
115 | --seq_len $seq_len \
116 | --label_len $label_len \
117 | --pred_len $pred_len \
118 | --e_layers 2 \
119 | --d_layers 1 \
120 | --factor 3 \
121 | --enc_in 7 \
122 | --dec_in 7 \
123 | --c_out 7 \
124 | --des 'Exp' \
125 | --d_model 32 \
126 | --d_ff 64 \
127 | --top_k 3 \
128 | --conv_channel 32 \
129 | --skip_channel 32 \
130 | --dropout 0.3 \
131 | --batch_size 32 \
132 | --itr 1 #>logs/ETTm2/$model_name'_'ETTm2_$seq_len'_'$pred_len.log
133 |
--------------------------------------------------------------------------------
/scripts/Flight.sh:
--------------------------------------------------------------------------------
1 | if [ ! -d "./logs" ]; then
2 | mkdir ./logs
3 | fi
4 |
5 | if [ ! -d "./logs/Flight" ]; then
6 | mkdir ./logs/Flight
7 | fi
8 |
9 | export CUDA_VISIBLE_DEVICES=2
10 |
11 | seq_len=96
12 | label_len=48
13 | model_name=MSGNet
14 |
15 | for pred_len in 96 192 336 720
16 | do
17 | python -u run_longExp.py \
18 | --is_training 1 \
19 | --root_path ./dataset/ \
20 | --data_path Flight.csv \
21 | --model_id Flight'_'$seq_len'_'$pred_len \
22 | --model $model_name \
23 | --data custom \
24 | --features M \
25 | --freq h \
26 | --target 'UUEE' \
27 | --seq_len $seq_len \
28 | --label_len $label_len \
29 | --pred_len $pred_len \
30 | --e_layers 2 \
31 | --d_layers 1 \
32 | --factor 3 \
33 | --enc_in 7 \
34 | --dec_in 7 \
35 | --c_out 7 \
36 | --des 'Exp' \
37 | --itr 1 \
38 | --d_model 16 \
39 | --d_ff 32 \
40 | --top_k 5 \
41 | --conv_channel 32 \
42 | --skip_channel 32 \
43 | --node_dim 100 \
44 | --batch_size 32 #>logs/Flight/$model_name'_'Flight_$seq_len'_'$pred_len.log
45 |
46 | done
47 |
48 |
--------------------------------------------------------------------------------
/scripts/electricity.sh:
--------------------------------------------------------------------------------
1 | if [ ! -d "./logs" ]; then
2 | mkdir ./logs
3 | fi
4 |
5 | if [ ! -d "./logs/electricity" ]; then
6 | mkdir ./logs/electricity
7 | fi
8 |
9 | export CUDA_VISIBLE_DEVICES=3
10 |
11 | seq_len=96
12 | label_len=48
13 | model_name=MSGNet
14 |
15 | pred_len=96
16 | python -u run_longExp.py \
17 | --is_training 1 \
18 | --root_path ./dataset/ \
19 | --data_path electricity.csv \
20 | --model_id electricity'_'$seq_len'_'$pred_len \
21 | --model $model_name \
22 | --data custom \
23 | --features M \
24 | --freq h \
25 | --target 'OT' \
26 | --seq_len $seq_len \
27 | --label_len $label_len \
28 | --pred_len $pred_len \
29 | --e_layers 2 \
30 | --d_layers 1 \
31 | --factor 3 \
32 | --enc_in 321 \
33 | --dec_in 321 \
34 | --c_out 321 \
35 | --des 'Exp' \
36 | --d_model 1024 \
37 | --d_ff 512 \
38 | --top_k 5 \
39 | --conv_channel 16 \
40 | --skip_channel 32 \
41 | --node_dim 100 \
42 | --batch_size 32 \
43 | --itr 1 #>logs/electricity/$model_name'_'electricity_$seq_len'_'$pred_len.log
44 |
45 | pred_len=192
46 | python -u run_longExp.py \
47 | --is_training 1 \
48 | --root_path ./dataset/ \
49 | --data_path electricity.csv \
50 | --model_id electricity'_'$seq_len'_'$pred_len \
51 | --model $model_name \
52 | --data custom \
53 | --features M \
54 | --freq h \
55 | --target 'OT' \
56 | --seq_len $seq_len \
57 | --label_len $label_len \
58 | --pred_len $pred_len \
59 | --e_layers 2 \
60 | --d_layers 1 \
61 | --factor 3 \
62 | --enc_in 321 \
63 | --dec_in 321 \
64 | --c_out 321 \
65 | --des 'Exp' \
66 | --d_model 1024 \
67 | --d_ff 512 \
68 | --top_k 5 \
69 | --conv_channel 16 \
70 | --skip_channel 32 \
71 | --node_dim 100 \
72 | --batch_size 32 \
73 | --itr 1 #>logs/electricity/$model_name'_'electricity_$seq_len'_'$pred_len.log
74 |
75 | pred_len=336
76 | python -u run_longExp.py \
77 | --is_training 1 \
78 | --root_path ./dataset/ \
79 | --data_path electricity.csv \
80 | --model_id electricity'_'$seq_len'_'$pred_len \
81 | --model $model_name \
82 | --data custom \
83 | --features M \
84 | --freq h \
85 | --target 'OT' \
86 | --seq_len $seq_len \
87 | --label_len $label_len \
88 | --pred_len $pred_len \
89 | --e_layers 3 \
90 | --d_layers 1 \
91 | --factor 3 \
92 | --enc_in 321 \
93 | --dec_in 321 \
94 | --c_out 321 \
95 | --des 'Exp' \
96 | --d_model 1024 \
97 | --d_ff 512 \
98 | --top_k 5 \
99 | --conv_channel 16 \
100 | --skip_channel 32 \
101 | --node_dim 100 \
102 | --batch_size 32 \
103 | --itr 1 #>logs/electricity/$model_name'_'electricity_$seq_len'_'$pred_len.log
104 |
105 |
106 | pred_len=720
107 | python -u run_longExp.py \
108 | --is_training 1 \
109 | --root_path ./dataset/ \
110 | --data_path electricity.csv \
111 | --model_id electricity'_'$seq_len'_'$pred_len \
112 | --model $model_name \
113 | --data custom \
114 | --features M \
115 | --freq h \
116 | --target 'OT' \
117 | --seq_len $seq_len \
118 | --label_len $label_len \
119 | --pred_len $pred_len \
120 | --e_layers 3 \
121 | --d_layers 1 \
122 | --factor 3 \
123 | --enc_in 321 \
124 | --dec_in 321 \
125 | --c_out 321 \
126 | --des 'Exp' \
127 | --d_model 1024 \
128 | --d_ff 512 \
129 | --top_k 5 \
130 | --conv_channel 16 \
131 | --skip_channel 32 \
132 | --node_dim 100 \
133 | --batch_size 32 \
134 | --itr 1 #>logs/electricity/$model_name'_'electricity_$seq_len'_'$pred_len.log
--------------------------------------------------------------------------------
/scripts/exchange.sh:
--------------------------------------------------------------------------------
1 | if [ ! -d "./logs" ]; then
2 | mkdir ./logs
3 | fi
4 |
5 | if [ ! -d "./logs/exchange" ]; then
6 | mkdir ./logs/exchange
7 | fi
8 |
9 | export CUDA_VISIBLE_DEVICES=2
10 |
11 | seq_len=96
12 | label_len=48
13 | model_name=MSGNet
14 |
15 | pred_len=96
16 | python -u run_longExp.py \
17 | --is_training 1 \
18 | --root_path ./dataset/ \
19 | --data_path exchange_rate.csv \
20 | --model_id exchange'_'$seq_len'_'$pred_len \
21 | --model $model_name \
22 | --data custom \
23 | --features M \
24 | --freq h \
25 | --target 'OT' \
26 | --seq_len $seq_len \
27 | --label_len $label_len \
28 | --pred_len $pred_len \
29 | --e_layers 2 \
30 | --d_layers 1 \
31 | --factor 3 \
32 | --enc_in 8 \
33 | --dec_in 8 \
34 | --c_out 8 \
35 | --des 'Exp' \
36 | --d_model 64 \
37 | --d_ff 128 \
38 | --top_k 3 \
39 | --dropout 0.2 \
40 | --conv_channel 16 \
41 | --skip_channel 32 \
42 | --batch_size 32 \
43 | --itr 1 #>logs/exchange/$model_name'_'exchange_$seq_len'_'$pred_len.log
44 |
45 |
46 | pred_len=192
47 | python -u run_longExp.py \
48 | --is_training 1 \
49 | --root_path ./dataset/ \
50 | --data_path exchange_rate.csv \
51 | --model_id exchange'_'$seq_len'_'$pred_len \
52 | --model $model_name \
53 | --data custom \
54 | --features M \
55 | --freq h \
56 | --target 'OT' \
57 | --seq_len $seq_len \
58 | --label_len $label_len \
59 | --pred_len $pred_len \
60 | --e_layers 2 \
61 | --d_layers 1 \
62 | --factor 3 \
63 | --enc_in 8 \
64 | --dec_in 8 \
65 | --c_out 8 \
66 | --des 'Exp' \
67 | --d_model 64 \
68 | --d_ff 128 \
69 | --top_k 5 \
70 | --node_dim 30 \
71 | --conv_channel 16 \
72 | --skip_channel 32 \
73 | --batch_size 32 \
74 | --itr 1 #>logs/exchange/$model_name'_'exchange_$seq_len'_'$pred_len.log
75 |
76 |
77 | pred_len=336
78 | python -u run_longExp.py \
79 | --is_training 1 \
80 | --root_path ./dataset/ \
81 | --data_path exchange_rate.csv \
82 | --model_id exchange'_'$seq_len'_'$pred_len \
83 | --model $model_name \
84 | --data custom \
85 | --features M \
86 | --freq h \
87 | --target 'OT' \
88 | --seq_len $seq_len \
89 | --label_len $label_len \
90 | --pred_len $pred_len \
91 | --e_layers 2 \
92 | --d_layers 1 \
93 | --factor 3 \
94 | --enc_in 8 \
95 | --dec_in 8 \
96 | --c_out 8 \
97 | --des 'Exp' \
98 | --d_model 64 \
99 | --d_ff 128 \
100 | --top_k 5 \
101 | --node_dim 30 \
102 | --conv_channel 16 \
103 | --skip_channel 32 \
104 | --batch_size 32 \
105 | --itr 1 #>logs/exchange/$model_name'_'exchange_$seq_len'_'$pred_len.log
106 |
107 |
108 | pred_len=720
109 | python -u run_longExp.py \
110 | --is_training 1 \
111 | --root_path ./dataset/ \
112 | --data_path exchange_rate.csv \
113 | --model_id exchange'_'$seq_len'_'$pred_len \
114 | --model $model_name \
115 | --data custom \
116 | --features M \
117 | --freq h \
118 | --target 'OT' \
119 | --seq_len $seq_len \
120 | --label_len $label_len \
121 | --pred_len $pred_len \
122 | --e_layers 2 \
123 | --d_layers 1 \
124 | --factor 3 \
125 | --enc_in 8 \
126 | --dec_in 8 \
127 | --c_out 8 \
128 | --des 'Exp' \
129 | --d_model 64 \
130 | --d_ff 128 \
131 | --top_k 5 \
132 | --conv_channel 16 \
133 | --skip_channel 32 \
134 | --batch_size 32 \
135 | --itr 1 #>logs/exchange/$model_name'_'exchange_$seq_len'_'$pred_len.log
--------------------------------------------------------------------------------
/scripts/weather.sh:
--------------------------------------------------------------------------------
1 | if [ ! -d "./logs" ]; then
2 | mkdir ./logs
3 | fi
4 |
5 | if [ ! -d "./logs/weather" ]; then
6 | mkdir ./logs/weather
7 | fi
8 | export CUDA_VISIBLE_DEVICES=2
9 |
10 | seq_len=96
11 | label_len=48
12 | model_name=MSGNet
13 |
14 | pred_len=96
15 | python -u run_longExp.py \
16 | --is_training 1 \
17 | --root_path ./dataset/ \
18 | --data_path weather.csv \
19 | --model_id weather'_'$seq_len'_'$pred_len \
20 | --model $model_name \
21 | --data custom \
22 | --features M \
23 | --freq h \
24 | --target 'OT' \
25 | --seq_len $seq_len \
26 | --label_len $label_len \
27 | --pred_len $pred_len \
28 | --e_layers 2 \
29 | --d_layers 1 \
30 | --factor 3 \
31 | --enc_in 21 \
32 | --dec_in 21 \
33 | --c_out 21 \
34 | --des 'Exp' \
35 | --d_model 64 \
36 | --d_ff 128 \
37 | --top_k 5 \
38 | --conv_channel 32 \
39 | --skip_channel 32 \
40 | --batch_size 32 \
41 | --train_epochs 3 \
42 | --itr 1 #>logs/weather/$model_name'_'weather_$seq_len'_'$pred_len.log
43 |
44 | pred_len=192
45 | python -u run_longExp.py \
46 | --is_training 1 \
47 | --root_path ./dataset/ \
48 | --data_path weather.csv \
49 | --model_id weather'_'$seq_len'_'$pred_len \
50 | --model $model_name \
51 | --data custom \
52 | --features M \
53 | --freq h \
54 | --target 'OT' \
55 | --seq_len $seq_len \
56 | --label_len $label_len \
57 | --pred_len $pred_len \
58 | --e_layers 2 \
59 | --d_layers 1 \
60 | --factor 3 \
61 | --enc_in 21 \
62 | --dec_in 21 \
63 | --c_out 21 \
64 | --des 'Exp' \
65 | --d_model 64 \
66 | --d_ff 128 \
67 | --top_k 5 \
68 | --conv_channel 32 \
69 | --skip_channel 32 \
70 | --batch_size 32 \
71 | --itr 1 #>logs/weather/$model_name'_'weather_$seq_len'_'$pred_len.log
72 |
73 | pred_len=336
74 | python -u run_longExp.py \
75 | --is_training 1 \
76 | --root_path ./dataset/ \
77 | --data_path weather.csv \
78 | --model_id weather'_'$seq_len'_'$pred_len \
79 | --model $model_name \
80 | --data custom \
81 | --features M \
82 | --freq h \
83 | --target 'OT' \
84 | --seq_len $seq_len \
85 | --label_len $label_len \
86 | --pred_len $pred_len \
87 | --e_layers 1 \
88 | --d_layers 1 \
89 | --factor 3 \
90 | --enc_in 21 \
91 | --dec_in 21 \
92 | --c_out 21 \
93 | --des 'Exp' \
94 | --d_model 64 \
95 | --d_ff 128 \
96 | --top_k 5 \
97 | --conv_channel 32 \
98 | --skip_channel 32 \
99 | --batch_size 32 \
100 | --itr 1 #>logs/weather/$model_name'_'weather_$seq_len'_'$pred_len.log
101 |
102 | pred_len=720
103 | python -u run_longExp.py \
104 | --is_training 1 \
105 | --root_path ./dataset/ \
106 | --data_path weather.csv \
107 | --model_id weather'_'$seq_len'_'$pred_len \
108 | --model $model_name \
109 | --data custom \
110 | --features M \
111 | --freq h \
112 | --target 'OT' \
113 | --seq_len $seq_len \
114 | --label_len $label_len \
115 | --pred_len $pred_len \
116 | --e_layers 2 \
117 | --d_layers 1 \
118 | --factor 3 \
119 | --enc_in 21 \
120 | --dec_in 21 \
121 | --c_out 21 \
122 | --des 'Exp' \
123 | --d_model 64 \
124 | --d_ff 128 \
125 | --top_k 5 \
126 | --conv_channel 32 \
127 | --skip_channel 32 \
128 | --batch_size 32 \
129 | --itr 1 #>logs/weather/$model_name'_'weather_$seq_len'_'$pred_len.log
130 |
131 |
132 |
--------------------------------------------------------------------------------
/utils/__pycache__/masking.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/utils/__pycache__/masking.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/masking.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/utils/__pycache__/masking.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/utils/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/utils/__pycache__/metrics.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/timefeatures.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/utils/__pycache__/timefeatures.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/timefeatures.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/utils/__pycache__/timefeatures.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/tools.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/utils/__pycache__/tools.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/tools.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YoZhibo/MSGNet/953b8330a2ca469dab4955e804b46a61eb08a9c2/utils/__pycache__/tools.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/masking.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TriangularCausalMask():
5 | def __init__(self, B, L, device="cpu"):
6 | mask_shape = [B, 1, L, L]
7 | with torch.no_grad():
8 | self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
9 |
10 | @property
11 | def mask(self):
12 | return self._mask
13 |
14 |
15 | class ProbMask():
16 | def __init__(self, B, H, L, index, scores, device="cpu"):
17 | _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
18 | _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
19 | indicator = _mask_ex[torch.arange(B)[:, None, None],
20 | torch.arange(H)[None, :, None],
21 | index, :].to(device)
22 | self._mask = indicator.view(scores.shape).to(device)
23 |
24 | @property
25 | def mask(self):
26 | return self._mask
27 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def MAE(pred, true):
4 | return np.mean(np.abs(pred - true))
5 |
6 | def MAPE(pred, true):
7 | return np.mean(np.abs((pred - true) / true))
8 |
9 | def ND(pred, true):
10 | return np.mean(np.abs(true - pred)) / np.mean(np.abs(true))
11 |
12 | def MSE(pred, true):
13 | return np.mean((pred - true) ** 2)
14 |
15 | def RMSE(pred, true):
16 | return np.sqrt(MSE(pred, true))
17 |
18 | def NRMSE(pred, true):
19 | return np.sqrt(np.mean(np.power((pred - true), 2))) / (np.mean(np.abs(true)))
20 |
21 | def RSE(pred, true):
22 | return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
23 |
24 |
25 | def CORR(pred, true):
26 | u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
27 | d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
28 | d += 1e-12
29 | return 0.01*(u / d).mean(-1)
30 |
31 |
32 | def MSPE(pred, true):
33 | return np.mean(np.square((pred - true) / true))
34 |
35 |
36 | def metric(pred, true):
37 | mae = MAE(pred, true)
38 | mse = MSE(pred, true)
39 | rmse = RMSE(pred, true)
40 | mape = MAPE(pred, true)
41 | mspe = MSPE(pred, true)
42 | rse = RSE(pred, true)
43 | corr = CORR(pred, true)
44 | nd = ND(pred,true)
45 | nrmse = NRMSE(pred,true)
46 |
47 | return mae, mse, rmse, mape, mspe, rse , corr, nd, nrmse
48 |
49 | def metric2(pred, true):
50 | mae = MAE(pred, true)
51 | mse = MSE(pred, true)
52 | rmse = RMSE(pred, true)
53 | mape = MAPE(pred, true)
54 | mspe = MSPE(pred, true)
55 | rse = RSE(pred, true)
56 | nd = ND(pred,true)
57 | nrmse = NRMSE(pred,true)
58 |
59 | return mae, mse, rmse, mape, mspe, rse , nd, nrmse
--------------------------------------------------------------------------------
/utils/timefeatures.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from pandas.tseries import offsets
6 | from pandas.tseries.frequencies import to_offset
7 |
8 |
9 | class TimeFeature:
10 | def __init__(self):
11 | pass
12 |
13 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
14 | pass
15 |
16 | def __repr__(self):
17 | return self.__class__.__name__ + "()"
18 |
19 |
20 | class SecondOfMinute(TimeFeature):
21 | """Minute of hour encoded as value between [-0.5, 0.5]"""
22 |
23 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
24 | return index.second / 59.0 - 0.5
25 |
26 |
27 | class MinuteOfHour(TimeFeature):
28 | """Minute of hour encoded as value between [-0.5, 0.5]"""
29 |
30 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
31 | return index.minute / 59.0 - 0.5
32 |
33 |
34 | class HourOfDay(TimeFeature):
35 | """Hour of day encoded as value between [-0.5, 0.5]"""
36 |
37 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
38 | return index.hour / 23.0 - 0.5
39 |
40 |
41 | class DayOfWeek(TimeFeature):
42 | """Hour of day encoded as value between [-0.5, 0.5]"""
43 |
44 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
45 | return index.dayofweek / 6.0 - 0.5
46 |
47 |
48 | class DayOfMonth(TimeFeature):
49 | """Day of month encoded as value between [-0.5, 0.5]"""
50 |
51 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
52 | return (index.day - 1) / 30.0 - 0.5
53 |
54 |
55 | class DayOfYear(TimeFeature):
56 | """Day of year encoded as value between [-0.5, 0.5]"""
57 |
58 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
59 | return (index.dayofyear - 1) / 365.0 - 0.5
60 |
61 |
62 | class MonthOfYear(TimeFeature):
63 | """Month of year encoded as value between [-0.5, 0.5]"""
64 |
65 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
66 | return (index.month - 1) / 11.0 - 0.5
67 |
68 |
69 | class WeekOfYear(TimeFeature):
70 | """Week of year encoded as value between [-0.5, 0.5]"""
71 |
72 | def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
73 | return (index.isocalendar().week - 1) / 52.0 - 0.5
74 |
75 |
76 | def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
77 | """
78 | Returns a list of time features that will be appropriate for the given frequency string.
79 | Parameters
80 | ----------
81 | freq_str
82 | Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
83 | """
84 |
85 | features_by_offsets = {
86 | offsets.YearEnd: [],
87 | offsets.QuarterEnd: [MonthOfYear],
88 | offsets.MonthEnd: [MonthOfYear],
89 | offsets.Week: [DayOfMonth, WeekOfYear],
90 | offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
91 | offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
92 | offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
93 | offsets.Minute: [
94 | MinuteOfHour,
95 | HourOfDay,
96 | DayOfWeek,
97 | DayOfMonth,
98 | DayOfYear,
99 | ],
100 | offsets.Second: [
101 | SecondOfMinute,
102 | MinuteOfHour,
103 | HourOfDay,
104 | DayOfWeek,
105 | DayOfMonth,
106 | DayOfYear,
107 | ],
108 | }
109 |
110 | offset = to_offset(freq_str)
111 |
112 | for offset_type, feature_classes in features_by_offsets.items():
113 | if isinstance(offset, offset_type):
114 | return [cls() for cls in feature_classes]
115 |
116 | supported_freq_msg = f"""
117 | Unsupported frequency {freq_str}
118 | The following frequencies are supported:
119 | Y - yearly
120 | alias: A
121 | M - monthly
122 | W - weekly
123 | D - daily
124 | B - business days
125 | H - hourly
126 | T - minutely
127 | alias: min
128 | S - secondly
129 | """
130 | raise RuntimeError(supported_freq_msg)
131 |
132 |
133 | def time_features(dates, freq='h'):
134 | return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])
135 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib.pyplot as plt
4 | import time
5 |
6 | plt.switch_backend('agg')
7 |
8 |
9 | def adjust_learning_rate(optimizer, epoch, args):
10 | # lr = args.learning_rate * (0.2 ** (epoch // 2))
11 | if args.lradj == 'type1':
12 | lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
13 | elif args.lradj == 'type2':
14 | lr_adjust = {
15 | 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
16 | 10: 5e-7, 15: 1e-7, 20: 5e-8
17 | }
18 | elif args.lradj == '3':
19 | lr_adjust = {epoch: args.learning_rate if epoch < 10 else args.learning_rate*0.1}
20 | elif args.lradj == '4':
21 | lr_adjust = {epoch: args.learning_rate if epoch < 15 else args.learning_rate*0.1}
22 | elif args.lradj == '5':
23 | lr_adjust = {epoch: args.learning_rate if epoch < 25 else args.learning_rate*0.1}
24 | elif args.lradj == '6':
25 | lr_adjust = {epoch: args.learning_rate if epoch < 5 else args.learning_rate*0.1}
26 | if epoch in lr_adjust.keys():
27 | lr = lr_adjust[epoch]
28 | for param_group in optimizer.param_groups:
29 | param_group['lr'] = lr
30 | print('Updating learning rate to {}'.format(lr))
31 |
32 |
33 | class EarlyStopping:
34 | def __init__(self, patience=7, verbose=False, delta=0):
35 | self.patience = patience
36 | self.verbose = verbose
37 | self.counter = 0
38 | self.best_score = None
39 | self.early_stop = False
40 | self.val_loss_min = np.Inf
41 | self.delta = delta
42 |
43 | def __call__(self, val_loss, model, path):
44 | score = -val_loss
45 | if self.best_score is None:
46 | self.best_score = score
47 | self.save_checkpoint(val_loss, model, path)
48 | elif score < self.best_score + self.delta:
49 | self.counter += 1
50 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
51 | if self.counter >= self.patience:
52 | self.early_stop = True
53 | else:
54 | self.best_score = score
55 | self.save_checkpoint(val_loss, model, path)
56 | self.counter = 0
57 |
58 | def save_checkpoint(self, val_loss, model, path):
59 | if self.verbose:
60 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
61 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
62 | self.val_loss_min = val_loss
63 |
64 |
65 | class dotdict(dict):
66 | """dot.notation access to dictionary attributes"""
67 | __getattr__ = dict.get
68 | __setattr__ = dict.__setitem__
69 | __delattr__ = dict.__delitem__
70 |
71 |
72 | class StandardScaler():
73 | def __init__(self, mean, std):
74 | self.mean = mean
75 | self.std = std
76 |
77 | def transform(self, data):
78 | return (data - self.mean) / self.std
79 |
80 | def inverse_transform(self, data):
81 | return (data * self.std) + self.mean
82 |
83 |
84 | def visual(true, preds=None, name='./pic/test.pdf'):
85 | """
86 | Results visualization
87 | """
88 | plt.figure()
89 | plt.plot(true, label='GroundTruth', linewidth=2)
90 | if preds is not None:
91 | plt.plot(preds, label='Prediction', linewidth=2)
92 | plt.legend()
93 | plt.show()
94 | plt.savefig(name, bbox_inches='tight')
95 |
96 | def test_params_flop(model,x_shape):
97 | """
98 | If you want to thest former's flop, you need to give default value to inputs in model.forward(), the following code can only pass one argument to forward()
99 | """
100 | model_params = 0
101 | for parameter in model.parameters():
102 | model_params += parameter.numel()
103 | print('INFO: Trainable parameter count: {:.2f}M'.format(model_params / 1000000.0))
104 | from ptflops import get_model_complexity_info
105 | with torch.cuda.device(0):
106 | macs, params = get_model_complexity_info(model.cuda(), x_shape, as_strings=True, print_per_layer_stat=True)
107 | # print('Flops:' + flops)
108 | # print('Params:' + params)
109 | print('{:<30} {:<8}'.format('Computational complexity: ', macs))
110 | print('{:<30} {:<8}'.format('Number of parameters: ', params))
--------------------------------------------------------------------------------