├── .idea
├── .gitignore
├── Meta-Learning4FSTSF.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
├── configs.py
├── core
├── base_nets.py
├── meta_nets.py
├── options.py
├── task_split.py
└── train.py
├── data
└── few_shot_data
│ ├── test_data_embedding_10.pkl
│ ├── test_data_embedding_20.pkl
│ ├── test_data_embedding_30.pkl
│ ├── test_data_embedding_40.pkl
│ ├── train_data_embedding_10.pkl
│ ├── train_data_embedding_20.pkl
│ ├── train_data_embedding_30.pkl
│ └── train_data_embedding_40.pkl
├── embedding
├── data_preprocessing.py
└── embedding.py
├── main.py
└── tools
└── tools.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
--------------------------------------------------------------------------------
/.idea/Meta-Learning4FSTSF.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 xf-git
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Meta-Learning4FSTSF
2 | Meta-Learning for Few-Shot Time Series Forecasting
3 |
4 | # Usage
5 |
6 | This section of the README walks through how to train the models.
7 |
8 | ## data prepare
9 | > data_preprocessing.py + embedding.py
10 |
11 | **notes**:
12 | The time-series data given in '/data/few_shot_data/...' already have done this step. For new raw time-series data, the two scripts can be used in this step.
13 |
14 |
15 | ## training of Base_{model}
16 | ### In this phase, a dataset is a time-series task, and each task would be training seperately.
17 |
18 | >**main.py**
19 | >>**Arguments help:**
20 |
21 | --baseNet: [mlp/cnn/lstm/cnnConlstm]
22 | --dataset: the directory of saving pre-processed time-series data
23 | --update_step_target: update times of network
24 | --fine_lr: leanring rate in this phase
25 | --ppn: predict point number [10/20/30/40]
26 | --device: [cpu/cuda]
27 | --user_id: the name of the task that will be training, it can be found in ./config.py TRAINING_TASK_SET
28 |
29 | >**training single task:**
30 |
31 | '''
32 | python main.py --baseNet [mlp/cnn/lstm/cnnConlstm] --dateset [few_shot_data/your defined data dir] --update_step_target 10 --fine_lr 0.001 --ppn [10/20/30/40] --device [cpu/cuda] --user_id 0001
33 | '''
34 |
35 | >**training all task:**
36 |
37 | '''
38 | python main.py --baseNet [mlp/cnn/lstm/cnnConlstm] --dateset few_shot_data --update_step_target 10 --fine_lr 0.001 --ppn [10/20/30/40] --device [cpu/cuda]
39 | '''
40 |
41 |
42 | ## training Meta_{model}
43 | ### In this phase, one task is selected as target task, and the remains are training-task set, firstly training baseNet using support set of training-task set, and then training MetaNet using query set of training-task set, finally using support set of target task to fine tune MetaNet.
44 |
45 | >**main.py**
46 | >>**Argument help:**
47 |
48 | --maml: using 'maml mode' to training model
49 | --update_step_train: the update times of baseNet on training-task set
50 | --update_step_target: the update times of MetaNet on target task
51 | --epoch: iteration times
52 | --base_lr: the learning rate of baseNet
53 | --meta_lr: the learning rate of MetaNet
54 | --fine_lr: the learning rate of MetaNet during fine-tuing
55 |
56 | >**training single task:**
57 |
58 | '''
59 | python main.py --baseNet [cnn/lstm/cnnConlstm] --maml --dataset few_shot_data --epoch 10 --update_step_train 10 --update_step_target 10 --base_lr 0.01 --meta_lr 0.01 --fine_lr 0.01 --ppn 10 --device [cpu/cuda] --user_id Wine
60 | '''
61 |
62 | >**training all task:**
63 |
64 | '''
65 | python main.py --baseNet [cnn/lstm/cnnConlstm] --maml --dataset few_shot_data --epoch 10 --update_step_train 10 --update_step_target 10 --base_lr 0.01 --meta_lr 0.01 --fine_lr 0.01 --ppn 10 --device [cpu/cuda]
66 | '''
67 |
68 | ## results
69 | ### All the trained models and evaluating metrics would be saved in dir ./results/
70 |
71 | ## log
72 | ## Some useful log information would be saved in dir ./log/
--------------------------------------------------------------------------------
/configs.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | __author__ = 'XF'
3 | __date__ = '2022/03/10'
4 |
5 | 'default configuration for this project'
6 |
7 | import os.path as osp
8 |
9 |
10 | BASE_DIR = osp.dirname(osp.abspath(__file__))
11 |
12 | dataName = 'data'
13 | model_save_dir = osp.join(BASE_DIR, 'model')
14 | loss_save_dir = osp.join(BASE_DIR, 'results/loss')
15 | log_save_path = osp.join(BASE_DIR, 'log/logs.txt')
16 | log_save_dir = osp.join(BASE_DIR, 'log')
17 | DATA_DIR = osp.join(BASE_DIR, 'data')
18 | MODEL_PATH = osp.join(BASE_DIR, 'results/model')
19 | console_file = osp.join(BASE_DIR, 'log/logs.txt')
20 | exp_result_dir = osp.join(BASE_DIR, 'results')
21 |
22 | MODEL_NAME = ['mlp', 'lstm', 'cnn', 'cnnConlstm', 'lstm+maml', 'cnn+maml', 'cnnConlstm+maml']
23 | MODE_NAME = ['training', 'testing', 'together']
24 |
25 | few_shot_dataset_name = [
26 | 'Beef',
27 | 'BeetleFly',
28 | 'BirdChicken',
29 | 'Car',
30 | 'Coffee',
31 | 'FaceFour',
32 | 'Herring',
33 | 'Lightning2',
34 | 'Lightning7',
35 | 'Meat',
36 | 'OliveOil',
37 | 'Rock',
38 | 'Wine'
39 | ]
40 |
41 | TRAINING_TASK_SET = [
42 | '0001',
43 | '0002',
44 | '0003',
45 | '0004',
46 | '0005',
47 | '0006',
48 | '0007',
49 | '0008',
50 | '0009',
51 | '0010',
52 | '0011',
53 | '0012',
54 | '0013',
55 | '0014',
56 | '0015',
57 | '0016',
58 | '0022',
59 | '0023',
60 | '0024',
61 | '0025',
62 | '0026',
63 | '0029',
64 | '0030',
65 | '0031',
66 | '0032',
67 | '0037',
68 | '0046',
69 | '0047',
70 | '0048',
71 | '0049',
72 | '0050',
73 | '0051',
74 | '0054',
75 | '0055',
76 | '0056',
77 | '0066',
78 | '0069',
79 | '0070',
80 | '0071',
81 | '0082',
82 | '0085',
83 | '0088',
84 | '0089',
85 | '0090',
86 | '0091',
87 | '0092',
88 | '0093',
89 | '0094',
90 | '0095',
91 | '0096',
92 | '0097',
93 | '0098',
94 | '0099',
95 | '0100',
96 | '0102',
97 | '0103',
98 | '0104',
99 | '0106',
100 | '0107',
101 | '0108',
102 | '0110',
103 | '0111',
104 | '0112',
105 | '0113',
106 | '0114',
107 | '0115',
108 | '0116',
109 | '0118',
110 | '0119',
111 | '0120',
112 | '0121',
113 | '0122',
114 | '0123',
115 | '0124',
116 | '0125',
117 | '0126',
118 | '0127',
119 | '0128',
120 | '0129',
121 | '0130',
122 | '0131',
123 | '0132',
124 | '0133',
125 | '0134',
126 | '0135',
127 | '0136',
128 | '0137',
129 | '0138',
130 | '0139',
131 | '0140',
132 | '0141',
133 | '0142',
134 | '0143',
135 | '0144',
136 | '0145',
137 | '0146',
138 | '0147',
139 | '0148',
140 | '0149',
141 | '0150',
142 | '0151',
143 | '0152',
144 | '0153',
145 | '0154',
146 | '0155',
147 | '0156',
148 | 'Beef',
149 | 'BeetleFly',
150 | 'BirdChicken',
151 | 'Car',
152 | 'Coffee',
153 | 'FaceFour',
154 | 'Herring',
155 | 'Lightning2',
156 | 'Lightning7',
157 | 'Meat',
158 | 'OliveOil',
159 | 'Rock',
160 | 'Wine'
161 | ]
162 |
163 | if __name__ == '__main__':
164 |
165 | print(BASE_DIR)
166 | pass
167 |
--------------------------------------------------------------------------------
/core/base_nets.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | __author__ = 'XF'
3 | __date__ = '2022/03/10'
4 |
5 | 'base networks'
6 |
7 | # built-in library
8 | import os
9 | import math
10 | import sys
11 | from copy import deepcopy
12 |
13 | # third-party library
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn import functional as F
17 |
18 | # self-defined library
19 | from tools.tools import metrics as Metrics
20 |
21 | torch.set_default_tensor_type(torch.DoubleTensor)
22 |
23 |
24 | class MLP(nn.Module):
25 |
26 | def __init__(self, n_input, n_hidden, n_output):
27 | super(MLP, self).__init__()
28 | self.name = 'MLP'
29 |
30 | self.hidden_size = n_hidden
31 | # this list contains all tensor needed to be optimized
32 | self.params = nn.ParameterList()
33 |
34 | # linear input layer
35 | weight = nn.Parameter(torch.ones(n_hidden, n_input))
36 | bias = nn.Parameter(torch.zeros(n_hidden))
37 |
38 | self.params.extend([weight, bias])
39 |
40 | # linear output layer
41 | weight = nn.Parameter(torch.ones(n_output, n_hidden))
42 | bias = nn.Parameter(torch.zeros(n_output))
43 |
44 | self.params.extend([weight, bias])
45 |
46 | self.init()
47 |
48 | def parameters(self):
49 | return self.params
50 |
51 | def init(self):
52 | stdv = 1.0 / math.sqrt(self.hidden_size)
53 | for weight in self.parameters():
54 | weight.data.uniform_(-stdv, stdv)
55 |
56 | def forward(self, x, vars=None):
57 |
58 | if vars is None:
59 | params = self.params
60 | else:
61 | params = vars
62 |
63 | # input layer
64 | (weight_input, bias_input) = (params[0].to(x.device), params[1].to(x.device))
65 | x = F.linear(x, weight_input, bias_input)
66 |
67 | # output layer
68 | (weight_output, bias_output) = (params[2].to(x.device), params[3].to(x.device))
69 | out = F.linear(x, weight_output, bias_output)
70 |
71 | return out
72 |
73 |
74 | class BaseCNN(nn.Module):
75 |
76 | def __init__(self, output=10):
77 | super(BaseCNN, self).__init__()
78 | self.name = 'BASECNN'
79 | self.output = output
80 |
81 | # this list contains all tensor needed to be optimized
82 | self.vars = nn.ParameterList()
83 |
84 | # running_mean and running var
85 | self.vars_bn = nn.ParameterList()
86 |
87 | # 填充需要训练的网络的参数
88 |
89 | # Conv1d layer
90 | # [channel_out, channel_in, kernel-size]
91 | weight = nn.Parameter(torch.ones(64, 1, 3))
92 |
93 | nn.init.kaiming_normal_(weight)
94 |
95 | bias = nn.Parameter(torch.zeros(64))
96 |
97 | self.vars.extend([weight, bias])
98 |
99 | # linear layer
100 | weight = nn.Parameter(torch.ones(self.output, 64*100))
101 | bias = nn.Parameter(torch.zeros(self.output))
102 |
103 | self.vars.extend([weight, bias])
104 |
105 | def forward(self, x, vars=None, bn_training=True):
106 |
107 | '''
108 |
109 | :param x: [batch size, 1, 3, 94]
110 | :param vars:
111 | :param bn_training: set false to not update
112 | :return:
113 | '''
114 |
115 | if vars is None:
116 | vars = self.vars
117 |
118 | # x = x.squeeze(dim=2)
119 | # x = x.unsqueeze(dim=1)
120 | # Conv1d layer
121 | weight, bias = vars[0].to(x.device), vars[1].to(x.device)
122 | # x ==> (batch size, 1, 200)
123 |
124 | x = F.conv1d(x, weight, bias, stride=1, padding=1) # ==>(batch size, 64, 200)
125 | x = F.relu(x, inplace=True) # ==> (batch_size, 64, 200)
126 | x = F.max_pool1d(x, kernel_size=2) # ==> (batch_size, 64, 100)
127 |
128 | # linear layer
129 | x = x.view(x.size(0), -1) # flatten ==> (batch_size, 16*12)
130 | weight, bias = vars[-2].to(x.device), vars[-1].to(x.device)
131 | x = F.linear(x, weight, bias)
132 |
133 | return x
134 |
135 | def parameters(self):
136 | return self.vars
137 |
138 | def zero_grad(self):
139 | pass
140 |
141 | pass
142 |
143 |
144 | class BaseLSTM(nn.Module):
145 |
146 | def __init__(self, n_features, n_hidden, n_output, n_layer=1):
147 | super().__init__()
148 | self.name = 'BaseLSTM'
149 |
150 | # this list contains all tensor needed to be optimized
151 | self.params = nn.ParameterList()
152 |
153 | self.input_size = n_features
154 | # print(n_features)
155 | self.hidden_size = n_hidden
156 | self.output_size = n_output
157 | self.layer_size = n_layer
158 |
159 | # 输入层
160 | W_i = nn.Parameter(torch.Tensor(self.hidden_size * 4, self.input_size))
161 | bias_i = nn.Parameter(torch.Tensor(self.hidden_size * 4))
162 | self.params.extend([W_i, bias_i])
163 |
164 | # 隐含层
165 | W_h = nn.Parameter(torch.Tensor(self.hidden_size * 4, self.hidden_size))
166 | bias_h = nn.Parameter(torch.Tensor(self.hidden_size * 4))
167 | self.params.extend([W_h, bias_h])
168 |
169 | if self.layer_size > 1:
170 | for _ in range(self.layer_size - 1):
171 |
172 | # 第i层lstm
173 | # 输入层
174 | W_i = nn.Parameter(torch.Tensor(self.hidden_size * 4, self.hidden_size))
175 | bias_i = nn.Parameter(torch.Tensor(self.hidden_size * 4))
176 | self.params.extend([W_i, bias_i])
177 | # 隐含层
178 | W_h = nn.Parameter(torch.Tensor(self.hidden_size * 4, self.hidden_size))
179 | bias_h = nn.Parameter(torch.Tensor(self.hidden_size * 4))
180 | self.params.extend([W_h, bias_h])
181 |
182 |
183 | # 输出层
184 | W_linear = nn.Parameter(torch.Tensor(self.output_size, self.hidden_size))
185 | bias_linear = nn.Parameter(torch.Tensor(self.output_size))
186 | self.params.extend([W_linear, bias_linear])
187 |
188 | self.init()
189 | pass
190 |
191 | def parameters(self):
192 | return self.params
193 |
194 | def init(self):
195 | stdv = 1.0 / math.sqrt(self.hidden_size)
196 | for weight in self.parameters():
197 | weight.data.uniform_(-stdv, stdv)
198 |
199 | def forward(self, x, vars=None, init_state=None):
200 |
201 | if vars is None:
202 | params = self.params
203 | else:
204 | params = vars
205 |
206 | # assume the shape of x is (batch_size, time_size, feature_size)
207 |
208 | batch_size, time_size, _ = x.size()
209 | hidden_seq = []
210 | # with torch.autograd.set_detect_anomaly(True):
211 | if init_state is None:
212 | h_t, c_t = (
213 | torch.zeros(batch_size, self.hidden_size).to(x.device),
214 | torch.zeros(batch_size, self.hidden_size).to(x.device)
215 | )
216 | else:
217 | h_t, c_t = init_state
218 |
219 | HS = self.hidden_size
220 |
221 | for t in range(time_size):
222 | x_t = x[:, t, :]
223 | W_i, bias_i = (params[0].to(x.device), params[1].to(x.device))
224 | W_h, bias_h = (params[2].to(x.device), params[3].to(x.device))
225 |
226 | # gates = x_t @ W_i + h_t @ W_h + bias_h + bias_i
227 | gates = F.linear(x_t, W_i, bias_i) + F.linear(h_t, W_h, bias_h)
228 |
229 | i_t, f_t, g_t, o_t = (
230 | torch.sigmoid(gates[:, :HS]), # input
231 | torch.sigmoid(gates[:, HS:HS * 2]), # forget
232 | torch.tanh(gates[:, HS * 2:HS * 3]),
233 | torch.sigmoid(gates[:, HS * 3:]) # output
234 | )
235 | c_t = f_t * c_t + i_t * g_t
236 | h_t = o_t * torch.tanh(c_t)
237 | hidden_seq.append(h_t)
238 |
239 | W_linear, bias_linear = (params[-2].to(x.device), params[-1].to(x.device))
240 | out = F.linear(hidden_seq[-1], W_linear, bias_linear)
241 | # out = hidden_seq[-1] @ W_linear + bias_linear
242 | return out
243 |
244 |
245 | class BaseCNNConLSTM(nn.Module):
246 |
247 | def __init__(self, n_features, n_hidden, n_output, n_layer=1, time_size=1, cnn_feature=200):
248 | super(BaseCNNConLSTM, self).__init__()
249 | self.name = 'BaseCNNConLSTM'
250 | self.time_size = time_size
251 |
252 | # this list contain all tensor needed to be optimized
253 | self.params = nn.ParameterList()
254 | self.cnn = BaseCNN(output=cnn_feature)
255 | self.lstm = BaseLSTM(n_features=n_features, n_hidden=n_hidden, n_output=n_output, n_layer=n_layer)
256 | self.cnn_tensor_num = 0
257 | self.lstm_tensor_num = 0
258 | self.init()
259 |
260 | def init(self):
261 |
262 | self.cnn_tensor_num = len(self.cnn.parameters())
263 | self.lstm_tensor_num = len(self.lstm.parameters())
264 | for param in self.cnn.parameters():
265 | self.params.append(param)
266 | for param in self.lstm.parameters():
267 | self.params.append(param)
268 |
269 | def sequence(self, data):
270 |
271 | dim_1, dim_2 = data.shape
272 | new_dim_1 = dim_1 - self.time_size + 1
273 |
274 | x = torch.zeros((new_dim_1, self.time_size, dim_2))
275 |
276 | for i in range(dim_1 - self.time_size + 1):
277 | x[i] = data[i: i + self.time_size]
278 | return x.to(data.device)
279 |
280 | def forward(self, x, vars=None, init_states=None):
281 |
282 | if vars is None:
283 | params = self.params
284 | else:
285 | params = vars
286 |
287 | x = self.cnn(x, params[: self.cnn_tensor_num])
288 | x = x.unsqueeze(dim=1)
289 | output = self.lstm(x, params[self.cnn_tensor_num:], init_states)
290 | return output
291 | pass
292 | pass
293 |
--------------------------------------------------------------------------------
/core/meta_nets.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | __author__ = 'XF'
3 | __date___ = '2022/03/10'
4 |
5 | 'meta networks'
6 |
7 | # built-in library
8 | import os.path as osp
9 | from copy import deepcopy
10 |
11 | # third-party library
12 | import torch
13 | import torch.nn as nn
14 | from torch.nn import functional as F
15 |
16 | # self-defined library
17 | from tools.tools import metrics as Metrics, generate_filename
18 | from configs import MODEL_PATH
19 |
20 | torch.set_default_tensor_type(torch.DoubleTensor)
21 |
22 |
23 | class MetaNet(nn.Module):
24 |
25 | def __init__(self, baseNet=None, update_step_train=10, update_step_target=20, meta_lr=0.001, base_lr=0.01, fine_lr=0.01):
26 |
27 | super(MetaNet, self).__init__()
28 | self.update_step_train = update_step_train
29 | self.update_step_target = update_step_target
30 | self.meta_lr = meta_lr
31 | self.base_lr = base_lr
32 | self.fine_tune_lr = fine_lr
33 |
34 | if baseNet is not None:
35 | self.net = baseNet
36 | else:
37 | raise Exception('baseNet is None')
38 | self.meta_optim = torch.optim.Adam(self.net.parameters(), lr=self.meta_lr)
39 | # self.meta_optim = torch.optim.SGD(self.net.parameters(), lr=self.meta_lr)
40 | pass
41 |
42 | def save_model(self, model_name='model'):
43 | torch.save(self.net, osp.join(MODEL_PATH, generate_filename('pth',*[model_name,])))
44 |
45 | def forward(self, spt_x, spt_y, qry_x, qry_y, device='cpu'):
46 | '''
47 |
48 | :param spt_x: if baseNet is cnn: [ spt size, in_channel, height, width], lstm [spt_size, time_size, feature_size]
49 | :param spt_y: [ spt size]
50 | :param qry_x: if baseNet is cnn: [ qry size, in_channel, height, width], lstm [qry size, time_size, feature_size]
51 | :param qry_y: [ qry size]
52 | :param min_max_data_path: 用来进行数据反归一化的min,max值的存储路径
53 | :return:
54 | batch size 在本任务中设置为1, 即每次采样一个任务进行训练
55 | '''
56 |
57 | # spt_size, channel, height, width = spt_x.size()
58 | # qry_size = spt_y.size(0)
59 | task_num = len(spt_x)
60 | loss_list_qry = []
61 | mape_list = []
62 | rmse_list = []
63 | smape_list = []
64 | qry_loss_sum = 0
65 | # print('更新任务网络===============================================')
66 | # 第0步更新
67 | for i in range(task_num):
68 | x_spt = torch.from_numpy(spt_x[i]).to(device)
69 | y_spt = torch.from_numpy(spt_y[i]).to(device)
70 | x_qry = torch.from_numpy(qry_x[i]).to(device)
71 | y_qry = torch.from_numpy(qry_y[i]).to(device)
72 |
73 | y_hat = self.net(x_spt, vars=None)
74 | loss = F.mse_loss(y_hat, y_spt)
75 | grad = torch.autograd.grad(loss, self.net.parameters())
76 | grads_params = zip(grad, self.net.parameters()) # 将梯度和参数一一对应起来
77 |
78 | # fast_weights 这一步相当于求了一个 theta - alpha * nabla(L)
79 | fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], grads_params))
80 |
81 | # 在query集上测试,计算准确率
82 | # 使用更新后的参数在query集上测试
83 | with torch.no_grad():
84 | y_hat = self.net(x_qry, fast_weights)
85 | loss_qry = F.mse_loss(y_hat, y_qry)
86 | loss_list_qry.append(loss_qry)
87 |
88 | # 计算评价指标
89 | rmse, mape, smape = Metrics(y_qry, y_hat)
90 |
91 | rmse_list.append(rmse)
92 | mape_list.append(mape)
93 | smape_list.append(smape)
94 |
95 | for step in range(1, self.update_step_train):
96 | y_hat = self.net(x_spt, fast_weights)
97 | loss = F.mse_loss(y_hat, y_spt)
98 | grad = torch.autograd.grad(loss, fast_weights)
99 | grads_params = zip(grad, fast_weights)
100 | fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], grads_params))
101 |
102 | if step < self.update_step -1:
103 | with torch.no_grad():
104 | y_hat = self.net(x_qry, fast_weights)
105 | loss_qry = F.mse_loss(y_hat, y_qry)
106 | loss_list_qry.append(loss_qry)
107 | else:
108 | y_hat = self.net(x_qry, fast_weights)
109 | loss_qry = F.mse_loss(y_hat, y_qry)
110 | loss_list_qry.append(loss_qry)
111 | qry_loss_sum += loss_qry
112 |
113 | with torch.no_grad():
114 | rmse, mape, smape = Metrics(y_qry, y_hat)
115 |
116 | rmse_list.append(rmse)
117 | mape_list.append(mape)
118 | smape_list.append(smape)
119 | pass
120 |
121 | # 更新元网络
122 | loss_qry = qry_loss_sum / task_num # 表示在经过update_step之后,learner在当前任务query set上的损失
123 | self.meta_optim.zero_grad() # 梯度清零
124 | loss_qry.backward()
125 | self.meta_optim.step()
126 |
127 |
128 | return {
129 | 'loss': loss_list_qry[-1].item(),
130 | 'rmse': rmse_list[-1],
131 | 'mape': mape_list[-1],
132 | 'smape': smape_list[-1]
133 | }
134 |
135 | def fine_tuning(self, spt_x, spt_y, qry_x, qry_y, naive=False):
136 |
137 | '''
138 |
139 | :param spt_x: if baseNet is cnn:[set size, channel, height, width] if baseNet is lstm: [batch_size, seq_size, feature_size]
140 | :param spt_y:
141 | :param qry_x:
142 | :param qry_y:
143 | :return:
144 | '''
145 |
146 | # 评价指标
147 | loss_qry_list = []
148 | rmse_list = []
149 | mape_list = []
150 | smape_list = []
151 | min_loss = 0
152 | best_epoch = 0
153 | min_train_loss = 1000000
154 | loss_set = {
155 | 'train_loss': [],
156 | 'validation_loss': []
157 | }
158 |
159 | # new_net = deepcopy(self.net)
160 | # new_net = self.net
161 | y_hat = self.net(spt_x)
162 | # with torch.autograd.set_detect_anomaly(True):
163 | loss = F.mse_loss(y_hat, spt_y)
164 | loss_set['train_loss'].append(loss.item())
165 | if loss.item() < min_train_loss:
166 | min_train_loss = loss.item()
167 | grad = torch.autograd.grad(loss, self.net.parameters())
168 | grads_params = zip(grad, self.net.parameters())
169 | fast_weights = list(map(lambda p: p[1] - self.fine_tune_lr * p[0], grads_params))
170 |
171 | # 在query集上测试,计算评价指标
172 | # 使用更新后的参数进行测试
173 | with torch.no_grad():
174 | y_hat = self.net(qry_x, fast_weights)
175 | loss_qry = F.mse_loss(y_hat, qry_y)
176 | loss_set['validation_loss'].append(loss_qry.item())
177 | loss_qry_list.append(loss_qry)
178 | # 计算评价指标mape
179 | rmse, mape, smape = Metrics(qry_y, y_hat)
180 |
181 | rmse_list.append(rmse)
182 | mape_list.append(mape)
183 | smape_list.append(smape)
184 | min_rmse = rmse
185 | min_mape = mape
186 | min_smape = smape
187 | min_loss = loss_qry.item()
188 | rmse_best_epoch = 1
189 | mape_best_epoch = 1
190 | smape_best_epcoh = 1
191 |
192 | if naive:
193 | print(' Epoch [1] | train_loss: %.4f | test_loss: %.4f | rmse: %.4f | mape: %.4f | smape: %.4f |'
194 | % (loss.item(), loss_qry.item(), rmse, mape, smape))
195 |
196 | for step in range(1, self.update_step_target):
197 | y_hat = self.net(spt_x, fast_weights)
198 | loss = F.mse_loss(y_hat, spt_y)
199 | loss_set['train_loss'].append(loss.item())
200 | if loss.item() < min_train_loss:
201 | min_train_loss = loss.item()
202 | grad = torch.autograd.grad(loss, fast_weights)
203 | grads_params = zip(grad, fast_weights)
204 | fast_weights = list(map(lambda p: p[1] - self.fine_tune_lr * p[0], grads_params))
205 |
206 | # 在query测试
207 | with torch.no_grad():
208 | # 计算评价指标
209 | y_hat = self.net(qry_x, fast_weights)
210 | loss_qry = F.mse_loss(y_hat, qry_y)
211 | loss_set['validation_loss'].append(loss_qry.item())
212 | loss_qry_list.append(loss_qry)
213 |
214 | rmse, mape, smape = Metrics(qry_y, y_hat)
215 |
216 | rmse_list.append(rmse)
217 | mape_list.append(mape)
218 | smape_list.append(smape)
219 | if min_rmse > rmse:
220 | min_rmse = rmse
221 | rmse_best_epoch = step + 1
222 | if min_smape > smape:
223 | min_smape = smape
224 | smape_best_epcoh = step + 1
225 | min_rmse = rmse
226 | self.save_model(model_name=self.net.name)
227 | print(' Epoch [%d] | train_loss: %.4f | test_loss: %.4f | rmse: %.4f | smape: %.4f |'
228 | % (step + 1, loss.item(), loss_qry.item(), rmse, smape))
229 |
230 | return {
231 | 'test_loss': min_loss,
232 | 'train_loss': min_train_loss,
233 | 'rmse': min_rmse,
234 | 'mape': min_mape,
235 | 'smape': min_smape,
236 | 'rmse_best_epoch': rmse_best_epoch,
237 | 'mape_best_epoch': mape_best_epoch,
238 | 'smape_best_epoch': smape_best_epcoh,
239 | 'loss_set': loss_set
240 | }
241 | pass
--------------------------------------------------------------------------------
/core/options.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | __author__ = 'XF'
3 | __date__ = '2022/03/10'
4 |
5 | 'options'
6 |
7 | # built-in library
8 | import os
9 | import os.path as osp
10 | import argparse
11 | from builtins import print as b_print
12 |
13 | # self-defined library
14 | from configs import model_save_dir, loss_save_dir, log_save_path, log_save_dir, MODEL_NAME, MODE_NAME, exp_result_dir
15 | from tools.tools import generate_filename
16 |
17 |
18 | def task_id_int2str(int_id):
19 |
20 | if int_id < 10:
21 | str_id = '000' + str(int_id)
22 | elif int_id < 100:
23 | str_id = '00' + str(int_id)
24 | elif int_id < 1000:
25 | str_id = '0' + str(int_id)
26 | else:
27 | str_id = str(int_id)
28 |
29 | return str_id
30 |
31 |
32 | def print(*args, file='./log.txt', end='\n', terminate=True):
33 |
34 | with open(file=file, mode='a', encoding='utf-8') as console:
35 | b_print(*args, file=console, end=end)
36 | if terminate:
37 | b_print(*args, end=end)
38 |
39 |
40 | def parse_args(script='main'):
41 |
42 | parser = argparse.ArgumentParser(description='Time Seriess Forecasting script %s.py' % script)
43 |
44 | # training arguments
45 | parser.add_argument('--model', default=None,
46 | help='the model name is used to train.[mlp, lstm, cnn, cnnConlstm, lstm+maml, cnn+maml, cnnConlstm+maml]')
47 | parser.add_argument('--epoch', type=int, default=1, help='the iteration number for training data.')
48 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate.')
49 |
50 | # data arguments
51 | parser.add_argument('--dataset', default='few_shot_data', help='the data path for training and testing.')
52 | parser.add_argument('--ratio', type=float, default=0.8, help='the ratio of training set for all data set')
53 | parser.add_argument('--trainSet', default='', help='the path of the training data.')
54 | parser.add_argument('--testSet', default='', help='the path of the testing data.')
55 | parser.add_argument('--UCR', action='store_true', default=False, help='for UCR data.')
56 | parser.add_argument('--ppn', type=int, default=10, help='predict point num.')
57 |
58 | # save-path arguments
59 | parser.add_argument('--msd', default=model_save_dir, help='the model save dir.')
60 | parser.add_argument('--lsd', default=loss_save_dir, help='the loss save dir.')
61 | parser.add_argument('--log', default=log_save_path, help='the log save path.')
62 | parser.add_argument('--maml_log', default=log_save_path, help='the log save path.')
63 | parser.add_argument('--rmse_path', default=log_save_path, help='the log save metric rmse.')
64 | parser.add_argument('--mape_path', default=log_save_path, help='the log save metric mape.')
65 | parser.add_argument('--smape_path', default=log_save_path, help='the log save metric smape.')
66 | # the arguments for LSTM model
67 | parser.add_argument('--time_size', type=int, default=1, help='the time_size for lstm input')
68 |
69 | # the arguments for cnn model
70 | parser.add_argument('--add_dim_pos', type=int, default=1, help='the position for add dimension when change sequence data to img data')
71 |
72 | # the implement mode
73 | parser.add_argument('--mode', default='together',
74 | help='the implement mode for script, [training, testing, together] can be chosen')
75 |
76 | # for testing mode
77 | parser.add_argument('--model_state', default='', help='the path of trained model')
78 |
79 | # for maml
80 | parser.add_argument('--user_id', type=str, default='none', help='the id of true target task')
81 | parser.add_argument('--update_step_train', type=int, default=5, help='the train task update step')
82 | parser.add_argument('--update_step_target', type=int, default=50, help='the target task update step')
83 | parser.add_argument('--meta_lr', type=float, default=1e-4, help='the learning rate of meta network')
84 | parser.add_argument('--base_lr', type=float, default=1e-3, help='the learning rate of base network')
85 | parser.add_argument('--fine_lr', type=float, default=0.03, help='the learning rate of fine tune target network')
86 | parser.add_argument('--baseNet', default='cnn', help='the base network for maml training, [lstm, cnn, cnnConlstm] can be chosen')
87 | parser.add_argument('--maml', action='store_true', default=False, help='whether using maml algorithm to train a network.')
88 | parser.add_argument('--all_data_dir', default=None, help='the directory for all load data.')
89 | parser.add_argument('--begin_task', type=int, default=1, help='the begining task id that be used to batch training when having maml')
90 | parser.add_argument('--end_task', type=int, default=12, help='the ending task id that be used to batch training when having maml.')
91 | parser.add_argument('--batch_task_num', type=int, default=5, help='batch training for maml')
92 | # for hardware setting
93 | parser.add_argument('--device', default='cuda', help='the calculate device for torch Tensor, [cpu, cuda] can be chosen')
94 |
95 | # for new settings
96 | parser.add_argument('--new_settings', action='store_true', default=False, help='training scheme in new settings.')
97 | parser.add_argument('--ft_step', type=int, default=100, help='epoch number of fine-tuning in new settings')
98 |
99 | params = parser.parse_args()
100 |
101 | # maml log
102 | if params.maml:
103 | maml_log_path = osp.join(log_save_dir, generate_filename('.txt', *['log'], timestamp=True))
104 | params.maml_log = maml_log_path
105 | params.log = maml_log_path
106 | params_show(params)
107 |
108 | # 动态生成日志文件
109 | log_path = osp.join('./log', generate_filename('.txt', *['log'], timestamp=True))
110 | params.log = log_path
111 | params_show(params)
112 |
113 | # 生成实验结果日志
114 | if params.new_settings:
115 | result_dir_name = params.baseNet + '_' + str(params.ppn) + 'new_settings'
116 | else:
117 | result_dir_name = params.baseNet + '_' + str(params.ppn)
118 | if params.maml:
119 | result_dir_name = 'M_' + result_dir_name
120 |
121 | result_dir = osp.join(exp_result_dir, result_dir_name)
122 | if not osp.exists(result_dir):
123 | os.mkdir(result_dir)
124 | params.rmse_path = osp.join(result_dir, generate_filename('.txt', *['rmse'], timestamp=True))
125 | params.mape_path = osp.join(result_dir, generate_filename('.txt', *['mape'], timestamp=True))
126 | params.smape_path = osp.join(result_dir, generate_filename('.txt', *['smape'], timestamp=True))
127 |
128 | # 参数合法性检查
129 |
130 | if params.mode == 'together':
131 | assert float(0) < params.ratio < float(1) # 对数据集拆分比例的检查
132 | elif params.mode == 'training':
133 | assert osp.exists(params.trainSet)
134 | elif params.mode == 'testing':
135 | assert osp.exists(params.testSet) and osp.exists(params.model_state)
136 | else:
137 | raise Exception('Unknown implement mode: %s' % params.mode)
138 |
139 | if params.baseNet in MODEL_NAME:
140 | assert params.epoch > 0 and isinstance(params.epoch, int) # 对epoch的检查
141 | # if params.model[:4] == 'lstm' and params.model[-4:] == 'lstm': # if model is 'lstm', the time_size id needed parameters
142 | # assert params.time_size > 0 and isinstance(params.time_size, int)
143 | else:
144 | raise Exception('Unknown model name: %s' % params.baseNet)
145 |
146 | return params
147 |
148 |
149 | def params_show(params):
150 |
151 | if params:
152 | print('Parameters Show', file=params.log)
153 | print('=======================================', file=params.log)
154 | print('About model:', file=params.log)
155 | print(' model: %s' % params.baseNet, file=params.log)
156 | # print(' epoch: %s' % params.epoch, file=params.log)
157 |
158 | # print(' learning rate: %s' % str(params.lr), file=params.log)
159 | # if params.model == 'lstm':
160 | # print(' time size: %d' % params.time_size, file=params.log)
161 | # if params.mode == 'testing':
162 | # print(' trained model path: %s' % params.model_state, file=params.log)
163 | print('About data:', file=params.log)
164 | print(' data file: %s' % params.dataset, file=params.log)
165 | # print(' training data file: %s' % params.trainSet, file=params.log)
166 | # print(' testing data file: %s' % params.testSet, file=params.log)
167 | # print(' data split rate: %s' % params.ratio, file=params.log)
168 | print(' predict point num: %d' % params.ppn, file=params.log)
169 | # print('implement mode: %s' % params.mode, file=params.log)
170 |
171 | if params.maml:
172 | print('=======================================', file=params.log)
173 | print('MAML Show', file=params.log)
174 | print('=======================================', file=params.log)
175 | if params.user_id != '0':
176 | #target_task = task_id_int2str(params.user_id)
177 | target_task = params.user_id
178 | print(' target task: %s' % target_task, file=params.log)
179 | else:
180 | begin_task = task_id_int2str(params.begin_task)
181 | end_task = task_id_int2str(params.end_task)
182 | print(' begin task: %s' % begin_task, file=params.log)
183 | print(' end task: %s' % end_task, file=params.log)
184 | print(' update step train: %d' % params.update_step, file=params.log)
185 | print(' update step target: %d' % params.update_step_test, file=params.log)
186 | print(' meta lr: %.4f' % params.meta_lr, file=params.log)
187 | print(' base lr: %.4f' % params.base_lr, file=params.log)
188 | print(' fine lr: %.4f' % params.fine_lr, file=params.log)
189 | print(' device: %s' % params.device, file=params.log)
190 | print('=======================================', file=params.log)
191 | else:
192 | raise Exception('params is None!', file=params.log)
193 | pass
194 |
195 |
196 |
197 |
198 |
199 | if __name__ == '__main__':
200 | pass
201 |
202 |
203 |
204 |
205 |
206 |
--------------------------------------------------------------------------------
/core/task_split.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | __author__ = 'XF'
3 | __date__ = '2022/03/10'
4 |
5 | 'task split'
6 |
7 | # built-in library
8 | import os.path as osp
9 |
10 | # third-party library
11 | from torch import from_numpy
12 | import numpy as np
13 |
14 | # self-defined tools
15 | from configs import DATA_DIR as data_dir
16 | from tools.tools import obj_unserialization
17 | from configs import TRAINING_TASK_SET
18 |
19 |
20 | class LoadData:
21 |
22 | def __init__(self, maml, test_user_index, add_dim_pos, data_path=None, ppn=None):
23 |
24 | '''
25 | :param maml: 是否使用 maml algorithm to train
26 | :param test_user_index: 选择一个用户测试,其他用户用来训练meta-network的参数
27 | :param add_dim_pos: 给tensor增加维度的位置, 当add_dim_pos=-1时,不加维度
28 | :param data_path: 原始数据路径
29 | ppn: predict point num
30 | '''
31 |
32 | if data_path:
33 | self.indexes = {'train': int(0), 'test': int(0)}
34 | self.task_id = {'train': [], 'test': []}
35 | self.features = None
36 | self.task_num = None
37 | self.datasets_cache = None
38 | self.output = 94
39 | self.outputs = []
40 | if maml:
41 | self.datasets_cache = self.UCR_data_cache_maml(data_path, test_user_index, add_dim_pos, ppn)
42 | else:
43 | self.UCR_data_cache(data_path, add_dim_pos, ppn)
44 | else:
45 | raise Exception('data dir is None!')
46 | pass
47 |
48 | def UCR_data_cache(self, data_path, add_dim_pos, ppn):
49 |
50 | train_data_path = osp.join(data_path, 'train_data_embedding_%s.pkl' % str(ppn))
51 | test_data_path = osp.join(data_path, 'test_data_embedding_%s.pkl' % str(ppn))
52 | self.datasets_cache = {'train': [], 'test': []}
53 |
54 | train_data = obj_unserialization(train_data_path)
55 | test_data = obj_unserialization(test_data_path)
56 | for key in train_data.keys():
57 | train_x, train_y, _ = self.split_x_y(train_data[key], add_dim_pos, ppn=ppn)
58 | test_x, test_y, _ = self.split_x_y(test_data[key], add_dim_pos, ppn=ppn)
59 | self.datasets_cache['train'].append([train_x, train_y])
60 | self.datasets_cache['test'].append([test_x, test_y])
61 | self.task_id['train'].append(key)
62 | self.task_id['test'].append(key)
63 |
64 | self.task_num = len(self.datasets_cache['train'])
65 |
66 | pass
67 |
68 | def UCR_data_cache_maml(self, data_path, test_task_index, add_dim_pos, ppn):
69 |
70 | datasets_cache = {'train': [], 'test': []}
71 | train_data_path = osp.join(data_path, 'train_data_embedding_%s.pkl' % str(ppn))
72 | test_data_path = osp.join(data_path, 'test_data_embedding_%s.pkl' % str(ppn))
73 |
74 | train_data = obj_unserialization(train_data_path)
75 | test_data = obj_unserialization(test_data_path)
76 |
77 | # print(test_task_index)
78 | for i, data_name in enumerate(TRAINING_TASK_SET):
79 | if test_task_index == data_name:
80 | test_task_index = i + 1
81 | # print(test_task_index)
82 | for i, key in enumerate(train_data.keys()):
83 | if i == test_task_index - 1:
84 | test_spt_x, test_spt_y, _ = self.split_x_y(train_data[key], add_dim_pos, ppn=ppn)
85 | test_qry_x, test_qry_y, _ = self.split_x_y(test_data[key], add_dim_pos, ppn=ppn)
86 | datasets_cache['test'].append([test_spt_x, test_spt_y, test_qry_x, test_qry_y])
87 | self.task_id['test'].append(key)
88 | continue
89 |
90 | train_spt_x, train_spt_y, _ = self.split_x_y(train_data[key], add_dim_pos, ppn=ppn)
91 | train_qry_x, train_qry_y, _ = self.split_x_y(test_data[key], add_dim_pos, ppn=ppn)
92 | datasets_cache['train'].append([train_spt_x, train_spt_y, train_qry_x, train_qry_y])
93 | self.task_id['train'].append(key)
94 |
95 | self.task_num = len(datasets_cache['train'])
96 |
97 | return datasets_cache
98 |
99 | def get_data(self, task_id=None):
100 |
101 | if task_id in self.task_id['train']:
102 |
103 | # task_id = task_id_int2str(task_id)
104 | pos_train = self.task_id['train'].index(task_id)
105 | pos_test = self.task_id['test'].index(task_id)
106 | return self.datasets_cache['train'][pos_train], self.datasets_cache['test'][pos_test], task_id
107 | else:
108 | raise Exception('Unknown the task id [%s]!' % task_id)
109 | pass
110 |
111 | def split_spt_qry(self, data, rate):
112 |
113 | spt = []
114 | qry = []
115 | for task in data:
116 | pos = int(len(task) * rate)
117 | spt.append(task[:pos])
118 | qry.append(task[pos:])
119 | # print('split train & val:')
120 | # print('train: %d, %d' % (len(spt[0]), len(spt[0][0])))
121 | # print('val: %d, %d' % (len(qry[0]), len(qry[0][0])))
122 | return spt, qry
123 |
124 | def split_x_y(self, data, add_dim_pos, ppn=10):
125 |
126 | forecast_point_num = ppn
127 | position = len(data[0]) - forecast_point_num
128 | xs = np.array(data)[:, :position]
129 | ys = np.array(data)[:, position:]
130 | # ==================================================== #
131 | # update time: 2021-12-10
132 | if add_dim_pos == -1:
133 | np_xs = np.array(xs)
134 | else:
135 | np_xs = from_numpy(xs).unsqueeze(dim=1).numpy()
136 | self.features = len(xs[0])
137 | # ==================================================== #
138 | self.output = forecast_point_num
139 | self.outputs.append(forecast_point_num)
140 | # print('split x & y')
141 | # print(xs.shape)
142 | # print(ys.shape)
143 |
144 | return np_xs, np.array(ys), self.features
145 |
146 | def next(self, mode='train'):
147 | if self.indexes[mode] == len(self.datasets_cache[mode]):
148 | self.indexes[mode] = 0
149 |
150 | next_batch = self.datasets_cache[mode][self.indexes[mode]]
151 | task_id = self.task_id[mode][self.indexes[mode]]
152 | self.indexes[mode] += 1
153 | return next_batch, task_id
154 |
155 |
156 | if __name__ == '__main__':
157 |
158 | pass
--------------------------------------------------------------------------------
/core/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | __author__ = 'XF'
3 | __date__ = '2022/03/10'
4 |
5 | 'train base network + maml'
6 |
7 | # built-in library
8 | import os
9 | import os.path as osp
10 | import time
11 | import copy
12 | from datetime import datetime
13 |
14 | # third-party library
15 | import torch
16 | import numpy as np
17 |
18 | # self-defined library
19 | from core.base_nets import MLP, BaseLSTM, BaseCNN, BaseCNNConLSTM
20 | from core.meta_nets import MetaNet
21 | from core.task_split import LoadData
22 | from configs import TRAINING_TASK_SET
23 | from tools.tools import generate_filename, obj_serialization
24 | from core.options import print as write_log
25 |
26 | torch.set_default_tensor_type(torch.DoubleTensor)
27 |
28 |
29 |
30 | def train(epoch_num, test_user_index, add_dim_pos, data_path, update_step_train,
31 | update_step_target, meta_lr, base_lr, fine_lr, device, baseNet,
32 | maml, log, maml_log, lsd, ppn, rmse_path, mape_path,
33 | smape_path, ft_step, batch_task_num=5, new_settings=False):
34 |
35 | # 设置随机数种子,保证运行结果可复现
36 | torch.manual_seed(1)
37 | np.random.seed(1)
38 | if device == 'cuda':
39 | torch.cuda.manual_seed_all(1)
40 |
41 | device = torch.device(device) # 选择 torch.Tensor 进行运算时的设备对象['cpu', 'cuda']
42 | if baseNet == 'mlp':
43 | #print('no add dimension')
44 | add_dim_pos = -1 # 不给tensor添加维度
45 |
46 | # get data
47 | data = LoadData(maml, test_user_index, add_dim_pos, data_path=data_path, ppn=ppn)
48 |
49 | if baseNet == 'cnn':
50 | BaseNet = BaseCNN(output=data.output)
51 | elif baseNet == 'lstm':
52 | BaseNet = BaseLSTM(n_features=data.features, n_hidden=100, n_output=data.output)
53 | elif baseNet == 'cnnConlstm':
54 | BaseNet = BaseCNNConLSTM(n_features=data.features, n_hidden=100, n_output=data.output, cnn_feature=200)
55 | # ========================================= #
56 | # update time: 2021-12-10
57 | elif baseNet == 'mlp':
58 | # print(data.features, data.output)
59 | BaseNet = MLP(n_input=data.features, n_hidden=100, n_output=data.output)
60 | # ======================================== #
61 | else:
62 | raise Exception('Unknown baseNet: %s' % baseNet)
63 |
64 |
65 | metaNet = MetaNet(
66 | baseNet=BaseNet,
67 | update_step_train=update_step_train,
68 | update_step_target=update_step_target,
69 | meta_lr=meta_lr,
70 | base_lr=base_lr,
71 | fine_lr=fine_lr
72 | ).to(device)
73 |
74 | training_result = {
75 | 'target task': None,
76 | 'qry_loss': None,
77 | 'rmse': None,
78 | 'mape': None,
79 | 'smape': None,
80 | 'rmse_best_epoch': None,
81 | 'mape_best_epoch': None,
82 | 'smape_best_epoch': None,
83 | 'training time:': None,
84 | 'date': None,
85 | 'log': log,
86 | 'maml_log': None
87 | }
88 |
89 | # training
90 | start = time.time()
91 | step = 0
92 | batch_num = 0
93 | train_loss = {}
94 | test_loss = {}
95 | # print(data.task_num)
96 |
97 | if maml:
98 | while step < epoch_num:
99 |
100 | (spt_x, spt_y), (qry_x, qry_y), task_id = batch_task(data, batch_task_num=batch_task_num)
101 |
102 | batch_num += 1
103 |
104 | print('[%d]===================== training 元网络========================= :[%s]' % (step, task_id))
105 | metrics = metaNet(spt_x, spt_y, qry_x, qry_y,device=device)
106 | print('| train_task: %s | qry_loss: %.4f | qry_rmse: %.4f | qry_mape: %.4f | qry_smape: %.4f |'
107 | % (task_id, metrics['loss'], metrics['rmse'], metrics['mape'], metrics['smape']))
108 |
109 |
110 | if batch_num % (data.task_num // batch_task_num) == 0:
111 | step += 1
112 | (spt_x, spt_y, qry_x, qry_y), task_id = data.next('test')
113 |
114 | spt_x, spt_y, qry_x, qry_y = torch.from_numpy(spt_x).to(device), \
115 | torch.from_numpy(spt_y).to(device), \
116 | torch.from_numpy(qry_x).to(device), \
117 | torch.from_numpy(qry_y).to(device)
118 | print('===================== fine tuning 目标网络 ========================= :[%s]' % task_id)
119 | metrics = metaNet.fine_tuning(spt_x, spt_y, qry_x, qry_y)
120 |
121 | if training_result['qry_loss'] is None:
122 | training_result['target task'] = task_id
123 | training_result['qry_loss'] = metrics['test_loss']
124 | training_result['mape'] = metrics['mape']
125 | training_result['smape'] = metrics['smape']
126 | training_result['rmse'] = metrics['rmse']
127 | training_result['rmse_best_epoch'] = metrics['rmse_best_epoch']
128 | training_result['mape_best_epoch'] = metrics['mape_best_epoch']
129 | training_result['smape_best_epoch'] = metrics['smape_best_epoch']
130 | else:
131 | if training_result['rmse'] > metrics['rmse']:
132 | training_result['qry_loss'] = metrics['test_loss']
133 | training_result['rmse'] = metrics['rmse']
134 | training_result['rmse_best_epoch'] = metrics['rmse_best_epoch']
135 | if training_result['mape'] > metrics['mape']:
136 | training_result['mape'] = metrics['mape']
137 | training_result['mape_best_epoch'] = metrics['mape_best_epoch']
138 | if training_result['smape'] > metrics['smape']:
139 | training_result['smape'] = metrics['smape']
140 | training_result['smape_best_epoch'] = metrics['smape_best_epoch']
141 | write_log(
142 | 'Epoch [%d] | '
143 | 'target_task_id: %s | '
144 | 'qry_loss: %.4f | '
145 | 'rmse: %.4f(%d) | '
146 | 'smape: %.4f(%d) |'
147 | % (
148 | step,
149 | task_id,
150 | metrics['test_loss'],
151 | metrics['rmse'], metrics['rmse_best_epoch'],
152 | metrics['smape'], metrics['smape_best_epoch']
153 | ),
154 | file=log
155 | )
156 | else:
157 |
158 | # updated date: 2021-10-20
159 | if not new_settings:
160 | (train_x, train_y), (test_x, test_y), task_id = data.get_data(test_user_index)
161 | else:
162 | (train_x, train_y), (test_x, test_y), task_id = new_settings_get_training_data(data,test_user_index)
163 | # ============================================= #
164 |
165 |
166 | spt_x, spt_y, qry_x, qry_y = torch.from_numpy(train_x).to(device), \
167 | torch.from_numpy(train_y).to(device), \
168 | torch.from_numpy(test_x).to(device), \
169 | torch.from_numpy(test_y).to(device)
170 |
171 | print('===================== training %s ========================= :[%s]' % (baseNet, task_id))
172 | metrics = metaNet.fine_tuning(spt_x, spt_y, qry_x, qry_y, naive=True)
173 | # ============================================= #
174 | # updated date: 2021-10-20
175 | if new_settings:
176 | print('=====================new settings finetuning=====================')
177 | (train_x, train_y), (test_x, test_y), task_id = new_settings_get_finetune_data(data,test_user_index)
178 | spt_x, spt_y, qry_x, qry_y = torch.from_numpy(train_x).to(device), \
179 | torch.from_numpy(train_y).to(device), \
180 | torch.from_numpy(test_x).to(device), \
181 | torch.from_numpy(test_y).to(device)
182 | metaNet.update_step_target = ft_step
183 | metrics = metaNet.fine_tuning(spt_x, spt_y, qry_x, qry_y, naive=True)
184 | # ============================================= #
185 |
186 | save_loss(lsd, metrics['loss_set'], *[task_id, baseNet, 'loss_set'])
187 | train_loss.setdefault(task_id, metrics['train_loss'])
188 | test_loss.setdefault(task_id, metrics['test_loss'])
189 | write_log(
190 | 'target_task_id: %s | '
191 | 'spt_loss: %.4f |'
192 | 'qry_loss: %.4f | '
193 | 'rmse: %.4f(%d)| '
194 | 'smape: %.4f(%d)|'
195 | % (
196 | task_id,
197 | metrics['train_loss'],
198 | metrics['test_loss'],
199 | metrics['rmse'], metrics['rmse_best_epoch'],
200 | metrics['smape'], metrics['smape_best_epoch']
201 | ),
202 | file=log
203 | )
204 | # save metrics rmse, mape, smape
205 | write_log('%.4f (%d)' % (metrics['rmse'], metrics['rmse_best_epoch']), file=rmse_path, terminate=False)
206 | write_log('%.4f (%d)' % (metrics['mape'], metrics['mape_best_epoch']), file=mape_path, terminate=False)
207 | write_log('%.4f (%d)' % (metrics['smape'], metrics['smape_best_epoch']), file=smape_path, terminate=False)
208 | pass
209 | end = time.time()
210 | # save loss
211 | save_loss(lsd, train_loss, *[baseNet, 'train', 'loss'])
212 | save_loss(lsd, test_loss, *[baseNet, 'test', 'loss'])
213 | training_result['training time'] = '%s Min' % str((end - start) / 60)
214 | training_result['date'] = datetime.strftime(datetime.now(), '%Y/%m/%d %H:%M:%S')
215 | if maml:
216 | training_result['maml_log'] = maml_log
217 | train_result(training_result, file=maml_log)
218 | # save metrics rmse, mape, smape
219 | write_log('%.4f (%d)' % (training_result['rmse'], training_result['rmse_best_epoch']), file=rmse_path, terminate=False)
220 | write_log('%.4f (%d)' % (training_result['mape'], training_result['mape_best_epoch']), file=mape_path, terminate=False)
221 | write_log('%.4f (%d)' % (training_result['smape'], training_result['smape_best_epoch']), file=smape_path, terminate=False)
222 | else:
223 | # train_result(training_result, file=log)
224 | pass
225 |
226 | return metrics['smape'], metrics['rmse']
227 |
228 |
229 | def save_loss(lsd, obj, *others):
230 |
231 | if not osp.exists(lsd):
232 | os.mkdir(lsd)
233 | loss_path = osp.join(lsd, generate_filename('.pkl', *others, timestamp=False))
234 | obj_serialization(loss_path, obj)
235 | print('loss serialization is finished!')
236 |
237 |
238 | def train_result(data_dict, file='./log.txt'):
239 | write_log('training result:============================', file=file)
240 | for key, value in data_dict.items():
241 | write_log(' %s: %s' % (key, value), file=file)
242 | write_log('============================================', file=file)
243 |
244 |
245 | def batch_task(data, batch_task_num=1, ablation=1):
246 |
247 | # abalation == 1: means that uses all tasks as training task set
248 | # abalation == 0: menas that only uses UCR tasks as training task set
249 |
250 | (spt_x, spt_y, qry_x, qry_y), task_id = data.next('train')
251 | train_x = list([spt_x])
252 | train_y = list([spt_y])
253 | test_x = list([qry_x])
254 | test_y = list([qry_y])
255 |
256 | if ablation == 1:
257 | while batch_task_num > 1:
258 | (x1, y1, x2, y2), temp = data.next('train')
259 | train_x.append(x1)
260 | train_y.append(y1)
261 | test_x.append(x2)
262 | test_y.append(y2)
263 | task_id += ('-' + temp)
264 | batch_task_num -= 1
265 | elif ablation == 0:
266 | while batch_task_num > 1:
267 | (x1, y1, x2, y2), temp = data.next('train')
268 | if temp.isdigit():
269 | continue
270 | train_x.append(x1)
271 | train_y.append(y1)
272 | test_x.append(x2)
273 | test_y.append(y2)
274 | task_id += ('-' + temp)
275 | batch_task_num -= 1
276 | else:
277 | raise Exception('UnKnown abalaion code: [%d]' % ablation)
278 |
279 | return (train_x, train_y), (test_x, test_y), task_id
280 | # ==================================================================================== #
281 | # updated date: 2021-10-20
282 | # in light of reviewer's suggestions, add a group of experiments settings
283 |
284 | def new_settings_get_training_data(data, target_task):
285 |
286 | # data: UCR
287 | # 将除target_task以外的其他的数据集打包到一起进行训练
288 |
289 | dataset_list = copy.deepcopy(TRAINING_TASK_SET)
290 | dataset_list.remove(target_task)
291 |
292 | (train_x, train_y), (test_x, test_y), _ = data.get_data(dataset_list[0])
293 | #print(dataset_list[0])
294 | #print(train_x.shape, train_y.shape, test_x.shape, test_y.shape)
295 | for dataset in dataset_list[1:]:
296 | # print(dataset)
297 | (temp_1, temp_2), (temp_3, temp_4), _ = data.get_data(dataset)
298 | # print(temp_1.shape, temp_2.shape, temp_3.shape, temp_4.shape)
299 | train_x = np.concatenate((train_x, temp_1), axis=0)
300 | train_y = np.concatenate((train_y, temp_2), axis=0)
301 | test_x = np.concatenate((test_x, temp_3), axis=0)
302 | test_y = np.concatenate((test_y, temp_4), axis=0)
303 | # print(train_x.shape, train_y.shape, test_x.shape, test_y.shape)
304 | return (train_x, train_y), (test_x, test_y), target_task
305 |
306 |
307 | def new_settings_get_finetune_data(data, target_task):
308 |
309 | # 用target task 进行微调
310 |
311 | return data.get_data(target_task)
312 |
313 | # ================================================================================== #
314 |
315 |
316 |
317 | if __name__ == '__main__':
318 |
319 | pass
320 |
321 |
--------------------------------------------------------------------------------
/data/few_shot_data/test_data_embedding_10.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/test_data_embedding_10.pkl
--------------------------------------------------------------------------------
/data/few_shot_data/test_data_embedding_20.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/test_data_embedding_20.pkl
--------------------------------------------------------------------------------
/data/few_shot_data/test_data_embedding_30.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/test_data_embedding_30.pkl
--------------------------------------------------------------------------------
/data/few_shot_data/test_data_embedding_40.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/test_data_embedding_40.pkl
--------------------------------------------------------------------------------
/data/few_shot_data/train_data_embedding_10.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/train_data_embedding_10.pkl
--------------------------------------------------------------------------------
/data/few_shot_data/train_data_embedding_20.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/train_data_embedding_20.pkl
--------------------------------------------------------------------------------
/data/few_shot_data/train_data_embedding_30.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/train_data_embedding_30.pkl
--------------------------------------------------------------------------------
/data/few_shot_data/train_data_embedding_40.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/train_data_embedding_40.pkl
--------------------------------------------------------------------------------
/embedding/data_preprocessing.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | __author__ = 'XF'
3 | __date__ = '2022-07-11'
4 |
5 | '''
6 | the scripts for time series data processing.
7 | '''
8 |
9 | import os
10 | import os.path as osp
11 | import torch
12 | import numpy as np
13 | from collections import OrderedDict
14 | from configs import DATA_DIR as DATADIR
15 | from configs import few_shot_dataset_name
16 | from tools.tools import obj_serialization, read_tsv, obj_unserialization, generate_filename
17 | from sklearn import preprocessing
18 |
19 |
20 | def few_shot_data(path=None):
21 |
22 | if path is None:
23 | raise Exception('The parameter "path" is None!')
24 | dataset_file_names = os.listdir(path)
25 | train_dataset = OrderedDict()
26 | test_dataset = OrderedDict()
27 | process_traindata_num = 0
28 | process_testdata_num = 0
29 | for dir in dataset_file_names:
30 | dataset_dir = osp.join(path, dir)
31 | if os.path.isdir(dataset_dir):
32 | train_file_path = osp.join(dataset_dir, '%s_TRAIN.tsv' % dir)
33 | test_file_path = osp.join(dataset_dir, '%s_TEST.tsv' % dir)
34 | if os.path.isfile(train_file_path):
35 | train_dataset.setdefault(dir, read_tsv(train_file_path).loc[:, 1:].values.astype(np.float64))
36 | process_traindata_num += 1
37 | else:
38 | print('"%s" is not a file!' % train_file_path)
39 | if os.path.isfile(test_file_path):
40 | test_dataset.setdefault(dir, read_tsv(test_file_path).loc[:, 1:].values.astype(np.float64))
41 | process_testdata_num += 1
42 | else:
43 | print('"%s" is not a file!' % test_file_path)
44 |
45 | obj_serialization(osp.join(DATADIR, 'train_data.pkl'), train_dataset)
46 | obj_serialization(osp.join(DATADIR, 'test_data.pkl'), test_dataset)
47 | print('train_process_num: %d' % process_traindata_num)
48 | print('test_process_num: %d' % process_testdata_num)
49 |
50 |
51 | def split_data(data_path=None, ratio=0.1, shuffle=False, data=None):
52 |
53 | if data is None:
54 | data = obj_unserialization(data_path)
55 | if int(ratio) >= len(data):
56 | return [], []
57 | if 0 < ratio < 1:
58 | train_data_size = int(len(data) * ratio)
59 | elif 1 <= ratio < len(data):
60 | train_data_size = int(ratio)
61 | else:
62 | raise Exception('Invalid value about "ratio" --> [%s]' % str(ratio))
63 |
64 | if train_data_size == 0:
65 | val_data = data
66 | train_data = []
67 | else:
68 | train_data = data[:train_data_size]
69 | val_data = data[train_data_size:]
70 | return train_data, val_data
71 |
72 |
73 | def create_sequence(data, ratio=0.1):
74 |
75 | forecast_point_num = int(len(data[0]) * ratio)
76 | position = len(data[0]) - forecast_point_num
77 | xs = np.array(data)[:, :position]
78 | ys = np.array(data)[:, position:]
79 |
80 | return torch.from_numpy(xs).float().unsqueeze(dim=2), torch.from_numpy(ys).float(), position, forecast_point_num
81 |
82 |
83 | def construct_dataset(size=100):
84 |
85 | file_list = os.listdir(DATADIR)
86 |
87 | counter = 0
88 | data_dict = {}
89 | file_num = len(file_list)
90 | for file in file_list:
91 | data_path = osp.join(DATADIR, file)
92 | if os.path.isdir(osp.join(DATADIR, file)):
93 | file_num -= 1
94 | continue
95 | data_dict.setdefault(file.split('.')[0], obj_unserialization(data_path))
96 | counter += 1
97 | if counter % size == 0:
98 | save_path = osp.join(DATADIR,
99 | 'dataset\\%s' % generate_filename('.pkl', *['UCR', str(counter - size + 1), str(counter)])
100 | )
101 | obj_serialization(save_path, data_dict)
102 | print(save_path)
103 | data_dict.clear()
104 | elif counter == file_num:
105 | save_path = osp.join(DATADIR,
106 | 'dataset\\%s' % generate_filename('.pkl', *['UCR', str(size * (counter // size) + 1), str(counter)])
107 | )
108 | obj_serialization(save_path, data_dict)
109 | print(save_path)
110 | data_dict.clear()
111 |
112 |
113 | def normalizer(data=None):
114 |
115 | # z-score normalization
116 |
117 | if data is not None:
118 | return preprocessing.scale(data, axis=1)
119 | else:
120 | raise Exception('data is None!')
121 | pass
122 |
123 |
124 | def get_basic_data():
125 |
126 | load_data_path = osp.join(DATADIR, 'few_shot_data\\few_shot_load_data.pkl')
127 | UCR_train_data_path = osp.join(DATADIR, 'train_data.pkl')
128 | UCR_test_data_path = osp.join(DATADIR, 'test_data.pkl')
129 |
130 | load_data = obj_unserialization(load_data_path)
131 | UCR_train_data = obj_unserialization(UCR_train_data_path)
132 | UCR_test_data = obj_unserialization(UCR_test_data_path)
133 |
134 | few_shot_train_data = OrderedDict()
135 | few_shot_test_data = OrderedDict()
136 | for key, value in load_data.items():
137 | # if key in DIRTY_DATA_ID:
138 | # continue
139 | train_data, test_data = split_data(data=value, ratio=0.5)
140 | few_shot_train_data.setdefault(key, np.array(train_data))
141 | few_shot_test_data.setdefault(key, np.array(test_data))
142 | print('valid load dataset: %d' % len(few_shot_train_data))
143 |
144 | for key, value in UCR_train_data.items():
145 | if key in few_shot_dataset_name:
146 | few_shot_train_data.setdefault(key, value)
147 | few_shot_test_data.setdefault(key, UCR_test_data[key])
148 |
149 | print('all the few shot dataset: %d' % len(few_shot_train_data))
150 |
151 | obj_serialization(osp.join(DATADIR, 'few_shot_data\\train_data.pkl'), few_shot_train_data)
152 | obj_serialization(osp.join(DATADIR, 'few_shot_data\\test_data.pkl'), few_shot_test_data)
153 |
154 |
155 | if __name__ == '__main__':
156 |
157 | pass
--------------------------------------------------------------------------------
/embedding/embedding.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | __author__ = 'XF'
3 | __date__ = '2022-07-11'
4 | '''
5 | This script is set to finish time series data embedding for uniting the length of data.
6 | '''
7 |
8 | # builtins library
9 | import os
10 | from collections import OrderedDict
11 |
12 | # third-party library
13 | import torch
14 | import numpy as np
15 | import torch.nn as nn
16 |
17 | # self-defined wheels
18 | from data_preprocessing import normalizer
19 | from tools.tools import obj_serialization, obj_unserialization
20 | from configs import DATA_DIR as DATADIR
21 |
22 | class EmbeddingBiGRU(nn.Module):
23 |
24 | def __init__(self, n_input, n_hidden, batch_size=100, bidirectional=True, forecasting_point_num=10):
25 | super().__init__()
26 | self.n_input = n_input
27 | self.n_hidden = n_hidden
28 | self.batch_size = batch_size
29 | self.bidirectional = bidirectional
30 | self.forecasting_point_num = forecasting_point_num
31 |
32 | self.bigru = nn.GRU(
33 | input_size=self.n_input,
34 | hidden_size=self.n_hidden,
35 | bidirectional=self.bidirectional,
36 | num_layers=1,
37 |
38 | )
39 |
40 | def batch_train(self, x):
41 |
42 | # x: (record size, seq_length, dim feature)
43 | record_size = x.shape[0]
44 | batch_data = []
45 | if record_size <= self.batch_size:
46 | batch_data.append(x)
47 | else:
48 | for pos in range(0, record_size, self.batch_size):
49 | batch_data.append(x[pos:pos + self.batch_size, :])
50 | if pos + self.batch_size < record_size:
51 | batch_data.append(x[pos + self.batch_size:, :])
52 | return batch_data
53 |
54 | def forward(self, x):
55 |
56 | # x: (batch size, seq_length, dim_feature) --> (seq_length, batch size, input size)
57 | batch_data = self.batch_train(x[:, :x.shape[1] - self.forecasting_point_num, :])
58 | forecasting_data = x[:, x.shape[1] - self.forecasting_point_num:, :].squeeze(dim=2).numpy()
59 | embedding = []
60 | for batch in batch_data:
61 | batch = batch.contiguous().view(batch.shape[1], len(batch), -1)
62 | gru_out, h_n = self.bigru(batch)
63 |
64 | forward_embedding = h_n[0, :, :].detach().numpy()
65 | backward_embedding = h_n[1, :, :].detach().numpy()
66 | embedding.append(np.concatenate((forward_embedding, backward_embedding), axis=1))
67 | embedding = np.concatenate(embedding, axis=0)
68 | embedding = np.concatenate((embedding, forecasting_data), axis=1)
69 | return embedding.astype(np.float64)
70 |
71 | pass
72 |
73 |
74 | if __name__ == '__main__':
75 |
76 | train_data_path = os.path.join(DATADIR, 'few_shot_data\\train_data.pkl')
77 | test_data_path = os.path.join(DATADIR, 'few_shot_data\\test_data.pkl')
78 | forecasting_point_num = 40
79 | model = EmbeddingBiGRU(n_input=1, n_hidden=100, forecasting_point_num=forecasting_point_num)
80 |
81 | # train data embedding ……
82 | print('train data embedding .......')
83 | train_data = obj_unserialization(train_data_path)
84 | train_data_embedding = OrderedDict()
85 |
86 | for key, value in train_data.items():
87 | input_data = torch.from_numpy(normalizer(value)).float().unsqueeze(dim=2)
88 | embedding_data = model(input_data)
89 | train_data_embedding.setdefault(key, embedding_data)
90 | print('train data dimension: %d' % len(train_data_embedding['0001'][0]))
91 | obj_serialization(os.path.join(DATADIR, 'few_shot_data\\train_data_embedding_%s.pkl' % str(forecasting_point_num)), train_data_embedding)
92 |
93 | # test data embedding ……
94 | print('test data embedding ......')
95 | test_data = obj_unserialization(test_data_path)
96 | test_data_embedding = OrderedDict()
97 |
98 | for key, value in test_data.items():
99 | input_data = torch.from_numpy(normalizer(value)).float().unsqueeze(dim=2)
100 | embedding_data = model(input_data)
101 | test_data_embedding.setdefault(key, embedding_data)
102 | print('test data dimension: %d' % len(test_data_embedding['0001'][0]))
103 | obj_serialization(os.path.join(DATADIR, 'few_shot_data\\test_data_embedding_%s.pkl' % str(forecasting_point_num)), test_data_embedding)
104 |
105 | print('OK!')
106 | pass
107 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | __author__ = 'XF'
3 | __date__ = '2022/03/10'
4 |
5 | 'begin from here'
6 |
7 |
8 | # built-in library
9 | import os.path as osp
10 | import time
11 |
12 | # third-party library
13 |
14 | # self-defined tools
15 | from core.options import parse_args
16 | from core.train import train
17 | from configs import TRAINING_TASK_SET, DATA_DIR
18 |
19 | if __name__ == '__main__':
20 |
21 | # python train_maml.py --mode together --model cnn+maml --epoch 100 --dataSet maml_load_data(14).pkl --ratio 0.9 --time_size 3 --update_step_train 10 --update_step_target 20 --meta_lr 0.001 --base_lr 0.01
22 |
23 | print('--------------------------------------------Time Series Forecasting------------------------------------------')
24 | params = parse_args('main')
25 | start = time.time()
26 | if params.user_id == 'none':
27 | # conducting all tasks
28 | for user_id in TRAINING_TASK_SET:
29 | train(
30 | epoch_num=params.epoch,
31 | test_user_index=user_id,
32 | add_dim_pos=params.add_dim_pos,
33 | data_path=osp.join(DATA_DIR, params.dataset),
34 | update_step_train=params.update_step_train,
35 | update_step_target=params.update_step_target,
36 | meta_lr=params.meta_lr,
37 | base_lr=params.base_lr,
38 | fine_lr=params.fine_lr,
39 | device=params.device,
40 | baseNet=params.baseNet,
41 | maml=params.maml,
42 | log=params.log,
43 | maml_log=params.maml_log,
44 | lsd=params.lsd,
45 | ppn=params.ppn,
46 | batch_task_num=params.batch_task_num,
47 | rmse_path=params.rmse_path,
48 | mape_path=params.mape_path,
49 | smape_path=params.smape_path,
50 | ft_step=params.ft_step,
51 | new_settings=params.new_settings
52 | )
53 | end = time.time()
54 | print('using time: %.4f Hour' % ((end - start) / 3600.0))
55 | print('training is over!')
56 | elif params.user_id in TRAINING_TASK_SET:
57 | # conducting single task
58 | smape, rmse = train(
59 | epoch_num=params.epoch,
60 | test_user_index=params.user_id,
61 | add_dim_pos=params.add_dim_pos,
62 | data_path=osp.join(DATA_DIR, params.dataset),
63 | update_step_train=params.update_step_train,
64 | update_step_target=params.update_step_target,
65 | meta_lr=params.meta_lr,
66 | base_lr=params.base_lr,
67 | fine_lr=params.fine_lr,
68 | device=params.device,
69 | baseNet=params.baseNet,
70 | maml=params.maml,
71 | log=params.log,
72 | maml_log=params.maml_log,
73 | lsd=params.lsd,
74 | ppn=params.ppn,
75 | rmse_path=params.rmse_path,
76 | mape_path=params.mape_path,
77 | smape_path=params.smape_path,
78 | ft_step=params.ft_step,
79 | new_settings=params.new_settings
80 | )
81 | end = time.time()
82 | print('using time: %.4f Min' % ((end - start) / 60.0))
83 | print('training is over!')
84 | print('smape: %.4f' % smape)
85 | print('rmse: %.4f' % rmse)
86 | else:
87 | raise Exception('Unknown user id!')
88 |
89 |
90 |
--------------------------------------------------------------------------------
/tools/tools.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | __author__ = 'XF'
3 | __date__ = '2022-07-11'
4 | '''
5 | The script is set for supplying some tool function.
6 | '''
7 |
8 | import os
9 | import time
10 | import pickle
11 | import pandas as pds
12 |
13 |
14 |
15 | def read_tsv(path=None, header=None):
16 |
17 | if path is None:
18 | raise FileExistsError('The path is None!')
19 |
20 | content = pds.read_csv(path, sep='\t', header=header, )
21 | return content
22 |
23 |
24 | # object serialization
25 | def obj_serialization(path, obj):
26 |
27 | if obj is not None:
28 | with open(path, 'wb') as file:
29 | pickle.dump(obj, file)
30 | else:
31 | print('object is None!')
32 |
33 |
34 | # object instantiation
35 | def obj_unserialization(path):
36 |
37 | if os.path.exists(path):
38 | with open(path, 'rb') as file:
39 | obj = pickle.load(file)
40 | return obj
41 | else:
42 | raise OSError('no such path:%s' % path)
43 |
44 |
45 | def generate_filename(suffix, *args, sep='_', timestamp=False):
46 |
47 | '''
48 |
49 | :param suffix: suffix of file
50 | :param sep: separator,default '_'
51 | :param timestamp: add timestamp for uniqueness
52 | :param args:
53 | :return:
54 | '''
55 |
56 | filename = sep.join(args).replace(' ', '_')
57 | if timestamp:
58 | filename += time.strftime('_%Y%m%d%H%M%S')
59 | if suffix[0] == '.':
60 | filename += suffix
61 | else:
62 | filename += ('.' + suffix)
63 |
64 | return filename
65 |
66 |
67 | def metrics(y, y_hat):
68 |
69 | assert y.shape == y_hat.shape # Tensor y and Tensor y_hat must have the same shape
70 | y = y.cpu()
71 | y_hat = y_hat.cpu()
72 | # mape
73 | _mape = mape(y, y_hat)
74 |
75 | # smape
76 | _smape = smape(y, y_hat)
77 |
78 | # rmse
79 | _rmse = rmse(y, y_hat)
80 |
81 | return _rmse, _mape, _smape
82 |
83 |
84 | def mape(Y, Y_hat):
85 |
86 | temp = [abs((y - y_hat) / y) for y, y_hat in zip(Y.view(-1).numpy(), Y_hat.view(-1).numpy())]
87 | return (sum(temp) / len(temp)) * 100
88 |
89 |
90 | def smape(Y, Y_hat):
91 |
92 | temp = [abs(y- y_hat) / (abs(y) + abs(y_hat)) for y, y_hat in zip(Y.view(-1).numpy(), Y_hat.view(-1).numpy())]
93 | return (sum(temp) / len(temp)) * 200
94 |
95 |
96 | def rmse(Y, Y_hat):
97 |
98 | temp = [pow(y - y_hat, 2) for y, y_hat in zip(Y.view(-1).numpy(), Y_hat.view(-1).numpy())]
99 | return pow(sum(temp) / len(temp), 0.5)
100 |
101 |
102 | if __name__ == '__main__':
103 | pass
--------------------------------------------------------------------------------