├── .idea ├── .gitignore ├── Meta-Learning4FSTSF.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── README.md ├── configs.py ├── core ├── base_nets.py ├── meta_nets.py ├── options.py ├── task_split.py └── train.py ├── data └── few_shot_data │ ├── test_data_embedding_10.pkl │ ├── test_data_embedding_20.pkl │ ├── test_data_embedding_30.pkl │ ├── test_data_embedding_40.pkl │ ├── train_data_embedding_10.pkl │ ├── train_data_embedding_20.pkl │ ├── train_data_embedding_30.pkl │ └── train_data_embedding_40.pkl ├── embedding ├── data_preprocessing.py └── embedding.py ├── main.py └── tools └── tools.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /.idea/Meta-Learning4FSTSF.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 xf-git 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Learning4FSTSF 2 | Meta-Learning for Few-Shot Time Series Forecasting 3 | 4 | # Usage 5 | 6 | This section of the README walks through how to train the models. 7 | 8 | ## data prepare 9 | > data_preprocessing.py + embedding.py 10 | 11 | **notes**: 12 | The time-series data given in '/data/few_shot_data/...' already have done this step. For new raw time-series data, the two scripts can be used in this step. 13 | 14 | 15 | ## training of Base_{model} 16 | ### In this phase, a dataset is a time-series task, and each task would be training seperately. 17 | 18 | >**main.py** 19 | >>**Arguments help:** 20 | 21 | --baseNet: [mlp/cnn/lstm/cnnConlstm] 22 | --dataset: the directory of saving pre-processed time-series data 23 | --update_step_target: update times of network 24 | --fine_lr: leanring rate in this phase 25 | --ppn: predict point number [10/20/30/40] 26 | --device: [cpu/cuda] 27 | --user_id: the name of the task that will be training, it can be found in ./config.py TRAINING_TASK_SET 28 | 29 | >**training single task:** 30 | 31 | ''' 32 | python main.py --baseNet [mlp/cnn/lstm/cnnConlstm] --dateset [few_shot_data/your defined data dir] --update_step_target 10 --fine_lr 0.001 --ppn [10/20/30/40] --device [cpu/cuda] --user_id 0001 33 | ''' 34 | 35 | >**training all task:** 36 | 37 | ''' 38 | python main.py --baseNet [mlp/cnn/lstm/cnnConlstm] --dateset few_shot_data --update_step_target 10 --fine_lr 0.001 --ppn [10/20/30/40] --device [cpu/cuda] 39 | ''' 40 | 41 | 42 | ## training Meta_{model} 43 | ### In this phase, one task is selected as target task, and the remains are training-task set, firstly training baseNet using support set of training-task set, and then training MetaNet using query set of training-task set, finally using support set of target task to fine tune MetaNet. 44 | 45 | >**main.py** 46 | >>**Argument help:** 47 | 48 | --maml: using 'maml mode' to training model 49 | --update_step_train: the update times of baseNet on training-task set 50 | --update_step_target: the update times of MetaNet on target task 51 | --epoch: iteration times 52 | --base_lr: the learning rate of baseNet 53 | --meta_lr: the learning rate of MetaNet 54 | --fine_lr: the learning rate of MetaNet during fine-tuing 55 | 56 | >**training single task:** 57 | 58 | ''' 59 | python main.py --baseNet [cnn/lstm/cnnConlstm] --maml --dataset few_shot_data --epoch 10 --update_step_train 10 --update_step_target 10 --base_lr 0.01 --meta_lr 0.01 --fine_lr 0.01 --ppn 10 --device [cpu/cuda] --user_id Wine 60 | ''' 61 | 62 | >**training all task:** 63 | 64 | ''' 65 | python main.py --baseNet [cnn/lstm/cnnConlstm] --maml --dataset few_shot_data --epoch 10 --update_step_train 10 --update_step_target 10 --base_lr 0.01 --meta_lr 0.01 --fine_lr 0.01 --ppn 10 --device [cpu/cuda] 66 | ''' 67 | 68 | ## results 69 | ### All the trained models and evaluating metrics would be saved in dir ./results/ 70 | 71 | ## log 72 | ## Some useful log information would be saved in dir ./log/ -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | __author__ = 'XF' 3 | __date__ = '2022/03/10' 4 | 5 | 'default configuration for this project' 6 | 7 | import os.path as osp 8 | 9 | 10 | BASE_DIR = osp.dirname(osp.abspath(__file__)) 11 | 12 | dataName = 'data' 13 | model_save_dir = osp.join(BASE_DIR, 'model') 14 | loss_save_dir = osp.join(BASE_DIR, 'results/loss') 15 | log_save_path = osp.join(BASE_DIR, 'log/logs.txt') 16 | log_save_dir = osp.join(BASE_DIR, 'log') 17 | DATA_DIR = osp.join(BASE_DIR, 'data') 18 | MODEL_PATH = osp.join(BASE_DIR, 'results/model') 19 | console_file = osp.join(BASE_DIR, 'log/logs.txt') 20 | exp_result_dir = osp.join(BASE_DIR, 'results') 21 | 22 | MODEL_NAME = ['mlp', 'lstm', 'cnn', 'cnnConlstm', 'lstm+maml', 'cnn+maml', 'cnnConlstm+maml'] 23 | MODE_NAME = ['training', 'testing', 'together'] 24 | 25 | few_shot_dataset_name = [ 26 | 'Beef', 27 | 'BeetleFly', 28 | 'BirdChicken', 29 | 'Car', 30 | 'Coffee', 31 | 'FaceFour', 32 | 'Herring', 33 | 'Lightning2', 34 | 'Lightning7', 35 | 'Meat', 36 | 'OliveOil', 37 | 'Rock', 38 | 'Wine' 39 | ] 40 | 41 | TRAINING_TASK_SET = [ 42 | '0001', 43 | '0002', 44 | '0003', 45 | '0004', 46 | '0005', 47 | '0006', 48 | '0007', 49 | '0008', 50 | '0009', 51 | '0010', 52 | '0011', 53 | '0012', 54 | '0013', 55 | '0014', 56 | '0015', 57 | '0016', 58 | '0022', 59 | '0023', 60 | '0024', 61 | '0025', 62 | '0026', 63 | '0029', 64 | '0030', 65 | '0031', 66 | '0032', 67 | '0037', 68 | '0046', 69 | '0047', 70 | '0048', 71 | '0049', 72 | '0050', 73 | '0051', 74 | '0054', 75 | '0055', 76 | '0056', 77 | '0066', 78 | '0069', 79 | '0070', 80 | '0071', 81 | '0082', 82 | '0085', 83 | '0088', 84 | '0089', 85 | '0090', 86 | '0091', 87 | '0092', 88 | '0093', 89 | '0094', 90 | '0095', 91 | '0096', 92 | '0097', 93 | '0098', 94 | '0099', 95 | '0100', 96 | '0102', 97 | '0103', 98 | '0104', 99 | '0106', 100 | '0107', 101 | '0108', 102 | '0110', 103 | '0111', 104 | '0112', 105 | '0113', 106 | '0114', 107 | '0115', 108 | '0116', 109 | '0118', 110 | '0119', 111 | '0120', 112 | '0121', 113 | '0122', 114 | '0123', 115 | '0124', 116 | '0125', 117 | '0126', 118 | '0127', 119 | '0128', 120 | '0129', 121 | '0130', 122 | '0131', 123 | '0132', 124 | '0133', 125 | '0134', 126 | '0135', 127 | '0136', 128 | '0137', 129 | '0138', 130 | '0139', 131 | '0140', 132 | '0141', 133 | '0142', 134 | '0143', 135 | '0144', 136 | '0145', 137 | '0146', 138 | '0147', 139 | '0148', 140 | '0149', 141 | '0150', 142 | '0151', 143 | '0152', 144 | '0153', 145 | '0154', 146 | '0155', 147 | '0156', 148 | 'Beef', 149 | 'BeetleFly', 150 | 'BirdChicken', 151 | 'Car', 152 | 'Coffee', 153 | 'FaceFour', 154 | 'Herring', 155 | 'Lightning2', 156 | 'Lightning7', 157 | 'Meat', 158 | 'OliveOil', 159 | 'Rock', 160 | 'Wine' 161 | ] 162 | 163 | if __name__ == '__main__': 164 | 165 | print(BASE_DIR) 166 | pass 167 | -------------------------------------------------------------------------------- /core/base_nets.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | __author__ = 'XF' 3 | __date__ = '2022/03/10' 4 | 5 | 'base networks' 6 | 7 | # built-in library 8 | import os 9 | import math 10 | import sys 11 | from copy import deepcopy 12 | 13 | # third-party library 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | 18 | # self-defined library 19 | from tools.tools import metrics as Metrics 20 | 21 | torch.set_default_tensor_type(torch.DoubleTensor) 22 | 23 | 24 | class MLP(nn.Module): 25 | 26 | def __init__(self, n_input, n_hidden, n_output): 27 | super(MLP, self).__init__() 28 | self.name = 'MLP' 29 | 30 | self.hidden_size = n_hidden 31 | # this list contains all tensor needed to be optimized 32 | self.params = nn.ParameterList() 33 | 34 | # linear input layer 35 | weight = nn.Parameter(torch.ones(n_hidden, n_input)) 36 | bias = nn.Parameter(torch.zeros(n_hidden)) 37 | 38 | self.params.extend([weight, bias]) 39 | 40 | # linear output layer 41 | weight = nn.Parameter(torch.ones(n_output, n_hidden)) 42 | bias = nn.Parameter(torch.zeros(n_output)) 43 | 44 | self.params.extend([weight, bias]) 45 | 46 | self.init() 47 | 48 | def parameters(self): 49 | return self.params 50 | 51 | def init(self): 52 | stdv = 1.0 / math.sqrt(self.hidden_size) 53 | for weight in self.parameters(): 54 | weight.data.uniform_(-stdv, stdv) 55 | 56 | def forward(self, x, vars=None): 57 | 58 | if vars is None: 59 | params = self.params 60 | else: 61 | params = vars 62 | 63 | # input layer 64 | (weight_input, bias_input) = (params[0].to(x.device), params[1].to(x.device)) 65 | x = F.linear(x, weight_input, bias_input) 66 | 67 | # output layer 68 | (weight_output, bias_output) = (params[2].to(x.device), params[3].to(x.device)) 69 | out = F.linear(x, weight_output, bias_output) 70 | 71 | return out 72 | 73 | 74 | class BaseCNN(nn.Module): 75 | 76 | def __init__(self, output=10): 77 | super(BaseCNN, self).__init__() 78 | self.name = 'BASECNN' 79 | self.output = output 80 | 81 | # this list contains all tensor needed to be optimized 82 | self.vars = nn.ParameterList() 83 | 84 | # running_mean and running var 85 | self.vars_bn = nn.ParameterList() 86 | 87 | # 填充需要训练的网络的参数 88 | 89 | # Conv1d layer 90 | # [channel_out, channel_in, kernel-size] 91 | weight = nn.Parameter(torch.ones(64, 1, 3)) 92 | 93 | nn.init.kaiming_normal_(weight) 94 | 95 | bias = nn.Parameter(torch.zeros(64)) 96 | 97 | self.vars.extend([weight, bias]) 98 | 99 | # linear layer 100 | weight = nn.Parameter(torch.ones(self.output, 64*100)) 101 | bias = nn.Parameter(torch.zeros(self.output)) 102 | 103 | self.vars.extend([weight, bias]) 104 | 105 | def forward(self, x, vars=None, bn_training=True): 106 | 107 | ''' 108 | 109 | :param x: [batch size, 1, 3, 94] 110 | :param vars: 111 | :param bn_training: set false to not update 112 | :return: 113 | ''' 114 | 115 | if vars is None: 116 | vars = self.vars 117 | 118 | # x = x.squeeze(dim=2) 119 | # x = x.unsqueeze(dim=1) 120 | # Conv1d layer 121 | weight, bias = vars[0].to(x.device), vars[1].to(x.device) 122 | # x ==> (batch size, 1, 200) 123 | 124 | x = F.conv1d(x, weight, bias, stride=1, padding=1) # ==>(batch size, 64, 200) 125 | x = F.relu(x, inplace=True) # ==> (batch_size, 64, 200) 126 | x = F.max_pool1d(x, kernel_size=2) # ==> (batch_size, 64, 100) 127 | 128 | # linear layer 129 | x = x.view(x.size(0), -1) # flatten ==> (batch_size, 16*12) 130 | weight, bias = vars[-2].to(x.device), vars[-1].to(x.device) 131 | x = F.linear(x, weight, bias) 132 | 133 | return x 134 | 135 | def parameters(self): 136 | return self.vars 137 | 138 | def zero_grad(self): 139 | pass 140 | 141 | pass 142 | 143 | 144 | class BaseLSTM(nn.Module): 145 | 146 | def __init__(self, n_features, n_hidden, n_output, n_layer=1): 147 | super().__init__() 148 | self.name = 'BaseLSTM' 149 | 150 | # this list contains all tensor needed to be optimized 151 | self.params = nn.ParameterList() 152 | 153 | self.input_size = n_features 154 | # print(n_features) 155 | self.hidden_size = n_hidden 156 | self.output_size = n_output 157 | self.layer_size = n_layer 158 | 159 | # 输入层 160 | W_i = nn.Parameter(torch.Tensor(self.hidden_size * 4, self.input_size)) 161 | bias_i = nn.Parameter(torch.Tensor(self.hidden_size * 4)) 162 | self.params.extend([W_i, bias_i]) 163 | 164 | # 隐含层 165 | W_h = nn.Parameter(torch.Tensor(self.hidden_size * 4, self.hidden_size)) 166 | bias_h = nn.Parameter(torch.Tensor(self.hidden_size * 4)) 167 | self.params.extend([W_h, bias_h]) 168 | 169 | if self.layer_size > 1: 170 | for _ in range(self.layer_size - 1): 171 | 172 | # 第i层lstm 173 | # 输入层 174 | W_i = nn.Parameter(torch.Tensor(self.hidden_size * 4, self.hidden_size)) 175 | bias_i = nn.Parameter(torch.Tensor(self.hidden_size * 4)) 176 | self.params.extend([W_i, bias_i]) 177 | # 隐含层 178 | W_h = nn.Parameter(torch.Tensor(self.hidden_size * 4, self.hidden_size)) 179 | bias_h = nn.Parameter(torch.Tensor(self.hidden_size * 4)) 180 | self.params.extend([W_h, bias_h]) 181 | 182 | 183 | # 输出层 184 | W_linear = nn.Parameter(torch.Tensor(self.output_size, self.hidden_size)) 185 | bias_linear = nn.Parameter(torch.Tensor(self.output_size)) 186 | self.params.extend([W_linear, bias_linear]) 187 | 188 | self.init() 189 | pass 190 | 191 | def parameters(self): 192 | return self.params 193 | 194 | def init(self): 195 | stdv = 1.0 / math.sqrt(self.hidden_size) 196 | for weight in self.parameters(): 197 | weight.data.uniform_(-stdv, stdv) 198 | 199 | def forward(self, x, vars=None, init_state=None): 200 | 201 | if vars is None: 202 | params = self.params 203 | else: 204 | params = vars 205 | 206 | # assume the shape of x is (batch_size, time_size, feature_size) 207 | 208 | batch_size, time_size, _ = x.size() 209 | hidden_seq = [] 210 | # with torch.autograd.set_detect_anomaly(True): 211 | if init_state is None: 212 | h_t, c_t = ( 213 | torch.zeros(batch_size, self.hidden_size).to(x.device), 214 | torch.zeros(batch_size, self.hidden_size).to(x.device) 215 | ) 216 | else: 217 | h_t, c_t = init_state 218 | 219 | HS = self.hidden_size 220 | 221 | for t in range(time_size): 222 | x_t = x[:, t, :] 223 | W_i, bias_i = (params[0].to(x.device), params[1].to(x.device)) 224 | W_h, bias_h = (params[2].to(x.device), params[3].to(x.device)) 225 | 226 | # gates = x_t @ W_i + h_t @ W_h + bias_h + bias_i 227 | gates = F.linear(x_t, W_i, bias_i) + F.linear(h_t, W_h, bias_h) 228 | 229 | i_t, f_t, g_t, o_t = ( 230 | torch.sigmoid(gates[:, :HS]), # input 231 | torch.sigmoid(gates[:, HS:HS * 2]), # forget 232 | torch.tanh(gates[:, HS * 2:HS * 3]), 233 | torch.sigmoid(gates[:, HS * 3:]) # output 234 | ) 235 | c_t = f_t * c_t + i_t * g_t 236 | h_t = o_t * torch.tanh(c_t) 237 | hidden_seq.append(h_t) 238 | 239 | W_linear, bias_linear = (params[-2].to(x.device), params[-1].to(x.device)) 240 | out = F.linear(hidden_seq[-1], W_linear, bias_linear) 241 | # out = hidden_seq[-1] @ W_linear + bias_linear 242 | return out 243 | 244 | 245 | class BaseCNNConLSTM(nn.Module): 246 | 247 | def __init__(self, n_features, n_hidden, n_output, n_layer=1, time_size=1, cnn_feature=200): 248 | super(BaseCNNConLSTM, self).__init__() 249 | self.name = 'BaseCNNConLSTM' 250 | self.time_size = time_size 251 | 252 | # this list contain all tensor needed to be optimized 253 | self.params = nn.ParameterList() 254 | self.cnn = BaseCNN(output=cnn_feature) 255 | self.lstm = BaseLSTM(n_features=n_features, n_hidden=n_hidden, n_output=n_output, n_layer=n_layer) 256 | self.cnn_tensor_num = 0 257 | self.lstm_tensor_num = 0 258 | self.init() 259 | 260 | def init(self): 261 | 262 | self.cnn_tensor_num = len(self.cnn.parameters()) 263 | self.lstm_tensor_num = len(self.lstm.parameters()) 264 | for param in self.cnn.parameters(): 265 | self.params.append(param) 266 | for param in self.lstm.parameters(): 267 | self.params.append(param) 268 | 269 | def sequence(self, data): 270 | 271 | dim_1, dim_2 = data.shape 272 | new_dim_1 = dim_1 - self.time_size + 1 273 | 274 | x = torch.zeros((new_dim_1, self.time_size, dim_2)) 275 | 276 | for i in range(dim_1 - self.time_size + 1): 277 | x[i] = data[i: i + self.time_size] 278 | return x.to(data.device) 279 | 280 | def forward(self, x, vars=None, init_states=None): 281 | 282 | if vars is None: 283 | params = self.params 284 | else: 285 | params = vars 286 | 287 | x = self.cnn(x, params[: self.cnn_tensor_num]) 288 | x = x.unsqueeze(dim=1) 289 | output = self.lstm(x, params[self.cnn_tensor_num:], init_states) 290 | return output 291 | pass 292 | pass 293 | -------------------------------------------------------------------------------- /core/meta_nets.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | __author__ = 'XF' 3 | __date___ = '2022/03/10' 4 | 5 | 'meta networks' 6 | 7 | # built-in library 8 | import os.path as osp 9 | from copy import deepcopy 10 | 11 | # third-party library 12 | import torch 13 | import torch.nn as nn 14 | from torch.nn import functional as F 15 | 16 | # self-defined library 17 | from tools.tools import metrics as Metrics, generate_filename 18 | from configs import MODEL_PATH 19 | 20 | torch.set_default_tensor_type(torch.DoubleTensor) 21 | 22 | 23 | class MetaNet(nn.Module): 24 | 25 | def __init__(self, baseNet=None, update_step_train=10, update_step_target=20, meta_lr=0.001, base_lr=0.01, fine_lr=0.01): 26 | 27 | super(MetaNet, self).__init__() 28 | self.update_step_train = update_step_train 29 | self.update_step_target = update_step_target 30 | self.meta_lr = meta_lr 31 | self.base_lr = base_lr 32 | self.fine_tune_lr = fine_lr 33 | 34 | if baseNet is not None: 35 | self.net = baseNet 36 | else: 37 | raise Exception('baseNet is None') 38 | self.meta_optim = torch.optim.Adam(self.net.parameters(), lr=self.meta_lr) 39 | # self.meta_optim = torch.optim.SGD(self.net.parameters(), lr=self.meta_lr) 40 | pass 41 | 42 | def save_model(self, model_name='model'): 43 | torch.save(self.net, osp.join(MODEL_PATH, generate_filename('pth',*[model_name,]))) 44 | 45 | def forward(self, spt_x, spt_y, qry_x, qry_y, device='cpu'): 46 | ''' 47 | 48 | :param spt_x: if baseNet is cnn: [ spt size, in_channel, height, width], lstm [spt_size, time_size, feature_size] 49 | :param spt_y: [ spt size] 50 | :param qry_x: if baseNet is cnn: [ qry size, in_channel, height, width], lstm [qry size, time_size, feature_size] 51 | :param qry_y: [ qry size] 52 | :param min_max_data_path: 用来进行数据反归一化的min,max值的存储路径 53 | :return: 54 | batch size 在本任务中设置为1, 即每次采样一个任务进行训练 55 | ''' 56 | 57 | # spt_size, channel, height, width = spt_x.size() 58 | # qry_size = spt_y.size(0) 59 | task_num = len(spt_x) 60 | loss_list_qry = [] 61 | mape_list = [] 62 | rmse_list = [] 63 | smape_list = [] 64 | qry_loss_sum = 0 65 | # print('更新任务网络===============================================') 66 | # 第0步更新 67 | for i in range(task_num): 68 | x_spt = torch.from_numpy(spt_x[i]).to(device) 69 | y_spt = torch.from_numpy(spt_y[i]).to(device) 70 | x_qry = torch.from_numpy(qry_x[i]).to(device) 71 | y_qry = torch.from_numpy(qry_y[i]).to(device) 72 | 73 | y_hat = self.net(x_spt, vars=None) 74 | loss = F.mse_loss(y_hat, y_spt) 75 | grad = torch.autograd.grad(loss, self.net.parameters()) 76 | grads_params = zip(grad, self.net.parameters()) # 将梯度和参数一一对应起来 77 | 78 | # fast_weights 这一步相当于求了一个 theta - alpha * nabla(L) 79 | fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], grads_params)) 80 | 81 | # 在query集上测试,计算准确率 82 | # 使用更新后的参数在query集上测试 83 | with torch.no_grad(): 84 | y_hat = self.net(x_qry, fast_weights) 85 | loss_qry = F.mse_loss(y_hat, y_qry) 86 | loss_list_qry.append(loss_qry) 87 | 88 | # 计算评价指标 89 | rmse, mape, smape = Metrics(y_qry, y_hat) 90 | 91 | rmse_list.append(rmse) 92 | mape_list.append(mape) 93 | smape_list.append(smape) 94 | 95 | for step in range(1, self.update_step_train): 96 | y_hat = self.net(x_spt, fast_weights) 97 | loss = F.mse_loss(y_hat, y_spt) 98 | grad = torch.autograd.grad(loss, fast_weights) 99 | grads_params = zip(grad, fast_weights) 100 | fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], grads_params)) 101 | 102 | if step < self.update_step -1: 103 | with torch.no_grad(): 104 | y_hat = self.net(x_qry, fast_weights) 105 | loss_qry = F.mse_loss(y_hat, y_qry) 106 | loss_list_qry.append(loss_qry) 107 | else: 108 | y_hat = self.net(x_qry, fast_weights) 109 | loss_qry = F.mse_loss(y_hat, y_qry) 110 | loss_list_qry.append(loss_qry) 111 | qry_loss_sum += loss_qry 112 | 113 | with torch.no_grad(): 114 | rmse, mape, smape = Metrics(y_qry, y_hat) 115 | 116 | rmse_list.append(rmse) 117 | mape_list.append(mape) 118 | smape_list.append(smape) 119 | pass 120 | 121 | # 更新元网络 122 | loss_qry = qry_loss_sum / task_num # 表示在经过update_step之后,learner在当前任务query set上的损失 123 | self.meta_optim.zero_grad() # 梯度清零 124 | loss_qry.backward() 125 | self.meta_optim.step() 126 | 127 | 128 | return { 129 | 'loss': loss_list_qry[-1].item(), 130 | 'rmse': rmse_list[-1], 131 | 'mape': mape_list[-1], 132 | 'smape': smape_list[-1] 133 | } 134 | 135 | def fine_tuning(self, spt_x, spt_y, qry_x, qry_y, naive=False): 136 | 137 | ''' 138 | 139 | :param spt_x: if baseNet is cnn:[set size, channel, height, width] if baseNet is lstm: [batch_size, seq_size, feature_size] 140 | :param spt_y: 141 | :param qry_x: 142 | :param qry_y: 143 | :return: 144 | ''' 145 | 146 | # 评价指标 147 | loss_qry_list = [] 148 | rmse_list = [] 149 | mape_list = [] 150 | smape_list = [] 151 | min_loss = 0 152 | best_epoch = 0 153 | min_train_loss = 1000000 154 | loss_set = { 155 | 'train_loss': [], 156 | 'validation_loss': [] 157 | } 158 | 159 | # new_net = deepcopy(self.net) 160 | # new_net = self.net 161 | y_hat = self.net(spt_x) 162 | # with torch.autograd.set_detect_anomaly(True): 163 | loss = F.mse_loss(y_hat, spt_y) 164 | loss_set['train_loss'].append(loss.item()) 165 | if loss.item() < min_train_loss: 166 | min_train_loss = loss.item() 167 | grad = torch.autograd.grad(loss, self.net.parameters()) 168 | grads_params = zip(grad, self.net.parameters()) 169 | fast_weights = list(map(lambda p: p[1] - self.fine_tune_lr * p[0], grads_params)) 170 | 171 | # 在query集上测试,计算评价指标 172 | # 使用更新后的参数进行测试 173 | with torch.no_grad(): 174 | y_hat = self.net(qry_x, fast_weights) 175 | loss_qry = F.mse_loss(y_hat, qry_y) 176 | loss_set['validation_loss'].append(loss_qry.item()) 177 | loss_qry_list.append(loss_qry) 178 | # 计算评价指标mape 179 | rmse, mape, smape = Metrics(qry_y, y_hat) 180 | 181 | rmse_list.append(rmse) 182 | mape_list.append(mape) 183 | smape_list.append(smape) 184 | min_rmse = rmse 185 | min_mape = mape 186 | min_smape = smape 187 | min_loss = loss_qry.item() 188 | rmse_best_epoch = 1 189 | mape_best_epoch = 1 190 | smape_best_epcoh = 1 191 | 192 | if naive: 193 | print(' Epoch [1] | train_loss: %.4f | test_loss: %.4f | rmse: %.4f | mape: %.4f | smape: %.4f |' 194 | % (loss.item(), loss_qry.item(), rmse, mape, smape)) 195 | 196 | for step in range(1, self.update_step_target): 197 | y_hat = self.net(spt_x, fast_weights) 198 | loss = F.mse_loss(y_hat, spt_y) 199 | loss_set['train_loss'].append(loss.item()) 200 | if loss.item() < min_train_loss: 201 | min_train_loss = loss.item() 202 | grad = torch.autograd.grad(loss, fast_weights) 203 | grads_params = zip(grad, fast_weights) 204 | fast_weights = list(map(lambda p: p[1] - self.fine_tune_lr * p[0], grads_params)) 205 | 206 | # 在query测试 207 | with torch.no_grad(): 208 | # 计算评价指标 209 | y_hat = self.net(qry_x, fast_weights) 210 | loss_qry = F.mse_loss(y_hat, qry_y) 211 | loss_set['validation_loss'].append(loss_qry.item()) 212 | loss_qry_list.append(loss_qry) 213 | 214 | rmse, mape, smape = Metrics(qry_y, y_hat) 215 | 216 | rmse_list.append(rmse) 217 | mape_list.append(mape) 218 | smape_list.append(smape) 219 | if min_rmse > rmse: 220 | min_rmse = rmse 221 | rmse_best_epoch = step + 1 222 | if min_smape > smape: 223 | min_smape = smape 224 | smape_best_epcoh = step + 1 225 | min_rmse = rmse 226 | self.save_model(model_name=self.net.name) 227 | print(' Epoch [%d] | train_loss: %.4f | test_loss: %.4f | rmse: %.4f | smape: %.4f |' 228 | % (step + 1, loss.item(), loss_qry.item(), rmse, smape)) 229 | 230 | return { 231 | 'test_loss': min_loss, 232 | 'train_loss': min_train_loss, 233 | 'rmse': min_rmse, 234 | 'mape': min_mape, 235 | 'smape': min_smape, 236 | 'rmse_best_epoch': rmse_best_epoch, 237 | 'mape_best_epoch': mape_best_epoch, 238 | 'smape_best_epoch': smape_best_epcoh, 239 | 'loss_set': loss_set 240 | } 241 | pass -------------------------------------------------------------------------------- /core/options.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | __author__ = 'XF' 3 | __date__ = '2022/03/10' 4 | 5 | 'options' 6 | 7 | # built-in library 8 | import os 9 | import os.path as osp 10 | import argparse 11 | from builtins import print as b_print 12 | 13 | # self-defined library 14 | from configs import model_save_dir, loss_save_dir, log_save_path, log_save_dir, MODEL_NAME, MODE_NAME, exp_result_dir 15 | from tools.tools import generate_filename 16 | 17 | 18 | def task_id_int2str(int_id): 19 | 20 | if int_id < 10: 21 | str_id = '000' + str(int_id) 22 | elif int_id < 100: 23 | str_id = '00' + str(int_id) 24 | elif int_id < 1000: 25 | str_id = '0' + str(int_id) 26 | else: 27 | str_id = str(int_id) 28 | 29 | return str_id 30 | 31 | 32 | def print(*args, file='./log.txt', end='\n', terminate=True): 33 | 34 | with open(file=file, mode='a', encoding='utf-8') as console: 35 | b_print(*args, file=console, end=end) 36 | if terminate: 37 | b_print(*args, end=end) 38 | 39 | 40 | def parse_args(script='main'): 41 | 42 | parser = argparse.ArgumentParser(description='Time Seriess Forecasting script %s.py' % script) 43 | 44 | # training arguments 45 | parser.add_argument('--model', default=None, 46 | help='the model name is used to train.[mlp, lstm, cnn, cnnConlstm, lstm+maml, cnn+maml, cnnConlstm+maml]') 47 | parser.add_argument('--epoch', type=int, default=1, help='the iteration number for training data.') 48 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate.') 49 | 50 | # data arguments 51 | parser.add_argument('--dataset', default='few_shot_data', help='the data path for training and testing.') 52 | parser.add_argument('--ratio', type=float, default=0.8, help='the ratio of training set for all data set') 53 | parser.add_argument('--trainSet', default='', help='the path of the training data.') 54 | parser.add_argument('--testSet', default='', help='the path of the testing data.') 55 | parser.add_argument('--UCR', action='store_true', default=False, help='for UCR data.') 56 | parser.add_argument('--ppn', type=int, default=10, help='predict point num.') 57 | 58 | # save-path arguments 59 | parser.add_argument('--msd', default=model_save_dir, help='the model save dir.') 60 | parser.add_argument('--lsd', default=loss_save_dir, help='the loss save dir.') 61 | parser.add_argument('--log', default=log_save_path, help='the log save path.') 62 | parser.add_argument('--maml_log', default=log_save_path, help='the log save path.') 63 | parser.add_argument('--rmse_path', default=log_save_path, help='the log save metric rmse.') 64 | parser.add_argument('--mape_path', default=log_save_path, help='the log save metric mape.') 65 | parser.add_argument('--smape_path', default=log_save_path, help='the log save metric smape.') 66 | # the arguments for LSTM model 67 | parser.add_argument('--time_size', type=int, default=1, help='the time_size for lstm input') 68 | 69 | # the arguments for cnn model 70 | parser.add_argument('--add_dim_pos', type=int, default=1, help='the position for add dimension when change sequence data to img data') 71 | 72 | # the implement mode 73 | parser.add_argument('--mode', default='together', 74 | help='the implement mode for script, [training, testing, together] can be chosen') 75 | 76 | # for testing mode 77 | parser.add_argument('--model_state', default='', help='the path of trained model') 78 | 79 | # for maml 80 | parser.add_argument('--user_id', type=str, default='none', help='the id of true target task') 81 | parser.add_argument('--update_step_train', type=int, default=5, help='the train task update step') 82 | parser.add_argument('--update_step_target', type=int, default=50, help='the target task update step') 83 | parser.add_argument('--meta_lr', type=float, default=1e-4, help='the learning rate of meta network') 84 | parser.add_argument('--base_lr', type=float, default=1e-3, help='the learning rate of base network') 85 | parser.add_argument('--fine_lr', type=float, default=0.03, help='the learning rate of fine tune target network') 86 | parser.add_argument('--baseNet', default='cnn', help='the base network for maml training, [lstm, cnn, cnnConlstm] can be chosen') 87 | parser.add_argument('--maml', action='store_true', default=False, help='whether using maml algorithm to train a network.') 88 | parser.add_argument('--all_data_dir', default=None, help='the directory for all load data.') 89 | parser.add_argument('--begin_task', type=int, default=1, help='the begining task id that be used to batch training when having maml') 90 | parser.add_argument('--end_task', type=int, default=12, help='the ending task id that be used to batch training when having maml.') 91 | parser.add_argument('--batch_task_num', type=int, default=5, help='batch training for maml') 92 | # for hardware setting 93 | parser.add_argument('--device', default='cuda', help='the calculate device for torch Tensor, [cpu, cuda] can be chosen') 94 | 95 | # for new settings 96 | parser.add_argument('--new_settings', action='store_true', default=False, help='training scheme in new settings.') 97 | parser.add_argument('--ft_step', type=int, default=100, help='epoch number of fine-tuning in new settings') 98 | 99 | params = parser.parse_args() 100 | 101 | # maml log 102 | if params.maml: 103 | maml_log_path = osp.join(log_save_dir, generate_filename('.txt', *['log'], timestamp=True)) 104 | params.maml_log = maml_log_path 105 | params.log = maml_log_path 106 | params_show(params) 107 | 108 | # 动态生成日志文件 109 | log_path = osp.join('./log', generate_filename('.txt', *['log'], timestamp=True)) 110 | params.log = log_path 111 | params_show(params) 112 | 113 | # 生成实验结果日志 114 | if params.new_settings: 115 | result_dir_name = params.baseNet + '_' + str(params.ppn) + 'new_settings' 116 | else: 117 | result_dir_name = params.baseNet + '_' + str(params.ppn) 118 | if params.maml: 119 | result_dir_name = 'M_' + result_dir_name 120 | 121 | result_dir = osp.join(exp_result_dir, result_dir_name) 122 | if not osp.exists(result_dir): 123 | os.mkdir(result_dir) 124 | params.rmse_path = osp.join(result_dir, generate_filename('.txt', *['rmse'], timestamp=True)) 125 | params.mape_path = osp.join(result_dir, generate_filename('.txt', *['mape'], timestamp=True)) 126 | params.smape_path = osp.join(result_dir, generate_filename('.txt', *['smape'], timestamp=True)) 127 | 128 | # 参数合法性检查 129 | 130 | if params.mode == 'together': 131 | assert float(0) < params.ratio < float(1) # 对数据集拆分比例的检查 132 | elif params.mode == 'training': 133 | assert osp.exists(params.trainSet) 134 | elif params.mode == 'testing': 135 | assert osp.exists(params.testSet) and osp.exists(params.model_state) 136 | else: 137 | raise Exception('Unknown implement mode: %s' % params.mode) 138 | 139 | if params.baseNet in MODEL_NAME: 140 | assert params.epoch > 0 and isinstance(params.epoch, int) # 对epoch的检查 141 | # if params.model[:4] == 'lstm' and params.model[-4:] == 'lstm': # if model is 'lstm', the time_size id needed parameters 142 | # assert params.time_size > 0 and isinstance(params.time_size, int) 143 | else: 144 | raise Exception('Unknown model name: %s' % params.baseNet) 145 | 146 | return params 147 | 148 | 149 | def params_show(params): 150 | 151 | if params: 152 | print('Parameters Show', file=params.log) 153 | print('=======================================', file=params.log) 154 | print('About model:', file=params.log) 155 | print(' model: %s' % params.baseNet, file=params.log) 156 | # print(' epoch: %s' % params.epoch, file=params.log) 157 | 158 | # print(' learning rate: %s' % str(params.lr), file=params.log) 159 | # if params.model == 'lstm': 160 | # print(' time size: %d' % params.time_size, file=params.log) 161 | # if params.mode == 'testing': 162 | # print(' trained model path: %s' % params.model_state, file=params.log) 163 | print('About data:', file=params.log) 164 | print(' data file: %s' % params.dataset, file=params.log) 165 | # print(' training data file: %s' % params.trainSet, file=params.log) 166 | # print(' testing data file: %s' % params.testSet, file=params.log) 167 | # print(' data split rate: %s' % params.ratio, file=params.log) 168 | print(' predict point num: %d' % params.ppn, file=params.log) 169 | # print('implement mode: %s' % params.mode, file=params.log) 170 | 171 | if params.maml: 172 | print('=======================================', file=params.log) 173 | print('MAML Show', file=params.log) 174 | print('=======================================', file=params.log) 175 | if params.user_id != '0': 176 | #target_task = task_id_int2str(params.user_id) 177 | target_task = params.user_id 178 | print(' target task: %s' % target_task, file=params.log) 179 | else: 180 | begin_task = task_id_int2str(params.begin_task) 181 | end_task = task_id_int2str(params.end_task) 182 | print(' begin task: %s' % begin_task, file=params.log) 183 | print(' end task: %s' % end_task, file=params.log) 184 | print(' update step train: %d' % params.update_step, file=params.log) 185 | print(' update step target: %d' % params.update_step_test, file=params.log) 186 | print(' meta lr: %.4f' % params.meta_lr, file=params.log) 187 | print(' base lr: %.4f' % params.base_lr, file=params.log) 188 | print(' fine lr: %.4f' % params.fine_lr, file=params.log) 189 | print(' device: %s' % params.device, file=params.log) 190 | print('=======================================', file=params.log) 191 | else: 192 | raise Exception('params is None!', file=params.log) 193 | pass 194 | 195 | 196 | 197 | 198 | 199 | if __name__ == '__main__': 200 | pass 201 | 202 | 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /core/task_split.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | __author__ = 'XF' 3 | __date__ = '2022/03/10' 4 | 5 | 'task split' 6 | 7 | # built-in library 8 | import os.path as osp 9 | 10 | # third-party library 11 | from torch import from_numpy 12 | import numpy as np 13 | 14 | # self-defined tools 15 | from configs import DATA_DIR as data_dir 16 | from tools.tools import obj_unserialization 17 | from configs import TRAINING_TASK_SET 18 | 19 | 20 | class LoadData: 21 | 22 | def __init__(self, maml, test_user_index, add_dim_pos, data_path=None, ppn=None): 23 | 24 | ''' 25 | :param maml: 是否使用 maml algorithm to train 26 | :param test_user_index: 选择一个用户测试,其他用户用来训练meta-network的参数 27 | :param add_dim_pos: 给tensor增加维度的位置, 当add_dim_pos=-1时,不加维度 28 | :param data_path: 原始数据路径 29 | ppn: predict point num 30 | ''' 31 | 32 | if data_path: 33 | self.indexes = {'train': int(0), 'test': int(0)} 34 | self.task_id = {'train': [], 'test': []} 35 | self.features = None 36 | self.task_num = None 37 | self.datasets_cache = None 38 | self.output = 94 39 | self.outputs = [] 40 | if maml: 41 | self.datasets_cache = self.UCR_data_cache_maml(data_path, test_user_index, add_dim_pos, ppn) 42 | else: 43 | self.UCR_data_cache(data_path, add_dim_pos, ppn) 44 | else: 45 | raise Exception('data dir is None!') 46 | pass 47 | 48 | def UCR_data_cache(self, data_path, add_dim_pos, ppn): 49 | 50 | train_data_path = osp.join(data_path, 'train_data_embedding_%s.pkl' % str(ppn)) 51 | test_data_path = osp.join(data_path, 'test_data_embedding_%s.pkl' % str(ppn)) 52 | self.datasets_cache = {'train': [], 'test': []} 53 | 54 | train_data = obj_unserialization(train_data_path) 55 | test_data = obj_unserialization(test_data_path) 56 | for key in train_data.keys(): 57 | train_x, train_y, _ = self.split_x_y(train_data[key], add_dim_pos, ppn=ppn) 58 | test_x, test_y, _ = self.split_x_y(test_data[key], add_dim_pos, ppn=ppn) 59 | self.datasets_cache['train'].append([train_x, train_y]) 60 | self.datasets_cache['test'].append([test_x, test_y]) 61 | self.task_id['train'].append(key) 62 | self.task_id['test'].append(key) 63 | 64 | self.task_num = len(self.datasets_cache['train']) 65 | 66 | pass 67 | 68 | def UCR_data_cache_maml(self, data_path, test_task_index, add_dim_pos, ppn): 69 | 70 | datasets_cache = {'train': [], 'test': []} 71 | train_data_path = osp.join(data_path, 'train_data_embedding_%s.pkl' % str(ppn)) 72 | test_data_path = osp.join(data_path, 'test_data_embedding_%s.pkl' % str(ppn)) 73 | 74 | train_data = obj_unserialization(train_data_path) 75 | test_data = obj_unserialization(test_data_path) 76 | 77 | # print(test_task_index) 78 | for i, data_name in enumerate(TRAINING_TASK_SET): 79 | if test_task_index == data_name: 80 | test_task_index = i + 1 81 | # print(test_task_index) 82 | for i, key in enumerate(train_data.keys()): 83 | if i == test_task_index - 1: 84 | test_spt_x, test_spt_y, _ = self.split_x_y(train_data[key], add_dim_pos, ppn=ppn) 85 | test_qry_x, test_qry_y, _ = self.split_x_y(test_data[key], add_dim_pos, ppn=ppn) 86 | datasets_cache['test'].append([test_spt_x, test_spt_y, test_qry_x, test_qry_y]) 87 | self.task_id['test'].append(key) 88 | continue 89 | 90 | train_spt_x, train_spt_y, _ = self.split_x_y(train_data[key], add_dim_pos, ppn=ppn) 91 | train_qry_x, train_qry_y, _ = self.split_x_y(test_data[key], add_dim_pos, ppn=ppn) 92 | datasets_cache['train'].append([train_spt_x, train_spt_y, train_qry_x, train_qry_y]) 93 | self.task_id['train'].append(key) 94 | 95 | self.task_num = len(datasets_cache['train']) 96 | 97 | return datasets_cache 98 | 99 | def get_data(self, task_id=None): 100 | 101 | if task_id in self.task_id['train']: 102 | 103 | # task_id = task_id_int2str(task_id) 104 | pos_train = self.task_id['train'].index(task_id) 105 | pos_test = self.task_id['test'].index(task_id) 106 | return self.datasets_cache['train'][pos_train], self.datasets_cache['test'][pos_test], task_id 107 | else: 108 | raise Exception('Unknown the task id [%s]!' % task_id) 109 | pass 110 | 111 | def split_spt_qry(self, data, rate): 112 | 113 | spt = [] 114 | qry = [] 115 | for task in data: 116 | pos = int(len(task) * rate) 117 | spt.append(task[:pos]) 118 | qry.append(task[pos:]) 119 | # print('split train & val:') 120 | # print('train: %d, %d' % (len(spt[0]), len(spt[0][0]))) 121 | # print('val: %d, %d' % (len(qry[0]), len(qry[0][0]))) 122 | return spt, qry 123 | 124 | def split_x_y(self, data, add_dim_pos, ppn=10): 125 | 126 | forecast_point_num = ppn 127 | position = len(data[0]) - forecast_point_num 128 | xs = np.array(data)[:, :position] 129 | ys = np.array(data)[:, position:] 130 | # ==================================================== # 131 | # update time: 2021-12-10 132 | if add_dim_pos == -1: 133 | np_xs = np.array(xs) 134 | else: 135 | np_xs = from_numpy(xs).unsqueeze(dim=1).numpy() 136 | self.features = len(xs[0]) 137 | # ==================================================== # 138 | self.output = forecast_point_num 139 | self.outputs.append(forecast_point_num) 140 | # print('split x & y') 141 | # print(xs.shape) 142 | # print(ys.shape) 143 | 144 | return np_xs, np.array(ys), self.features 145 | 146 | def next(self, mode='train'): 147 | if self.indexes[mode] == len(self.datasets_cache[mode]): 148 | self.indexes[mode] = 0 149 | 150 | next_batch = self.datasets_cache[mode][self.indexes[mode]] 151 | task_id = self.task_id[mode][self.indexes[mode]] 152 | self.indexes[mode] += 1 153 | return next_batch, task_id 154 | 155 | 156 | if __name__ == '__main__': 157 | 158 | pass -------------------------------------------------------------------------------- /core/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | __author__ = 'XF' 3 | __date__ = '2022/03/10' 4 | 5 | 'train base network + maml' 6 | 7 | # built-in library 8 | import os 9 | import os.path as osp 10 | import time 11 | import copy 12 | from datetime import datetime 13 | 14 | # third-party library 15 | import torch 16 | import numpy as np 17 | 18 | # self-defined library 19 | from core.base_nets import MLP, BaseLSTM, BaseCNN, BaseCNNConLSTM 20 | from core.meta_nets import MetaNet 21 | from core.task_split import LoadData 22 | from configs import TRAINING_TASK_SET 23 | from tools.tools import generate_filename, obj_serialization 24 | from core.options import print as write_log 25 | 26 | torch.set_default_tensor_type(torch.DoubleTensor) 27 | 28 | 29 | 30 | def train(epoch_num, test_user_index, add_dim_pos, data_path, update_step_train, 31 | update_step_target, meta_lr, base_lr, fine_lr, device, baseNet, 32 | maml, log, maml_log, lsd, ppn, rmse_path, mape_path, 33 | smape_path, ft_step, batch_task_num=5, new_settings=False): 34 | 35 | # 设置随机数种子,保证运行结果可复现 36 | torch.manual_seed(1) 37 | np.random.seed(1) 38 | if device == 'cuda': 39 | torch.cuda.manual_seed_all(1) 40 | 41 | device = torch.device(device) # 选择 torch.Tensor 进行运算时的设备对象['cpu', 'cuda'] 42 | if baseNet == 'mlp': 43 | #print('no add dimension') 44 | add_dim_pos = -1 # 不给tensor添加维度 45 | 46 | # get data 47 | data = LoadData(maml, test_user_index, add_dim_pos, data_path=data_path, ppn=ppn) 48 | 49 | if baseNet == 'cnn': 50 | BaseNet = BaseCNN(output=data.output) 51 | elif baseNet == 'lstm': 52 | BaseNet = BaseLSTM(n_features=data.features, n_hidden=100, n_output=data.output) 53 | elif baseNet == 'cnnConlstm': 54 | BaseNet = BaseCNNConLSTM(n_features=data.features, n_hidden=100, n_output=data.output, cnn_feature=200) 55 | # ========================================= # 56 | # update time: 2021-12-10 57 | elif baseNet == 'mlp': 58 | # print(data.features, data.output) 59 | BaseNet = MLP(n_input=data.features, n_hidden=100, n_output=data.output) 60 | # ======================================== # 61 | else: 62 | raise Exception('Unknown baseNet: %s' % baseNet) 63 | 64 | 65 | metaNet = MetaNet( 66 | baseNet=BaseNet, 67 | update_step_train=update_step_train, 68 | update_step_target=update_step_target, 69 | meta_lr=meta_lr, 70 | base_lr=base_lr, 71 | fine_lr=fine_lr 72 | ).to(device) 73 | 74 | training_result = { 75 | 'target task': None, 76 | 'qry_loss': None, 77 | 'rmse': None, 78 | 'mape': None, 79 | 'smape': None, 80 | 'rmse_best_epoch': None, 81 | 'mape_best_epoch': None, 82 | 'smape_best_epoch': None, 83 | 'training time:': None, 84 | 'date': None, 85 | 'log': log, 86 | 'maml_log': None 87 | } 88 | 89 | # training 90 | start = time.time() 91 | step = 0 92 | batch_num = 0 93 | train_loss = {} 94 | test_loss = {} 95 | # print(data.task_num) 96 | 97 | if maml: 98 | while step < epoch_num: 99 | 100 | (spt_x, spt_y), (qry_x, qry_y), task_id = batch_task(data, batch_task_num=batch_task_num) 101 | 102 | batch_num += 1 103 | 104 | print('[%d]===================== training 元网络========================= :[%s]' % (step, task_id)) 105 | metrics = metaNet(spt_x, spt_y, qry_x, qry_y,device=device) 106 | print('| train_task: %s | qry_loss: %.4f | qry_rmse: %.4f | qry_mape: %.4f | qry_smape: %.4f |' 107 | % (task_id, metrics['loss'], metrics['rmse'], metrics['mape'], metrics['smape'])) 108 | 109 | 110 | if batch_num % (data.task_num // batch_task_num) == 0: 111 | step += 1 112 | (spt_x, spt_y, qry_x, qry_y), task_id = data.next('test') 113 | 114 | spt_x, spt_y, qry_x, qry_y = torch.from_numpy(spt_x).to(device), \ 115 | torch.from_numpy(spt_y).to(device), \ 116 | torch.from_numpy(qry_x).to(device), \ 117 | torch.from_numpy(qry_y).to(device) 118 | print('===================== fine tuning 目标网络 ========================= :[%s]' % task_id) 119 | metrics = metaNet.fine_tuning(spt_x, spt_y, qry_x, qry_y) 120 | 121 | if training_result['qry_loss'] is None: 122 | training_result['target task'] = task_id 123 | training_result['qry_loss'] = metrics['test_loss'] 124 | training_result['mape'] = metrics['mape'] 125 | training_result['smape'] = metrics['smape'] 126 | training_result['rmse'] = metrics['rmse'] 127 | training_result['rmse_best_epoch'] = metrics['rmse_best_epoch'] 128 | training_result['mape_best_epoch'] = metrics['mape_best_epoch'] 129 | training_result['smape_best_epoch'] = metrics['smape_best_epoch'] 130 | else: 131 | if training_result['rmse'] > metrics['rmse']: 132 | training_result['qry_loss'] = metrics['test_loss'] 133 | training_result['rmse'] = metrics['rmse'] 134 | training_result['rmse_best_epoch'] = metrics['rmse_best_epoch'] 135 | if training_result['mape'] > metrics['mape']: 136 | training_result['mape'] = metrics['mape'] 137 | training_result['mape_best_epoch'] = metrics['mape_best_epoch'] 138 | if training_result['smape'] > metrics['smape']: 139 | training_result['smape'] = metrics['smape'] 140 | training_result['smape_best_epoch'] = metrics['smape_best_epoch'] 141 | write_log( 142 | 'Epoch [%d] | ' 143 | 'target_task_id: %s | ' 144 | 'qry_loss: %.4f | ' 145 | 'rmse: %.4f(%d) | ' 146 | 'smape: %.4f(%d) |' 147 | % ( 148 | step, 149 | task_id, 150 | metrics['test_loss'], 151 | metrics['rmse'], metrics['rmse_best_epoch'], 152 | metrics['smape'], metrics['smape_best_epoch'] 153 | ), 154 | file=log 155 | ) 156 | else: 157 | 158 | # updated date: 2021-10-20 159 | if not new_settings: 160 | (train_x, train_y), (test_x, test_y), task_id = data.get_data(test_user_index) 161 | else: 162 | (train_x, train_y), (test_x, test_y), task_id = new_settings_get_training_data(data,test_user_index) 163 | # ============================================= # 164 | 165 | 166 | spt_x, spt_y, qry_x, qry_y = torch.from_numpy(train_x).to(device), \ 167 | torch.from_numpy(train_y).to(device), \ 168 | torch.from_numpy(test_x).to(device), \ 169 | torch.from_numpy(test_y).to(device) 170 | 171 | print('===================== training %s ========================= :[%s]' % (baseNet, task_id)) 172 | metrics = metaNet.fine_tuning(spt_x, spt_y, qry_x, qry_y, naive=True) 173 | # ============================================= # 174 | # updated date: 2021-10-20 175 | if new_settings: 176 | print('=====================new settings finetuning=====================') 177 | (train_x, train_y), (test_x, test_y), task_id = new_settings_get_finetune_data(data,test_user_index) 178 | spt_x, spt_y, qry_x, qry_y = torch.from_numpy(train_x).to(device), \ 179 | torch.from_numpy(train_y).to(device), \ 180 | torch.from_numpy(test_x).to(device), \ 181 | torch.from_numpy(test_y).to(device) 182 | metaNet.update_step_target = ft_step 183 | metrics = metaNet.fine_tuning(spt_x, spt_y, qry_x, qry_y, naive=True) 184 | # ============================================= # 185 | 186 | save_loss(lsd, metrics['loss_set'], *[task_id, baseNet, 'loss_set']) 187 | train_loss.setdefault(task_id, metrics['train_loss']) 188 | test_loss.setdefault(task_id, metrics['test_loss']) 189 | write_log( 190 | 'target_task_id: %s | ' 191 | 'spt_loss: %.4f |' 192 | 'qry_loss: %.4f | ' 193 | 'rmse: %.4f(%d)| ' 194 | 'smape: %.4f(%d)|' 195 | % ( 196 | task_id, 197 | metrics['train_loss'], 198 | metrics['test_loss'], 199 | metrics['rmse'], metrics['rmse_best_epoch'], 200 | metrics['smape'], metrics['smape_best_epoch'] 201 | ), 202 | file=log 203 | ) 204 | # save metrics rmse, mape, smape 205 | write_log('%.4f (%d)' % (metrics['rmse'], metrics['rmse_best_epoch']), file=rmse_path, terminate=False) 206 | write_log('%.4f (%d)' % (metrics['mape'], metrics['mape_best_epoch']), file=mape_path, terminate=False) 207 | write_log('%.4f (%d)' % (metrics['smape'], metrics['smape_best_epoch']), file=smape_path, terminate=False) 208 | pass 209 | end = time.time() 210 | # save loss 211 | save_loss(lsd, train_loss, *[baseNet, 'train', 'loss']) 212 | save_loss(lsd, test_loss, *[baseNet, 'test', 'loss']) 213 | training_result['training time'] = '%s Min' % str((end - start) / 60) 214 | training_result['date'] = datetime.strftime(datetime.now(), '%Y/%m/%d %H:%M:%S') 215 | if maml: 216 | training_result['maml_log'] = maml_log 217 | train_result(training_result, file=maml_log) 218 | # save metrics rmse, mape, smape 219 | write_log('%.4f (%d)' % (training_result['rmse'], training_result['rmse_best_epoch']), file=rmse_path, terminate=False) 220 | write_log('%.4f (%d)' % (training_result['mape'], training_result['mape_best_epoch']), file=mape_path, terminate=False) 221 | write_log('%.4f (%d)' % (training_result['smape'], training_result['smape_best_epoch']), file=smape_path, terminate=False) 222 | else: 223 | # train_result(training_result, file=log) 224 | pass 225 | 226 | return metrics['smape'], metrics['rmse'] 227 | 228 | 229 | def save_loss(lsd, obj, *others): 230 | 231 | if not osp.exists(lsd): 232 | os.mkdir(lsd) 233 | loss_path = osp.join(lsd, generate_filename('.pkl', *others, timestamp=False)) 234 | obj_serialization(loss_path, obj) 235 | print('loss serialization is finished!') 236 | 237 | 238 | def train_result(data_dict, file='./log.txt'): 239 | write_log('training result:============================', file=file) 240 | for key, value in data_dict.items(): 241 | write_log(' %s: %s' % (key, value), file=file) 242 | write_log('============================================', file=file) 243 | 244 | 245 | def batch_task(data, batch_task_num=1, ablation=1): 246 | 247 | # abalation == 1: means that uses all tasks as training task set 248 | # abalation == 0: menas that only uses UCR tasks as training task set 249 | 250 | (spt_x, spt_y, qry_x, qry_y), task_id = data.next('train') 251 | train_x = list([spt_x]) 252 | train_y = list([spt_y]) 253 | test_x = list([qry_x]) 254 | test_y = list([qry_y]) 255 | 256 | if ablation == 1: 257 | while batch_task_num > 1: 258 | (x1, y1, x2, y2), temp = data.next('train') 259 | train_x.append(x1) 260 | train_y.append(y1) 261 | test_x.append(x2) 262 | test_y.append(y2) 263 | task_id += ('-' + temp) 264 | batch_task_num -= 1 265 | elif ablation == 0: 266 | while batch_task_num > 1: 267 | (x1, y1, x2, y2), temp = data.next('train') 268 | if temp.isdigit(): 269 | continue 270 | train_x.append(x1) 271 | train_y.append(y1) 272 | test_x.append(x2) 273 | test_y.append(y2) 274 | task_id += ('-' + temp) 275 | batch_task_num -= 1 276 | else: 277 | raise Exception('UnKnown abalaion code: [%d]' % ablation) 278 | 279 | return (train_x, train_y), (test_x, test_y), task_id 280 | # ==================================================================================== # 281 | # updated date: 2021-10-20 282 | # in light of reviewer's suggestions, add a group of experiments settings 283 | 284 | def new_settings_get_training_data(data, target_task): 285 | 286 | # data: UCR 287 | # 将除target_task以外的其他的数据集打包到一起进行训练 288 | 289 | dataset_list = copy.deepcopy(TRAINING_TASK_SET) 290 | dataset_list.remove(target_task) 291 | 292 | (train_x, train_y), (test_x, test_y), _ = data.get_data(dataset_list[0]) 293 | #print(dataset_list[0]) 294 | #print(train_x.shape, train_y.shape, test_x.shape, test_y.shape) 295 | for dataset in dataset_list[1:]: 296 | # print(dataset) 297 | (temp_1, temp_2), (temp_3, temp_4), _ = data.get_data(dataset) 298 | # print(temp_1.shape, temp_2.shape, temp_3.shape, temp_4.shape) 299 | train_x = np.concatenate((train_x, temp_1), axis=0) 300 | train_y = np.concatenate((train_y, temp_2), axis=0) 301 | test_x = np.concatenate((test_x, temp_3), axis=0) 302 | test_y = np.concatenate((test_y, temp_4), axis=0) 303 | # print(train_x.shape, train_y.shape, test_x.shape, test_y.shape) 304 | return (train_x, train_y), (test_x, test_y), target_task 305 | 306 | 307 | def new_settings_get_finetune_data(data, target_task): 308 | 309 | # 用target task 进行微调 310 | 311 | return data.get_data(target_task) 312 | 313 | # ================================================================================== # 314 | 315 | 316 | 317 | if __name__ == '__main__': 318 | 319 | pass 320 | 321 | -------------------------------------------------------------------------------- /data/few_shot_data/test_data_embedding_10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/test_data_embedding_10.pkl -------------------------------------------------------------------------------- /data/few_shot_data/test_data_embedding_20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/test_data_embedding_20.pkl -------------------------------------------------------------------------------- /data/few_shot_data/test_data_embedding_30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/test_data_embedding_30.pkl -------------------------------------------------------------------------------- /data/few_shot_data/test_data_embedding_40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/test_data_embedding_40.pkl -------------------------------------------------------------------------------- /data/few_shot_data/train_data_embedding_10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/train_data_embedding_10.pkl -------------------------------------------------------------------------------- /data/few_shot_data/train_data_embedding_20.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/train_data_embedding_20.pkl -------------------------------------------------------------------------------- /data/few_shot_data/train_data_embedding_30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/train_data_embedding_30.pkl -------------------------------------------------------------------------------- /data/few_shot_data/train_data_embedding_40.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaofeng-github/Meta-Learning4FSTSF/724ed0b3093b1416d176fcb033b04125de988685/data/few_shot_data/train_data_embedding_40.pkl -------------------------------------------------------------------------------- /embedding/data_preprocessing.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | __author__ = 'XF' 3 | __date__ = '2022-07-11' 4 | 5 | ''' 6 | the scripts for time series data processing. 7 | ''' 8 | 9 | import os 10 | import os.path as osp 11 | import torch 12 | import numpy as np 13 | from collections import OrderedDict 14 | from configs import DATA_DIR as DATADIR 15 | from configs import few_shot_dataset_name 16 | from tools.tools import obj_serialization, read_tsv, obj_unserialization, generate_filename 17 | from sklearn import preprocessing 18 | 19 | 20 | def few_shot_data(path=None): 21 | 22 | if path is None: 23 | raise Exception('The parameter "path" is None!') 24 | dataset_file_names = os.listdir(path) 25 | train_dataset = OrderedDict() 26 | test_dataset = OrderedDict() 27 | process_traindata_num = 0 28 | process_testdata_num = 0 29 | for dir in dataset_file_names: 30 | dataset_dir = osp.join(path, dir) 31 | if os.path.isdir(dataset_dir): 32 | train_file_path = osp.join(dataset_dir, '%s_TRAIN.tsv' % dir) 33 | test_file_path = osp.join(dataset_dir, '%s_TEST.tsv' % dir) 34 | if os.path.isfile(train_file_path): 35 | train_dataset.setdefault(dir, read_tsv(train_file_path).loc[:, 1:].values.astype(np.float64)) 36 | process_traindata_num += 1 37 | else: 38 | print('"%s" is not a file!' % train_file_path) 39 | if os.path.isfile(test_file_path): 40 | test_dataset.setdefault(dir, read_tsv(test_file_path).loc[:, 1:].values.astype(np.float64)) 41 | process_testdata_num += 1 42 | else: 43 | print('"%s" is not a file!' % test_file_path) 44 | 45 | obj_serialization(osp.join(DATADIR, 'train_data.pkl'), train_dataset) 46 | obj_serialization(osp.join(DATADIR, 'test_data.pkl'), test_dataset) 47 | print('train_process_num: %d' % process_traindata_num) 48 | print('test_process_num: %d' % process_testdata_num) 49 | 50 | 51 | def split_data(data_path=None, ratio=0.1, shuffle=False, data=None): 52 | 53 | if data is None: 54 | data = obj_unserialization(data_path) 55 | if int(ratio) >= len(data): 56 | return [], [] 57 | if 0 < ratio < 1: 58 | train_data_size = int(len(data) * ratio) 59 | elif 1 <= ratio < len(data): 60 | train_data_size = int(ratio) 61 | else: 62 | raise Exception('Invalid value about "ratio" --> [%s]' % str(ratio)) 63 | 64 | if train_data_size == 0: 65 | val_data = data 66 | train_data = [] 67 | else: 68 | train_data = data[:train_data_size] 69 | val_data = data[train_data_size:] 70 | return train_data, val_data 71 | 72 | 73 | def create_sequence(data, ratio=0.1): 74 | 75 | forecast_point_num = int(len(data[0]) * ratio) 76 | position = len(data[0]) - forecast_point_num 77 | xs = np.array(data)[:, :position] 78 | ys = np.array(data)[:, position:] 79 | 80 | return torch.from_numpy(xs).float().unsqueeze(dim=2), torch.from_numpy(ys).float(), position, forecast_point_num 81 | 82 | 83 | def construct_dataset(size=100): 84 | 85 | file_list = os.listdir(DATADIR) 86 | 87 | counter = 0 88 | data_dict = {} 89 | file_num = len(file_list) 90 | for file in file_list: 91 | data_path = osp.join(DATADIR, file) 92 | if os.path.isdir(osp.join(DATADIR, file)): 93 | file_num -= 1 94 | continue 95 | data_dict.setdefault(file.split('.')[0], obj_unserialization(data_path)) 96 | counter += 1 97 | if counter % size == 0: 98 | save_path = osp.join(DATADIR, 99 | 'dataset\\%s' % generate_filename('.pkl', *['UCR', str(counter - size + 1), str(counter)]) 100 | ) 101 | obj_serialization(save_path, data_dict) 102 | print(save_path) 103 | data_dict.clear() 104 | elif counter == file_num: 105 | save_path = osp.join(DATADIR, 106 | 'dataset\\%s' % generate_filename('.pkl', *['UCR', str(size * (counter // size) + 1), str(counter)]) 107 | ) 108 | obj_serialization(save_path, data_dict) 109 | print(save_path) 110 | data_dict.clear() 111 | 112 | 113 | def normalizer(data=None): 114 | 115 | # z-score normalization 116 | 117 | if data is not None: 118 | return preprocessing.scale(data, axis=1) 119 | else: 120 | raise Exception('data is None!') 121 | pass 122 | 123 | 124 | def get_basic_data(): 125 | 126 | load_data_path = osp.join(DATADIR, 'few_shot_data\\few_shot_load_data.pkl') 127 | UCR_train_data_path = osp.join(DATADIR, 'train_data.pkl') 128 | UCR_test_data_path = osp.join(DATADIR, 'test_data.pkl') 129 | 130 | load_data = obj_unserialization(load_data_path) 131 | UCR_train_data = obj_unserialization(UCR_train_data_path) 132 | UCR_test_data = obj_unserialization(UCR_test_data_path) 133 | 134 | few_shot_train_data = OrderedDict() 135 | few_shot_test_data = OrderedDict() 136 | for key, value in load_data.items(): 137 | # if key in DIRTY_DATA_ID: 138 | # continue 139 | train_data, test_data = split_data(data=value, ratio=0.5) 140 | few_shot_train_data.setdefault(key, np.array(train_data)) 141 | few_shot_test_data.setdefault(key, np.array(test_data)) 142 | print('valid load dataset: %d' % len(few_shot_train_data)) 143 | 144 | for key, value in UCR_train_data.items(): 145 | if key in few_shot_dataset_name: 146 | few_shot_train_data.setdefault(key, value) 147 | few_shot_test_data.setdefault(key, UCR_test_data[key]) 148 | 149 | print('all the few shot dataset: %d' % len(few_shot_train_data)) 150 | 151 | obj_serialization(osp.join(DATADIR, 'few_shot_data\\train_data.pkl'), few_shot_train_data) 152 | obj_serialization(osp.join(DATADIR, 'few_shot_data\\test_data.pkl'), few_shot_test_data) 153 | 154 | 155 | if __name__ == '__main__': 156 | 157 | pass -------------------------------------------------------------------------------- /embedding/embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | __author__ = 'XF' 3 | __date__ = '2022-07-11' 4 | ''' 5 | This script is set to finish time series data embedding for uniting the length of data. 6 | ''' 7 | 8 | # builtins library 9 | import os 10 | from collections import OrderedDict 11 | 12 | # third-party library 13 | import torch 14 | import numpy as np 15 | import torch.nn as nn 16 | 17 | # self-defined wheels 18 | from data_preprocessing import normalizer 19 | from tools.tools import obj_serialization, obj_unserialization 20 | from configs import DATA_DIR as DATADIR 21 | 22 | class EmbeddingBiGRU(nn.Module): 23 | 24 | def __init__(self, n_input, n_hidden, batch_size=100, bidirectional=True, forecasting_point_num=10): 25 | super().__init__() 26 | self.n_input = n_input 27 | self.n_hidden = n_hidden 28 | self.batch_size = batch_size 29 | self.bidirectional = bidirectional 30 | self.forecasting_point_num = forecasting_point_num 31 | 32 | self.bigru = nn.GRU( 33 | input_size=self.n_input, 34 | hidden_size=self.n_hidden, 35 | bidirectional=self.bidirectional, 36 | num_layers=1, 37 | 38 | ) 39 | 40 | def batch_train(self, x): 41 | 42 | # x: (record size, seq_length, dim feature) 43 | record_size = x.shape[0] 44 | batch_data = [] 45 | if record_size <= self.batch_size: 46 | batch_data.append(x) 47 | else: 48 | for pos in range(0, record_size, self.batch_size): 49 | batch_data.append(x[pos:pos + self.batch_size, :]) 50 | if pos + self.batch_size < record_size: 51 | batch_data.append(x[pos + self.batch_size:, :]) 52 | return batch_data 53 | 54 | def forward(self, x): 55 | 56 | # x: (batch size, seq_length, dim_feature) --> (seq_length, batch size, input size) 57 | batch_data = self.batch_train(x[:, :x.shape[1] - self.forecasting_point_num, :]) 58 | forecasting_data = x[:, x.shape[1] - self.forecasting_point_num:, :].squeeze(dim=2).numpy() 59 | embedding = [] 60 | for batch in batch_data: 61 | batch = batch.contiguous().view(batch.shape[1], len(batch), -1) 62 | gru_out, h_n = self.bigru(batch) 63 | 64 | forward_embedding = h_n[0, :, :].detach().numpy() 65 | backward_embedding = h_n[1, :, :].detach().numpy() 66 | embedding.append(np.concatenate((forward_embedding, backward_embedding), axis=1)) 67 | embedding = np.concatenate(embedding, axis=0) 68 | embedding = np.concatenate((embedding, forecasting_data), axis=1) 69 | return embedding.astype(np.float64) 70 | 71 | pass 72 | 73 | 74 | if __name__ == '__main__': 75 | 76 | train_data_path = os.path.join(DATADIR, 'few_shot_data\\train_data.pkl') 77 | test_data_path = os.path.join(DATADIR, 'few_shot_data\\test_data.pkl') 78 | forecasting_point_num = 40 79 | model = EmbeddingBiGRU(n_input=1, n_hidden=100, forecasting_point_num=forecasting_point_num) 80 | 81 | # train data embedding …… 82 | print('train data embedding .......') 83 | train_data = obj_unserialization(train_data_path) 84 | train_data_embedding = OrderedDict() 85 | 86 | for key, value in train_data.items(): 87 | input_data = torch.from_numpy(normalizer(value)).float().unsqueeze(dim=2) 88 | embedding_data = model(input_data) 89 | train_data_embedding.setdefault(key, embedding_data) 90 | print('train data dimension: %d' % len(train_data_embedding['0001'][0])) 91 | obj_serialization(os.path.join(DATADIR, 'few_shot_data\\train_data_embedding_%s.pkl' % str(forecasting_point_num)), train_data_embedding) 92 | 93 | # test data embedding …… 94 | print('test data embedding ......') 95 | test_data = obj_unserialization(test_data_path) 96 | test_data_embedding = OrderedDict() 97 | 98 | for key, value in test_data.items(): 99 | input_data = torch.from_numpy(normalizer(value)).float().unsqueeze(dim=2) 100 | embedding_data = model(input_data) 101 | test_data_embedding.setdefault(key, embedding_data) 102 | print('test data dimension: %d' % len(test_data_embedding['0001'][0])) 103 | obj_serialization(os.path.join(DATADIR, 'few_shot_data\\test_data_embedding_%s.pkl' % str(forecasting_point_num)), test_data_embedding) 104 | 105 | print('OK!') 106 | pass 107 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | __author__ = 'XF' 3 | __date__ = '2022/03/10' 4 | 5 | 'begin from here' 6 | 7 | 8 | # built-in library 9 | import os.path as osp 10 | import time 11 | 12 | # third-party library 13 | 14 | # self-defined tools 15 | from core.options import parse_args 16 | from core.train import train 17 | from configs import TRAINING_TASK_SET, DATA_DIR 18 | 19 | if __name__ == '__main__': 20 | 21 | # python train_maml.py --mode together --model cnn+maml --epoch 100 --dataSet maml_load_data(14).pkl --ratio 0.9 --time_size 3 --update_step_train 10 --update_step_target 20 --meta_lr 0.001 --base_lr 0.01 22 | 23 | print('--------------------------------------------Time Series Forecasting------------------------------------------') 24 | params = parse_args('main') 25 | start = time.time() 26 | if params.user_id == 'none': 27 | # conducting all tasks 28 | for user_id in TRAINING_TASK_SET: 29 | train( 30 | epoch_num=params.epoch, 31 | test_user_index=user_id, 32 | add_dim_pos=params.add_dim_pos, 33 | data_path=osp.join(DATA_DIR, params.dataset), 34 | update_step_train=params.update_step_train, 35 | update_step_target=params.update_step_target, 36 | meta_lr=params.meta_lr, 37 | base_lr=params.base_lr, 38 | fine_lr=params.fine_lr, 39 | device=params.device, 40 | baseNet=params.baseNet, 41 | maml=params.maml, 42 | log=params.log, 43 | maml_log=params.maml_log, 44 | lsd=params.lsd, 45 | ppn=params.ppn, 46 | batch_task_num=params.batch_task_num, 47 | rmse_path=params.rmse_path, 48 | mape_path=params.mape_path, 49 | smape_path=params.smape_path, 50 | ft_step=params.ft_step, 51 | new_settings=params.new_settings 52 | ) 53 | end = time.time() 54 | print('using time: %.4f Hour' % ((end - start) / 3600.0)) 55 | print('training is over!') 56 | elif params.user_id in TRAINING_TASK_SET: 57 | # conducting single task 58 | smape, rmse = train( 59 | epoch_num=params.epoch, 60 | test_user_index=params.user_id, 61 | add_dim_pos=params.add_dim_pos, 62 | data_path=osp.join(DATA_DIR, params.dataset), 63 | update_step_train=params.update_step_train, 64 | update_step_target=params.update_step_target, 65 | meta_lr=params.meta_lr, 66 | base_lr=params.base_lr, 67 | fine_lr=params.fine_lr, 68 | device=params.device, 69 | baseNet=params.baseNet, 70 | maml=params.maml, 71 | log=params.log, 72 | maml_log=params.maml_log, 73 | lsd=params.lsd, 74 | ppn=params.ppn, 75 | rmse_path=params.rmse_path, 76 | mape_path=params.mape_path, 77 | smape_path=params.smape_path, 78 | ft_step=params.ft_step, 79 | new_settings=params.new_settings 80 | ) 81 | end = time.time() 82 | print('using time: %.4f Min' % ((end - start) / 60.0)) 83 | print('training is over!') 84 | print('smape: %.4f' % smape) 85 | print('rmse: %.4f' % rmse) 86 | else: 87 | raise Exception('Unknown user id!') 88 | 89 | 90 | -------------------------------------------------------------------------------- /tools/tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | __author__ = 'XF' 3 | __date__ = '2022-07-11' 4 | ''' 5 | The script is set for supplying some tool function. 6 | ''' 7 | 8 | import os 9 | import time 10 | import pickle 11 | import pandas as pds 12 | 13 | 14 | 15 | def read_tsv(path=None, header=None): 16 | 17 | if path is None: 18 | raise FileExistsError('The path is None!') 19 | 20 | content = pds.read_csv(path, sep='\t', header=header, ) 21 | return content 22 | 23 | 24 | # object serialization 25 | def obj_serialization(path, obj): 26 | 27 | if obj is not None: 28 | with open(path, 'wb') as file: 29 | pickle.dump(obj, file) 30 | else: 31 | print('object is None!') 32 | 33 | 34 | # object instantiation 35 | def obj_unserialization(path): 36 | 37 | if os.path.exists(path): 38 | with open(path, 'rb') as file: 39 | obj = pickle.load(file) 40 | return obj 41 | else: 42 | raise OSError('no such path:%s' % path) 43 | 44 | 45 | def generate_filename(suffix, *args, sep='_', timestamp=False): 46 | 47 | ''' 48 | 49 | :param suffix: suffix of file 50 | :param sep: separator,default '_' 51 | :param timestamp: add timestamp for uniqueness 52 | :param args: 53 | :return: 54 | ''' 55 | 56 | filename = sep.join(args).replace(' ', '_') 57 | if timestamp: 58 | filename += time.strftime('_%Y%m%d%H%M%S') 59 | if suffix[0] == '.': 60 | filename += suffix 61 | else: 62 | filename += ('.' + suffix) 63 | 64 | return filename 65 | 66 | 67 | def metrics(y, y_hat): 68 | 69 | assert y.shape == y_hat.shape # Tensor y and Tensor y_hat must have the same shape 70 | y = y.cpu() 71 | y_hat = y_hat.cpu() 72 | # mape 73 | _mape = mape(y, y_hat) 74 | 75 | # smape 76 | _smape = smape(y, y_hat) 77 | 78 | # rmse 79 | _rmse = rmse(y, y_hat) 80 | 81 | return _rmse, _mape, _smape 82 | 83 | 84 | def mape(Y, Y_hat): 85 | 86 | temp = [abs((y - y_hat) / y) for y, y_hat in zip(Y.view(-1).numpy(), Y_hat.view(-1).numpy())] 87 | return (sum(temp) / len(temp)) * 100 88 | 89 | 90 | def smape(Y, Y_hat): 91 | 92 | temp = [abs(y- y_hat) / (abs(y) + abs(y_hat)) for y, y_hat in zip(Y.view(-1).numpy(), Y_hat.view(-1).numpy())] 93 | return (sum(temp) / len(temp)) * 200 94 | 95 | 96 | def rmse(Y, Y_hat): 97 | 98 | temp = [pow(y - y_hat, 2) for y, y_hat in zip(Y.view(-1).numpy(), Y_hat.view(-1).numpy())] 99 | return pow(sum(temp) / len(temp), 0.5) 100 | 101 | 102 | if __name__ == '__main__': 103 | pass --------------------------------------------------------------------------------