├── README.md ├── generate_data.py ├── main.py ├── model.py ├── prepareData.py ├── requirements └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # AGC-Net: Adaptive Graph Convolution Networks for Traffic Flow Forecasting 2 | 3 | 4 | 5 | AGC-Net (Adaptive Graph Convolution Networks) is an advanced model designed to predict traffic flow. The paper is available [here](https://arxiv.org/abs/2307.05517). 6 | 7 | 8 | ## Citation 9 | 10 | If you find this work useful for your research, please cite our paper: 11 | 12 | 13 | ```bibtex 14 | 15 | @article{li2023adaptive, 16 | title={Adaptive Graph Convolution Networks for Traffic Flow Forecasting}, 17 | author={Zhengdao Li and Wei Li and Kai Hwang}, 18 | year={2023}, 19 | eprint={2307.05517}, 20 | archivePrefix={arXiv}, 21 | primaryClass={cs.LG} 22 | } 23 | 24 | ``` 25 | 26 | 27 | ## Installation 28 | 29 | Before proceeding with the model training, ensure all necessary packages are installed. To install the requirements, run the following command: 30 | 31 | 32 | 33 | ```bash 34 | 35 | pip install -r requirements -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com 36 | 37 | ``` 38 | 39 | 40 | ## Data Preparation 41 | 42 | 43 | 44 | To prepare the data, please make sure you have the METR-LA dataset placed inside the `./data/` directory in a sub-directory named `METR-LA-12`. If you don't have the dataset, you can download it from the [DCRNN](https://github.com/liyaguang/DCRNN) (Note: replace with the appropriate link). The `feature_len` parameter denotes the feature length of the dataset. Here, we use a `feature_len` of 3. 45 | 46 | 47 | ## Training 48 | 49 | 50 | 51 | Once the data is prepared, you can train the AGC-Net model. The following is a sample command to initiate the training: 52 | 53 | 54 | 55 | ```bash 56 | 57 | python main.py --predict_len=12 --cuda --att --data_path=./data/METR-LA-12 --feature_len=3 --wavelets_num=20 --transpose --epochs=1 --best_model_save_path=best_model_12_30w 58 | 59 | ``` 60 | 61 | 62 | 63 | Here is a brief explanation of the command-line arguments: 64 | 65 | 66 | 67 | * `--predict_len`: The number of future time steps to be predicted by the model (12 in this case). 68 | 69 | * `--cuda`: If present, use GPU for training. 70 | 71 | * `--att`: If present, use the attention mechanism in the model. 72 | 73 | * `--data_path`: The path to the directory where the dataset is stored. 74 | 75 | * `--feature_len`: The feature length of the dataset. 76 | 77 | * `--wavelets_num`: The number of wavelet functions to be used (20 in this case). 78 | 79 | * `--transpose`: If present, transpose the input data. 80 | 81 | * `--epochs`: The number of epochs to train the model. 82 | 83 | * `--best_model_save_path`: The path where the model with the best validation performance should be saved. 84 | 85 | 86 | 87 | The `best_model_12_30w` will be saved in the provided path upon successful training of the model. 88 | 89 | 90 | 91 | Feel free to explore and adapt the model to suit your own requirements. We look forward to your contribution and feedback. 92 | 93 | 94 | 95 | 96 | 97 | 98 | ## License 99 | 100 | [MIT](https://choosealicense.com/licenses/mit/) 101 | -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import argparse 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | from prepareData import get_sample_indices 11 | 12 | 13 | def generate_graph_seq2seq_io_data( 14 | df, x_offsets, y_offsets, add_time_in_day=True, add_day_in_week=False, scaler=None 15 | ): 16 | """ 17 | Generate samples from 18 | :param df: 19 | :param x_offsets: 20 | :param y_offsets: 21 | :param add_time_in_day: 22 | :param add_day_in_week: 23 | :param scaler: 24 | :return: 25 | # x: (epoch_size, input_length, num_nodes, input_dim) 26 | # y: (epoch_size, output_length, num_nodes, output_dim) 27 | """ 28 | 29 | num_samples, num_nodes = df.shape 30 | data = np.expand_dims(df.values, axis=-1) 31 | feature_list = [data] 32 | if add_time_in_day: 33 | time_ind = (df.index.values - df.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D") 34 | time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0)) 35 | feature_list.append(time_in_day) 36 | if add_day_in_week: 37 | dow = df.index.dayofweek 38 | dow_tiled = np.tile(dow, [1, num_nodes, 1]).transpose((2, 1, 0)) 39 | feature_list.append(dow_tiled) 40 | 41 | data = np.concatenate(feature_list, axis=-1) 42 | x, y = [], [] 43 | min_t = abs(min(x_offsets)) 44 | max_t = abs(num_samples - abs(max(y_offsets))) # Exclusive 45 | for t in range(min_t, max_t): # t is the index of the last observation. 46 | x.append(data[t + x_offsets, ...]) 47 | y.append(data[t + y_offsets, ...]) 48 | x = np.stack(x, axis=0) 49 | y = np.stack(y, axis=0) 50 | return x, y 51 | 52 | 53 | def generate_train_val_test(args): 54 | seq_len = str(args.seq_length_x) 55 | seq_length_x, seq_length_y = args.seq_length_x, args.seq_length_y 56 | df = pd.read_hdf(args.traffic_df_filename) 57 | # 0 is the latest observed sample. 58 | x_offsets = np.sort(np.concatenate((np.arange(-(seq_length_x - 1), 1, 1),))) 59 | # Predict the next one hour 60 | y_offsets = np.sort(np.arange(args.y_start, (seq_length_y + 1), 1)) 61 | # x: (num_samples, input_length, num_nodes, input_dim) 62 | # y: (num_samples, output_length, num_nodes, output_dim) 63 | x, y = generate_graph_seq2seq_io_data( 64 | df, 65 | x_offsets=x_offsets, 66 | y_offsets=y_offsets, 67 | add_time_in_day=True, 68 | add_day_in_week=True, 69 | ) 70 | 71 | print("x shape: ", x.shape, ", y shape: ", y.shape) 72 | # Write the data into npz file. 73 | num_samples = x.shape[0] 74 | num_train = round(num_samples * 0.7) 75 | num_test = round(num_samples * 0.2) 76 | num_val = num_samples - num_test - num_train 77 | x_train, y_train = x[:num_train], y[:num_train] 78 | # x_test, y_test = ( 79 | # x[num_train: num_train + num_test], 80 | # y[num_train: num_train + num_test], 81 | # ) 82 | # x_val, y_val = x[-num_val:], y[-num_val:] 83 | x_val, y_val = ( 84 | x[num_train: num_train + num_val], 85 | y[num_train: num_train + num_val], 86 | ) 87 | x_test, y_test = x[-num_test:], y[-num_test:] 88 | 89 | for cat in ["train", "val", "test"]: 90 | _x, _y = locals()["x_" + cat], locals()["y_" + cat] 91 | print(cat, "x: ", _x.shape, "y:", _y.shape) 92 | np.savez_compressed( 93 | os.path.join(args.output_dir, f"{cat}.npz"), 94 | x=_x, 95 | y=_y, 96 | x_offsets=x_offsets.reshape(list(x_offsets.shape) + [1]), 97 | y_offsets=y_offsets.reshape(list(y_offsets.shape) + [1]), 98 | ) 99 | 100 | 101 | def read_and_generate_dataset(args, add_time_in_day=False, add_day_in_week=False): 102 | df = pd.read_hdf(args.traffic_df_filename) 103 | num_of_weeks = args.num_of_weeks 104 | num_of_days = args.num_of_days 105 | num_of_hours = args.num_of_hours 106 | num_for_predict = args.seq_length_x 107 | num_predict = args.seq_length_y 108 | points_per_hour = 12 109 | 110 | num_samples, num_nodes = df.shape 111 | data = np.expand_dims(df.values, axis=-1) 112 | 113 | feature_list = [data] 114 | if add_time_in_day: 115 | time_ind = (df.index.values - df.index.values.astype("datetime64[D]")) / np.timedelta64(1, "D") 116 | time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0)) 117 | feature_list.append(time_in_day) 118 | if add_day_in_week: 119 | dow = df.index.dayofweek 120 | dow_tiled = np.tile(dow, [1, num_nodes, 1]).transpose((2, 1, 0)) 121 | feature_list.append(dow_tiled) 122 | 123 | data = np.concatenate(feature_list, axis=-1) 124 | print(num_samples, num_nodes, data.shape) 125 | all_samples = [] 126 | for idx in range(data.shape[0]): 127 | sample = get_sample_indices(data, num_of_weeks, num_of_days, 128 | num_of_hours, idx, num_for_predict, 129 | num_predict, points_per_hour) 130 | if ((sample[0] is None) and (sample[1] is None) and (sample[2] is None)): 131 | continue 132 | 133 | week_sample, day_sample, hour_sample, target = sample 134 | 135 | sample = [] # [(week_sample),(day_sample),(hour_sample),target] 136 | 137 | if num_of_weeks > 0: 138 | week_sample = np.expand_dims(week_sample, axis=0).transpose((0, 2, 3, 1)) # (1,N,F,T) 139 | sample.append(week_sample) 140 | 141 | if num_of_days > 0: 142 | day_sample = np.expand_dims(day_sample, axis=0).transpose((0, 2, 3, 1)) # (1,N,F,T) 143 | sample.append(day_sample) 144 | 145 | if num_of_hours > 0: 146 | hour_sample = np.expand_dims(hour_sample, axis=0).transpose((0, 2, 3, 1)) # (1,N,F,T) 147 | sample.append(hour_sample) 148 | 149 | target = np.expand_dims(target, axis=0).transpose((0, 2, 3, 1))[:, :, 0, :] # (1,N,T) 150 | sample.append(target) 151 | 152 | all_samples.append( 153 | sample) # sampe:[(week_sample),(day_sample),(hour_sample),target] = [(1,N,F,Tw),(1,N,F,Td),(1,N,F,Th),(1,N,Tpre)] 154 | 155 | # print(all_samples[0]) 156 | print(all_samples[0][0].shape, all_samples[0][1].shape) 157 | split_line1 = int(len(all_samples) * 0.7) 158 | split_line2 = int(len(all_samples) * 0.8) 159 | 160 | training_set = [np.concatenate(i, axis=0) 161 | for i in zip(*all_samples[:split_line1])] # [(B,N,F,Tw),(B,N,F,Td),(B,N,F,Th),(B,N,Tpre)] 162 | validation_set = [np.concatenate(i, axis=0) 163 | for i in zip(*all_samples[split_line1: split_line2])] 164 | testing_set = [np.concatenate(i, axis=0) 165 | for i in zip(*all_samples[split_line2:])] 166 | 167 | train_x = np.concatenate(training_set[:-1], axis=-1).transpose((0, 3, 1, 2)) # (B,N,F,T') -> (B,T,N,F) 168 | val_x = np.concatenate(validation_set[:-1], axis=-1).transpose((0, 3, 1, 2)) 169 | test_x = np.concatenate(testing_set[:-1], axis=-1).transpose((0, 3, 1, 2)) 170 | 171 | train_target = training_set[-1].transpose((0, 2, 1)) # (B,N,T) -> (B,T,N) 172 | val_target = validation_set[-1].transpose((0, 2, 1)) 173 | test_target = testing_set[-1].transpose((0, 2, 1)) 174 | 175 | for cat in ["train", "val", "test"]: 176 | _x, _y = locals()[cat + "_x"], locals()[cat + "_target"] 177 | print(cat, "x: ", _x.shape, "y:", _y.shape) 178 | filename = os.path.join(args.output_dir, cat) 179 | np.savez_compressed( 180 | filename, 181 | x=_x, 182 | y=_y 183 | ) 184 | 185 | 186 | if __name__ == "__main__": 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument("--output_dir", type=str, default="data/PEMS-BAY-6", help="Output directory.") 189 | # parser.add_argument("--traffic_df_filename", type=str, default="data/metr-la.h5", help="Raw traffic readings.", ) 190 | parser.add_argument("--traffic_df_filename", type=str, default="data/pems-bay.h5", help="Raw traffic readings.",) 191 | parser.add_argument("--seq_length_x", type=int, default=12, help="Sequence Length.", ) 192 | parser.add_argument("--seq_length_y", type=int, default=6, help="Sequence Length.", ) 193 | parser.add_argument("--y_start", type=int, default=1, help="Y pred start", ) 194 | parser.add_argument("--num_of_weeks", type=int, default=1) 195 | parser.add_argument("--num_of_days", type=int, default=1) 196 | parser.add_argument("--num_of_hours", type=int, default=1) 197 | 198 | args = parser.parse_args() 199 | if os.path.exists(args.output_dir): 200 | reply = str(input(f'{args.output_dir} exists. Do you want to overwrite it? (y/n)')).lower().strip() 201 | if reply[0] != 'y': exit 202 | else: 203 | os.makedirs(args.output_dir) 204 | # generate_train_val_test(args) 205 | read_and_generate_dataset(args, False, False) 206 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | 3 | from util import * 4 | from model import * 5 | import matplotlib.pyplot as plt 6 | import time 7 | import os 8 | from fastprogress import progress_bar 9 | 10 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 11 | 12 | args = Args() 13 | 14 | 15 | def load_dataset(dataset_dir, batch_size): 16 | datasets = {} 17 | 18 | for category in ['train', 'val', 'test']: 19 | cat_data = np.load(os.path.join(dataset_dir, category + '.npz')) 20 | if args.dev: 21 | datasets['x_' + category] = cat_data['x'][:args.dev_size] 22 | else: 23 | datasets['x_' + category] = cat_data['x'] 24 | print(category + ' x size: ', datasets['x_' + category].shape) 25 | if args.dev: 26 | datasets['y_' + category] = cat_data['y'][:args.dev_size, ...] 27 | else: 28 | datasets['y_' + category] = cat_data['y'] 29 | print(category + ' y size: ', datasets['y_' + category].shape) 30 | 31 | # normalization of first feature: speed 32 | scaler = StandardScaler(mean=datasets['x_train'][..., 0].mean(), std=datasets['x_train'][..., 0].std()) 33 | 34 | # construct dataloader 35 | for category in ['train', 'val', 'test']: 36 | datasets['x_' + category][..., 0] = scaler.transform(datasets['x_' + category][..., 0]) 37 | # construct data 38 | datasets[category + '_loader'] = TrafficDataLoader(datasets['x_' + category], datasets['y_' + category], 39 | batch_size, args.cuda, transpose=args.transpose) 40 | print('finish load dataset!') 41 | return datasets, scaler 42 | 43 | 44 | def rnn_train(args, datasets, scaler): 45 | _, _, adj_mx = load_adj(args.adj_file, 'normlap') 46 | # adj_mx[adj_mx > 0.01] = 1 47 | # adj_mx[adj_mx<=0.01] = 0 48 | print(adj_mx) 49 | 50 | if args.rnn: 51 | model = RNNModel(args) 52 | print('load rnnmodel') 53 | else: 54 | model = STWN(adj_mx[0], args, is_gpu=args.cuda) 55 | print('load STWN') 56 | if args.pretrain: 57 | print('pretrainok') 58 | model.load_state_dict(torch.load(args.pre_model_path)) 59 | 60 | print('args_cuda:', args.cuda) 61 | if args.cuda: 62 | print('rnn_train RNNBlock to cuda!') 63 | model.cuda() 64 | else: 65 | print('rnn_train RNNBlock to cpu!') 66 | 67 | # optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 68 | # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mo, weight_decay=args.weight_decay) 69 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 70 | # optimizer = optim.Adam(net.parameters(), lr=learning_rate) 71 | trainer = Trainer(args, model, optimizer, scaler) 72 | 73 | best_model = dict() 74 | best_val_mae = 1000 75 | best_unchanged_threshold = 500 # accumulated epochs of best val_mae unchanged 76 | best_count = 0 77 | best_index = -1 78 | train_val_metrics = [] 79 | start_time = time.time() 80 | for e in range(args.epochs): 81 | print('Starting epoch: ', e) 82 | datasets['train_loader'].shuffle() 83 | train_loss, train_mae, train_mape, train_rmse = [], [], [], [] 84 | for i, (input_data, target) in enumerate(datasets['train_loader'].get_iterator()): 85 | if args.cuda: 86 | input_data = input_data.cuda() 87 | target = target.cuda() 88 | 89 | # yspeed = target[:, 0, :, :] 90 | input_data, target = Variable(input_data), Variable(target) 91 | # mse, mae, mape, rmse = trainer.train(input_data, target) 92 | loss, mae, mape, rmse = trainer.train(input_data, target) 93 | # training metrics 94 | train_loss.append(loss) 95 | train_mae.append(mae) 96 | train_mape.append(mape) 97 | train_rmse.append(rmse) 98 | 99 | # validation metrics 100 | # TODO: pick best model with best validation evaluation. 101 | datasets['val_loader'].shuffle() 102 | val_loss, val_mae, val_mape, val_rmse = [], [], [], [] 103 | for _, (input_data, target) in enumerate(datasets['val_loader'].get_iterator()): 104 | if args.cuda: 105 | input_data = input_data.cuda() 106 | target = target.cuda() 107 | input_data, target = Variable(input_data), Variable(target) 108 | mae, mape, rmse = trainer.eval(input_data, target) 109 | 110 | # add metrics 111 | val_mae.append(mae) 112 | val_mape.append(mape) 113 | val_rmse.append(rmse) 114 | val_loss.append(mae) 115 | m = dict(train_loss=np.mean(train_loss), train_mae=np.mean(train_mae), 116 | train_rmse=np.mean(train_rmse), train_mape=np.mean(train_mape), 117 | valid_loss=np.mean(val_loss), valid_mae=np.mean(val_mae), 118 | valid_mape=np.mean(val_mape), valid_rmse=np.mean(val_rmse) 119 | ) 120 | 121 | m = pd.Series(m) 122 | print(m) 123 | train_val_metrics.append(m) 124 | # once got best validation model ( 20 epochs unchanged), then we break. 125 | if m['valid_mae'] < best_val_mae: 126 | best_val_mae = m['valid_mae'] 127 | best_count = 0 128 | best_model = trainer.model.state_dict() 129 | best_index = e 130 | else: 131 | best_count += 1 132 | if best_count > best_unchanged_threshold: 133 | print('Got best') 134 | break 135 | # trainer.scheduler.step() 136 | # test metrics 137 | torch.save(best_model, args.best_model_save_path) 138 | trainer.model.load_state_dict(torch.load(args.best_model_save_path)) 139 | print('best_epoch:', best_index) 140 | 141 | test_metrics = [] 142 | test_mae, test_mape, test_rmse = [], [], [] 143 | for i, (input_data, target) in enumerate(datasets['test_loader'].get_iterator()): 144 | input_data, target = Variable(input_data), Variable(target) 145 | if target.max() == 0: continue 146 | mae, mape, rmse = trainer.eval(input_data, target) 147 | # add metrics 148 | test_mae.append(mae) 149 | test_mape.append(mape) 150 | test_rmse.append(rmse) 151 | m = dict(test_mape=np.mean(test_mape), test_rmse=np.mean(test_rmse), 152 | test_mae=np.mean(test_mae)) 153 | m = pd.Series(m) 154 | print("test:") 155 | print(m) 156 | 157 | test_metrics.append(m) 158 | plot(train_val_metrics, test_metrics, args.fig_filename) 159 | print('finish rnn_train!, time cost:', time.time() - start_time) 160 | # output learnable wavelets matrix: 161 | 162 | for i in range(len(model.gwblocks)): 163 | torch.save(model.gwblocks[i].wavelets, f"{i}_wavelets_maps.pt") 164 | 165 | 166 | 167 | 168 | def plot(train_val_metrics, test_metrics, fig_filename='mae'): 169 | epochs = len(train_val_metrics) 170 | x = range(epochs) 171 | train_mae = [m['train_mae'] for m in train_val_metrics] 172 | val_mae = [m['valid_mae'] for m in train_val_metrics] 173 | 174 | plt.figure(figsize=(8, 6)) 175 | plt.plot(x, train_mae, '', label='train_mae') 176 | plt.plot(x, val_mae, '', label='val_mae') 177 | plt.title('MAE') 178 | plt.legend(loc='upper right') # 设置label标记的显示位置 179 | plt.xlabel('epoch') 180 | plt.ylabel('mae') 181 | plt.grid() 182 | 183 | plt.savefig(fig_filename) 184 | 185 | 186 | # plt.show() 187 | 188 | 189 | if __name__ == "__main__": 190 | print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))) 191 | args = util.get_common_args() 192 | args.add_argument('--gwnnArgs', help='gwnnArgs') 193 | args.add_argument('--adj_mx', help='adj_mx') 194 | args = args.parse_args() 195 | 196 | print(args) 197 | datasets, scaler = load_dataset(args.data_path, args.batch_size) 198 | # t1 = time.time() 199 | 200 | rnn_train(args, datasets, scaler) 201 | # wavenet_train(args, datasets, scaler) 202 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import numpy as np 4 | import os 5 | 6 | import pandas as pd 7 | import scipy.sparse as sp 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from scipy.sparse import linalg 12 | from torch.autograd import Variable 13 | from torch.utils.data.dataset import Dataset 14 | import util 15 | from pygsp import graphs, filters 16 | import numpy as np 17 | import os 18 | import torch.optim as optim 19 | 20 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 21 | 22 | 23 | def createGWCBlock(adj_mx): 24 | blocks = [STWNBlock(adj_mx, 2, 4, 4), STWNBlock(adj_mx, 4, 8, 4), STWNBlock(adj_mx, 8, 1, 4)] 25 | return blocks 26 | 27 | 28 | def create_adj_kernel(N, size): 29 | Adj_kernel = nn.ParameterList( 30 | [nn.Parameter(torch.FloatTensor(N, N)) for _ in range(size)]) 31 | return Adj_kernel 32 | 33 | 34 | def create_mlp_kernel(in_channel, out_channel, kernel_size): 35 | mlp_kernel = nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1)) 36 | return mlp_kernel 37 | 38 | 39 | def exp_wavelet_kernels(Lamda, scale): 40 | kernels = [np.exp(-Lamda[i] * scale) for i in range(len(Lamda))] 41 | # kernels = [np.exp(-Lamda[i] * 1.0 / i) for i in range(len(Lamda))] 42 | # print('scale:',scale, kernels) 43 | return kernels 44 | 45 | 46 | def create_wnn_kernel_by_pygsp(adj_mx): 47 | G = graphs.Graph(adj_mx) 48 | print('{} nodes, {} edges'.format(G.N, G.Ne)) 49 | print(G) 50 | G.compute_laplacian('normalized') 51 | G.compute_fourier_basis() 52 | print('G.U', G.U) 53 | # print('G.U*G.UT', G.U ) 54 | G.set_coordinates('ring2D') 55 | G.plot() 56 | 57 | 58 | # TODO: this could be replaced by Chebyshev Polynomials 59 | def create_wnn_kernel_matrix(norm_laplacian, scale): 60 | U, Lamda, _ = torch.svd(torch.from_numpy(norm_laplacian)) 61 | kernels = exp_wavelet_kernels(Lamda, scale) 62 | # print(Lamda) 63 | G = torch.from_numpy(np.diag(kernels)) 64 | Phi = np.matmul(np.matmul(U, G), U.t()) 65 | Phi_inv = torch.inverse(Phi) 66 | # print('create_wnn_kernel_matrix: Phi:', Phi) 67 | return Phi, Phi_inv 68 | 69 | 70 | class STWN(nn.Module): 71 | def __init__(self, adj_mx, args, is_gpu=False): 72 | super(STWN, self).__init__() 73 | self.is_gpu = is_gpu 74 | self.adj_mx = adj_mx 75 | self.N = adj_mx.shape[0] 76 | self.upsampling = nn.Conv2d(args.feature_len, 32, kernel_size=(1, 1)) 77 | self.predict_len = args.predict_len 78 | self.args = args 79 | 80 | # add rnn encoder before gcn: 81 | # self.encoder = DecoderRNN(args.feature_len, 32, args.predict_len, num_layers=args.rnn_layer_num) 82 | self.encoder = EncoderRNN(args.feature_len, 32, args.predict_len,out_channel=args.predict_len, num_layers=args.rnn_layer_num) 83 | if args.att: 84 | print('using att') 85 | self.gwblocks = nn.ModuleList([AttSTWNBlock(adj_mx, args.feature_len, 32, args.wavelets_num, is_gpu=is_gpu) 86 | ]) 87 | if args.gcn_layer_num > 1: 88 | for i in range(1, args.gcn_layer_num): 89 | self.gwblocks.append(AttSTWNBlock(adj_mx, 32, 32, args.wavelets_num, is_gpu=is_gpu)) 90 | else: 91 | print('no att', args.att) 92 | self.gwblocks = nn.ModuleList( 93 | [STWNBlock(adj_mx, 32, 32, args.wavelets_num, is_gpu=is_gpu) 94 | ]) 95 | if args.gcn_layer_num > 1: 96 | for i in range(1, args.gcn_layer_num): 97 | self.gwblocks.append(STWNBlock(adj_mx, 32, 32, args.wavelets_num, is_gpu=is_gpu)) 98 | print("gcn_layer_num: ", args.gcn_layer_num) 99 | # self.readout = STWNBlock(adj_mx, 4, 1, 4) 100 | # residual + input feature 101 | self.W_W = nn.Parameter(torch.FloatTensor(self.N, self.predict_len)) 102 | self.D_W = nn.Parameter(torch.FloatTensor(self.N, self.predict_len)) 103 | self.H_W = nn.Parameter(torch.FloatTensor(self.N, self.predict_len)) 104 | 105 | self.decoder = NewDecoderRNN(32, 32, self.predict_len, num_layers=args.rnn_layer_num) 106 | # self.decoder = DecoderRNN(32, 32, self.predict_len, num_layers=args.rnn_layer_num) 107 | 108 | self.lout = nn.Conv1d(32, 1, kernel_size=1) 109 | self.dropout = nn.Dropout(0.2) 110 | self.weight_init() 111 | 112 | def forward(self, x): 113 | """ 114 | :param x: (batch, in_channel, N, sequence) 115 | """ 116 | seq_len = x.shape[3] // 3 117 | w_x, d_x, h_x = x[:, :, :, :seq_len], x[:, :, :, seq_len:seq_len * 2], x[:, :, :, seq_len * 2:] 118 | 119 | if self.args.fusion: 120 | out, _ = self.forward_fusion(w_x, d_x, h_x) 121 | else: 122 | out, _ = self.forward_one(h_x) 123 | 124 | out = out.transpose(1, 2) 125 | return out, out 126 | 127 | def forward_fusion(self, w_x, d_x, h_x): 128 | # x = (B, F, N, T) 129 | # print(w_x.shape, d_x.shape, h_x.shape) 130 | out_w, _ = self.forward_one(w_x) 131 | out_d, _ = self.forward_one(d_x) 132 | out_h, _ = self.forward_one(h_x) 133 | # (batch, N, predict_len) 134 | out_w = torch.einsum("bnt,nt->bnt", out_w, self.W_W) 135 | out_d = torch.einsum("bnt,nt->bnt", out_d, self.D_W) 136 | out_h = torch.einsum("bnt,nt->bnt", out_h, self.H_W) 137 | 138 | out = out_w + out_d + out_h 139 | return out, out 140 | 141 | def forward_one(self, x): 142 | # x = (B, F, N, T) 143 | # encoder: 144 | 145 | # lstm + GCN + fc: 146 | # x = self.encoder(x) 147 | 148 | # GCN + lstm: 149 | residual = F.relu(self.upsampling(x)) 150 | h = x 151 | for i in range(0, len(self.gwblocks)): 152 | h = residual + self.gwblocks[i](residual) 153 | 154 | # skip connection 155 | out = residual + h 156 | out = F.relu(out) 157 | 158 | # GCN + LSTM, decoder: 159 | out, h = self.decoder(out) 160 | 161 | # lstm + GCN + fc: 162 | # out = self.lout(out) 163 | 164 | # test without fc: 165 | # out = out.squeeze().transpose(1, 2) 166 | # out = self.fc(out) 167 | 168 | # out = B, N, T 169 | return out, out 170 | 171 | def weight_init(self): 172 | for m in self.modules(): 173 | if isinstance(m, nn.Conv2d): 174 | nn.init.kaiming_normal_(m.weight, mode='fan_in') 175 | print('STWN init module with kaiming', m) 176 | elif isinstance(m, nn.ParameterList): 177 | for i in m: 178 | nn.init.normal_(i, mean=0.0, std=0.001) 179 | print('STWN init parameterlist with norm') 180 | else: 181 | print('STWN ParameterList!Do nothing') 182 | # nn.init.kaiming_normal_(self.rnn.weight, mode='fan_in') 183 | nn.init.normal_(self.W_W, mean=0.0, std=0.001) 184 | nn.init.normal_(self.D_W, mean=0.0, std=0.001) 185 | nn.init.normal_(self.H_W, mean=0.0, std=0.001) 186 | 187 | 188 | def get_wavelet(kernel, scale, adj_mx, is_gpu): 189 | Phi, Phi_inv = create_wnn_kernel_matrix(adj_mx, scale) 190 | Phi = Phi.cuda() 191 | Phi_inv = Phi_inv.cuda() 192 | return Phi.mm(kernel.diag()).mm(Phi_inv) 193 | 194 | 195 | class WaveletKernel(nn.Module): 196 | def __init__(self, adj_mx, is_gpu=False, scale=0.1): 197 | super(WaveletKernel, self).__init__() 198 | self.is_gpu = is_gpu 199 | self.adj_mx = adj_mx 200 | self.N = adj_mx.shape[0] 201 | self.scale = scale 202 | 203 | self.g = nn.Parameter(torch.ones(self.N)) 204 | self.Phi, self.Phi_inv = create_wnn_kernel_matrix(self.adj_mx, self.scale) 205 | 206 | if is_gpu: 207 | self.g = nn.Parameter(torch.ones(self.N).cuda()) 208 | self.Phi = self.Phi.cuda() 209 | self.Phi_inv = self.Phi_inv.cuda() 210 | g_diag = self.g.diag() 211 | self.k = self.Phi.mm(g_diag).mm(self.Phi_inv) 212 | self.weight_init() 213 | 214 | def forward(self, x): 215 | # batch, feature, N 216 | x = torch.einsum('bfn, np -> bfp', x, self.k).contiguous() 217 | return x 218 | 219 | def weight_init(self): 220 | # this could be replaced by Chebyshev Polynomials 221 | nn.init.uniform_(self.g) 222 | 223 | 224 | # TODO: wavelet attention mechanism and trans attention 225 | class STWNBlock(nn.Module): 226 | def __init__(self, adj_mx, in_channel, out_channel, wavelets_num, is_gpu=False): 227 | super(STWNBlock, self).__init__() 228 | self.is_gpu = is_gpu 229 | self.N = adj_mx.shape[0] 230 | self.adj_mx = adj_mx 231 | self.adj_mx_t = torch.from_numpy(adj_mx).float().cuda() 232 | self.in_channel = in_channel 233 | self.out_channel = out_channel 234 | self.kernel_size = wavelets_num 235 | 236 | scales = [0.1 + 0.1 * 2 * i for i in range(wavelets_num)] 237 | kernel_para = torch.ones(wavelets_num, self.N).float() 238 | if is_gpu: 239 | kernel_para = kernel_para.cuda() 240 | self.kernels = nn.ParameterList([nn.Parameter(kernel_para[i]) for i in range(wavelets_num)]) 241 | self.wavelets = torch.stack( 242 | [get_wavelet(self.kernels[i], scales[i], self.adj_mx, is_gpu) for i in range(wavelets_num)], 243 | dim=0).cuda() 244 | self.randw = nn.Parameter(torch.randn(self.N, self.N).float()).cuda() 245 | self.krandw = nn.Parameter(torch.stack([torch.randn(self.N, self.N) for i in range(wavelets_num)], dim=0).float()).cuda() 246 | # self.wavelets = nn.Parameter(torch.randn(self.N, self.N).float()) 247 | print('wavelets shape', self.wavelets.shape) 248 | self.Gate = nn.Parameter(torch.FloatTensor(wavelets_num)).cuda() 249 | self.SumOne = torch.ones(wavelets_num).float().cuda() 250 | self.upsampling = nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1)) 251 | self.rnn = nn.GRU(input_size=32, 252 | hidden_size=32, 253 | num_layers=1, 254 | batch_first=True) 255 | self.weight_init() 256 | 257 | 258 | def forward(self, x): 259 | """ 260 | :param x: (batch, in_channel, N, sequence) 261 | :return: (batch, out_channel, N, sequence) 262 | """ 263 | 264 | seq_len = x.shape[3] 265 | seqs = [] 266 | B, F, N, T = x.shape 267 | 268 | x = x.transpose(1, 2) 269 | 270 | 271 | # real wavelet + K randw 272 | wavelets = self.wavelets * self.krandw 273 | x = torch.einsum('bnft, knm -> bkmft', x, wavelets) 274 | x = torch.einsum('bknft, k -> bnft', x, self.Gate) 275 | 276 | 277 | # GCN + GRU: 278 | x = x.transpose(2, 3).reshape(B * N, T, F) 279 | outputs, last_hidden = self.rnn(x, None) 280 | outputs = outputs.reshape(B, N, T, F).permute(0, 3, 1, 2) 281 | return outputs 282 | 283 | def weight_init(self): 284 | for m in self.modules(): 285 | if isinstance(m, nn.Conv2d): 286 | nn.init.kaiming_normal_(m.weight, mode='fan_in') 287 | print('init module with kaiming', m) 288 | elif isinstance(m, nn.ParameterList): 289 | for i in m: 290 | nn.init.normal_(i, mean=0.0, std=0.001) 291 | else: 292 | print('ParameterList!Do nothing') 293 | nn.init.normal_(self.Gate, mean=0.0, std=0.001) 294 | # nn.init.kaiming_normal_(self.sampling) 295 | # nn.init.kaiming_normal_(self.rnn.weight, mode='fan_in') 296 | 297 | 298 | class AttSTWNBlock(nn.Module): 299 | ''' 300 | Attention STWNBlock 301 | ''' 302 | 303 | def __init__(self, adj_mx, in_channel, out_channel, kernel_size, att_channel=32, bn=True, sampling=None, 304 | is_gpu=False): 305 | super(AttSTWNBlock, self).__init__() 306 | self.is_gpu = is_gpu 307 | print('AttSTWNBlock, is_gpu', is_gpu) 308 | self.N = adj_mx.shape[0] 309 | self.adj_mx = adj_mx 310 | self.in_channel = in_channel 311 | self.out_channel = out_channel 312 | self.kernel_size = kernel_size 313 | self.att_channel = att_channel 314 | 315 | scales = [0.1 + 0.1 * 2 * i for i in range(kernel_size)] 316 | kernel_para = torch.ones(kernel_size, self.N).float() 317 | 318 | if is_gpu: 319 | kernel_para = kernel_para.cuda() 320 | 321 | self.kernels = nn.ParameterList([nn.Parameter(kernel_para[i]) for i in range(kernel_size)]) 322 | self.wavelets = torch.stack( 323 | [get_wavelet(self.kernels[i], scales[i], self.adj_mx, is_gpu) for i in range(kernel_size)], 324 | dim=0) 325 | print('wavelets shape', self.wavelets.shape) 326 | if is_gpu: 327 | self.upsamplings = nn.Parameter(torch.FloatTensor(kernel_size, in_channel, out_channel).cuda()) 328 | else: 329 | self.upsamplings = nn.Parameter(torch.FloatTensor(kernel_size, in_channel, out_channel)) 330 | 331 | # self.upsamplings = nn.ParameterList([nn.Parameter(torch.FloatTensor(in_channel, out_channel).cuda()) 332 | # for _ in range(self.kernel_size)]) 333 | 334 | self.upsampling = nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1)) 335 | 336 | self.Att_W = nn.Parameter(torch.FloatTensor(self.out_channel, self.att_channel)) 337 | self.Att_U = nn.Parameter(torch.FloatTensor(kernel_size, self.att_channel)) 338 | 339 | self.weight_init() 340 | 341 | def forward(self, x): 342 | """ 343 | :param x: (batch, in_channel, N, sequence) 344 | :return: (batch, out_channel, N, sequence) 345 | """ 346 | # TODO: change to recursive 347 | sha = x.shape 348 | seq_len = sha[3] 349 | seqs = [] 350 | if self.is_gpu: 351 | wavelets = self.wavelets.cuda() 352 | else: 353 | wavelets = self.wavelets 354 | 355 | # x = self.upsampling(x) 356 | # print('samp',self.upsampling.weight) 357 | 358 | # attention: 359 | 360 | # # wavelet gated, sum directly: 361 | x = torch.einsum('bnft, knm -> bkmft', x, wavelets) 362 | x = torch.einsum('bkmft, kfo -> bk') 363 | x = torch.einsum('bknft, k -> bnft', x, self.Gate) 364 | 365 | 366 | a = torch.einsum('fs, bknf -> bkns', self.Att_W, xs) 367 | a = torch.einsum('bkns, ks -> bkn', xs, self.Att_U) 368 | 369 | for i in range(seq_len): 370 | xs = x[..., i].transpose(1, 2) 371 | # wavelet transform: 372 | # bnf, knn -> bknf 373 | xs = torch.einsum('bnf, ksn -> bknf', xs, wavelets) 374 | 375 | # in_channel to out_channel, bknf , kfs -> bkns 376 | xs = torch.einsum('bknf , kfo -> bkno', xs, self.upsamplings) 377 | 378 | mask = xs == float('-inf') 379 | xs = xs.data.masked_fill(mask, 0) 380 | # attention: 381 | 382 | # fs, bknf -> bkns 383 | # bkns, s -> bkn 384 | # bkn -> a = softmax(bkn, k) 385 | # bknf, bkn -> bknf 386 | a = torch.einsum('fs, bknf -> bkns', self.Att_W, xs) 387 | a = torch.einsum('bkns, ks -> bkn', xs, self.Att_U) 388 | a = util.norm(a, dim=1) 389 | a = F.softmax(a, dim=1) 390 | a = a.transpose(1, 2) 391 | 392 | # mock attention: 393 | # xsshape = xs.shape 394 | # a = torch.ones(xsshape[0], xsshape[2], xsshape[1]).float().cuda() 395 | 396 | # xs * attention 397 | out = torch.einsum('bnk, bkno -> bno', a, xs).transpose(1, 2) 398 | # h = out 399 | seqs.append(out) 400 | 401 | # stack all sequences 402 | x = torch.stack(seqs, dim=3) 403 | return x 404 | 405 | def weight_init(self): 406 | for m in self.modules(): 407 | if isinstance(m, nn.Conv2d): 408 | nn.init.kaiming_normal_(m.weight.data, mode='fan_in') 409 | print('init module with kaiming', m) 410 | elif isinstance(m, nn.ParameterList): 411 | for i in m: 412 | # nn.init.kaiming_normal_(i.data, mode='fan_out') 413 | nn.init.normal_(i, mean=0.0, std=0.001) 414 | # nn.init.kaiming_normal_(i.weight.data, mode='fan_in') 415 | print('init parameterlist with kaiming', m) 416 | else: 417 | print('ParameterList!Do nothing', m) 418 | # nn.init.kaiming_normal_(self.upsamplings.data, mode='fan_in') 419 | # nn.init.normal_(self.upsamplings, mean=0.0, std=0.001) 420 | for m in self.upsamplings: 421 | nn.init.kaiming_normal_(m.data, mode='fan_in') 422 | 423 | nn.init.normal_(self.Att_W, mean=0.0, std=0.001) 424 | nn.init.kaiming_normal_(self.Att_U.data, mode='fan_in') 425 | 426 | 427 | # nn.init.normal_(self.Att_U, mean=0.0, std=0.001) 428 | 429 | 430 | class DecoderRNN(nn.Module): 431 | def __init__(self, feature_len, hidden_len, predict_len, out_channel=1, num_layers=1): 432 | super(DecoderRNN, self).__init__() 433 | self.feature_len = feature_len 434 | self.hidden_len = hidden_len 435 | self.num_layers = num_layers 436 | self.predict_len = predict_len 437 | self.out_channel = out_channel 438 | # RNN层 439 | self.rnn = nn.RNN( 440 | input_size=feature_len, # feature len 441 | hidden_size=hidden_len, # 隐藏记忆单元尺寸 442 | num_layers=num_layers, # 层数 443 | batch_first=True # 在喂入数据时,按照[batch,seq_len,feature_len]的格式 444 | ) 445 | print("DecoderRNN out_channel: ", out_channel) 446 | self.l1 = nn.Conv1d(hidden_len, out_channel, kernel_size=1) 447 | self.dropout = nn.Dropout(0.2) 448 | # 对RNN层的参数做初始化 449 | for p in self.rnn.parameters(): 450 | nn.init.normal_(p, mean=0.0, std=0.001) 451 | 452 | def forward(self, x): 453 | """ 454 | x = (batch, feature, N, sequence) 455 | 需要转换成: 456 | x = (batch x N, sequence, feature) 457 | :return:输出out(batch, N, sequence, out_channel) 458 | """ 459 | batch, channel, N, seq_len = x.shape 460 | x = x.transpose(1, 2).reshape(batch * N, channel, seq_len) 461 | h = None 462 | x = x[:, :, seq_len - 1].unsqueeze(dim=1) 463 | seqs = [] 464 | for _ in range(self.predict_len): 465 | out, h = self.rnn(x, h) 466 | # out = (batch * N, seq_len, feature) to (batch * N, feature, seq_len) 467 | out = out.transpose(1, 2) 468 | out = self.l1(out) 469 | seqs.append(out) 470 | 471 | predict = torch.cat(seqs, dim=2) 472 | predict = predict.reshape(batch, N, self.predict_len) 473 | # predict = self.dropout(predict) 474 | return predict, predict 475 | 476 | 477 | 478 | class NewDecoderRNN(nn.Module): 479 | def __init__(self, feature_len, hidden_len, predict_len, out_channel=1, num_layers=1): 480 | super(NewDecoderRNN, self).__init__() 481 | self.feature_len = feature_len 482 | self.hidden_len = hidden_len 483 | self.num_layers = num_layers 484 | self.predict_len = predict_len 485 | self.out_channel = out_channel 486 | # RNN层 487 | self.rnn = nn.RNN( 488 | input_size=feature_len, # feature len 489 | hidden_size=hidden_len, # 隐藏记忆单元尺寸 490 | num_layers=num_layers, # 层数 491 | batch_first=True # 在喂入数据时,按照[batch,seq_len,feature_len]的格式 492 | ) 493 | print("DecoderRNN out_channel: ", out_channel) 494 | self.l1 = nn.Conv1d(hidden_len, out_channel, kernel_size=1) 495 | self.dropout = nn.Dropout(0.2) 496 | # 对RNN层的参数做初始化 497 | for p in self.rnn.parameters(): 498 | nn.init.normal_(p, mean=0.0, std=0.001) 499 | 500 | def forward(self, x): 501 | """ 502 | x = (batch, feature, N, sequence) 503 | 需要转换成: 504 | x = (batch x N, sequence, feature) 505 | :return:输出out(batch, N, sequence, out_channel) 506 | """ 507 | batch, channel, N, seq_len = x.shape 508 | x = x.transpose(1, 2).reshape(batch * N, channel, seq_len).transpose(1, 2) 509 | # last seq: 510 | x = x[:,-1,:].unsqueeze(dim=1) 511 | h = None 512 | seqs = [] 513 | for i in range(self.predict_len): 514 | out, h = self.rnn(x, h) 515 | # out = (batch * N, seq_len, feature) to (batch * N, feature, seq_len) 516 | out = out[:,-1,:].unsqueeze(dim=1) 517 | each_seq = out.transpose(1, 2) 518 | each_seq = self.l1(each_seq) 519 | seqs.append(each_seq) 520 | 521 | predict = torch.cat(seqs, dim=2).squeeze() 522 | predict = predict.reshape(batch, N, self.predict_len) 523 | # predict = self.dropout(predict) 524 | return predict, predict 525 | 526 | 527 | 528 | class EncoderRNN(nn.Module): 529 | def __init__(self, feature_len, hidden_len, predict_len, out_channel=1, num_layers=1): 530 | super(EncoderRNN, self).__init__() 531 | self.feature_len = feature_len 532 | self.hidden_len = hidden_len 533 | self.num_layers = num_layers 534 | self.predict_len = predict_len 535 | self.out_channel = out_channel 536 | # RNN层 537 | self.rnn = nn.RNN( 538 | input_size=feature_len, # feature len 539 | hidden_size=hidden_len, # 隐藏记忆单元尺寸 540 | num_layers=num_layers, # 层数 541 | batch_first=True # 在喂入数据时,按照[batch,seq_len,feature_len]的格式 542 | ) 543 | print("DecoderRNN out_channel: ", out_channel) 544 | self.l1 = nn.Conv1d(hidden_len, out_channel, kernel_size=1) 545 | self.dropout = nn.Dropout(0.2) 546 | # 对RNN层的参数做初始化 547 | for p in self.rnn.parameters(): 548 | nn.init.normal_(p, mean=0.0, std=0.001) 549 | 550 | def forward(self, x): 551 | """ 552 | x = (batch, channel, N, sequence) 553 | 需要转换成: 554 | x = (batch x N, sequence, feature) 555 | :return:输出out(batch, in_channel, N, sequence) 556 | """ 557 | batch, channel, N, seq_len = x.shape 558 | x = x.transpose(1, 2).reshape(batch * N, channel, seq_len) 559 | h = None 560 | x = x.transpose(1, 2) 561 | seqs = [] 562 | out, h = self.rnn(x, h) 563 | # batch*N, seq, feature. 564 | out = out.reshape(batch, N, seq_len, self.hidden_len) 565 | out = out[:,:,-1,:].unsqueeze(dim=2).permute(0, 3, 1, 2) 566 | return out 567 | 568 | 569 | 570 | 571 | class DecoderLSTM(nn.Module): 572 | def __init__(self, feature_len, hidden_len, predict_len, num_layers=1): 573 | super(DecoderLSTM, self).__init__() 574 | print('init DecoderLSTM') 575 | self.feature_len = feature_len 576 | self.hidden_len = hidden_len 577 | self.num_layers = num_layers 578 | self.predict_len = predict_len 579 | # RNN层 580 | self.rnn = nn.LSTM( 581 | input_size=feature_len, # feature len 582 | hidden_size=hidden_len, # 隐藏记忆单元尺寸 583 | num_layers=num_layers, # 层数 584 | batch_first=True # 在喂入数据时,按照[batch,seq_len,feature_len]的格式 585 | ) 586 | self.l1 = nn.Conv1d(hidden_len, 1, kernel_size=1) 587 | self.dropout = nn.Dropout(0.2) 588 | # 对RNN层的参数做初始化 589 | for p in self.rnn.parameters(): 590 | nn.init.normal_(p, mean=0.0, std=0.001) 591 | 592 | def forward(self, x): 593 | """ 594 | x = (batch, feature, N, sequence) 595 | 需要转换成: 596 | x = (batch x N, sequence, feature) 597 | :return:输出out(batch,N,sequence) 598 | """ 599 | batch, _, N, seq_len = x.shape 600 | x = x.transpose(1, 2).reshape(batch * N, seq_len, self.feature_len).transpose(1, 2) 601 | h = None 602 | x = x[:, :, seq_len - 1].unsqueeze(dim=1) 603 | seqs = [] 604 | for _ in range(self.predict_len): 605 | out, h = self.rnn(x, h) 606 | # out = (batch * N, seq_len, feature) to (batch * N, feature, seq_len) 607 | out = self.l1(out.transpose(1, 2)) 608 | seqs.append(out) 609 | 610 | predict = torch.cat(seqs, dim=1) 611 | 612 | predict = predict.reshape(batch, N, self.predict_len) 613 | # predict = self.dropout(predict) 614 | return predict, predict 615 | 616 | 617 | class DecoderGRU(nn.Module): 618 | def __init__(self, feature_len, hidden_len, predict_len, num_layers=1): 619 | super(DecoderGRU, self).__init__() 620 | print('init DecoderGRU') 621 | self.feature_len = feature_len 622 | self.hidden_len = hidden_len 623 | self.num_layers = num_layers 624 | self.predict_len = predict_len 625 | # RNN层 626 | self.rnn = nn.GRU( 627 | input_size=feature_len, # feature len 628 | hidden_size=hidden_len, # 隐藏记忆单元尺寸 629 | num_layers=num_layers, # 层数 630 | batch_first=True # 在喂入数据时,按照[batch,seq_len,feature_len]的格式 631 | ) 632 | self.l1 = nn.Conv1d(hidden_len, 1, kernel_size=1) 633 | self.dropout = nn.Dropout(0.2) 634 | # 对RNN层的参数做初始化 635 | for p in self.rnn.parameters(): 636 | nn.init.normal_(p, mean=0.0, std=0.001) 637 | 638 | def forward(self, x): 639 | """ 640 | x = (batch, feature, N, sequence) 641 | 需要转换成: 642 | x = (batch x N, sequence, feature) 643 | :return:输出out(batch,N,sequence) 644 | """ 645 | batch, _, N, seq_len = x.shape 646 | x = x.transpose(1, 2).reshape(batch * N, seq_len, self.feature_len).transpose(1, 2) 647 | h = None 648 | x = x[:, :, seq_len - 1].unsqueeze(dim=1) 649 | seqs = [] 650 | for _ in range(self.predict_len): 651 | out, h = self.rnn(x, h) 652 | # out = (batch * N, seq_len, feature) to (batch * N, feature, seq_len) 653 | out = self.l1(out.transpose(1, 2)) 654 | seqs.append(out) 655 | 656 | predict = torch.cat(seqs, dim=1) 657 | 658 | predict = predict.reshape(batch, N, self.predict_len) 659 | # predict = self.dropout(predict) 660 | return predict, predict 661 | 662 | 663 | class FC(nn.Module): 664 | def __init__(self, in_channel, out_channel, N): 665 | super(FC, self).__init__() 666 | self.N = N 667 | self.in_channel = in_channel 668 | self.out_channel = out_channel 669 | self.l1 = nn.Conv2d(in_channel, 16, kernel_size=(1, 1)) 670 | self.bn = nn.BatchNorm2d(16) 671 | self.l2 = nn.Conv2d(16, out_channel, kernel_size=(1, 1)) 672 | 673 | def forward(self, x): 674 | """ 675 | :param x: (batch, features, Nodes, sequence) 676 | :return: (batch, sequence, Nodes, features=1).squeeze() --> (batch, sequence, Nodes) 677 | """ 678 | x = F.relu(self.bn(self.l1(x))) 679 | x = F.relu(self.l2(x)) 680 | seq_len = x.shape[3] 681 | batch_size = x.shape[0] 682 | # outs = [] 683 | # for i in range(self.N): 684 | # node = x[:, :, i, :] 685 | # out = F.relu(self.l(node)) 686 | # outs.append(out) 687 | x = x.reshape(batch_size, seq_len, self.N, self.out_channel).squeeze() 688 | # x = torch.cat(outs, dim=2).reshape(batch_size, seq_len, self.N, self.out_channel).squeeze() 689 | return x 690 | 691 | 692 | class RNNModel(nn.Module): 693 | def __init__(self, args): 694 | super(RNNModel, self).__init__() 695 | self.predict_len = args.predict_len 696 | self.args = args 697 | self.rnn_in_channel = args.rnn_in_channel 698 | self.upsampling = nn.Conv1d(args.feature_len, self.rnn_in_channel, kernel_size=1) 699 | if args.decoder_type == "rnn": 700 | self.decoder = DecoderRNN(self.rnn_in_channel, 32, self.predict_len, num_layers=args.rnn_layer_num) 701 | elif args.decoder_type == "lstm": 702 | print('lstm') 703 | self.decoder = DecoderLSTM(self.rnn_in_channel, 32, self.predict_len, num_layers=args.rnn_layer_num) 704 | else: 705 | self.decoder = DecoderGRU(self.rnn_in_channel, 32, self.predict_len, num_layers=args.rnn_layer_num) 706 | 707 | def forward(self, x): 708 | """ 709 | :param x: (batch, feature, N, sequence) 710 | :return: 711 | """ 712 | batch, feature, N, seq = x.shape 713 | seq = seq // 3 714 | _, _, x = x[:, :, :, :seq], x[:, :, :, seq:seq * 2], x[:, :, :, seq * 2:] 715 | 716 | x = x.transpose(1, 2).reshape(batch * N, feature, seq) 717 | x = F.relu(self.upsampling(x)) 718 | x = x.reshape(batch, N, self.rnn_in_channel, seq).transpose(1, 2) 719 | x, _ = self.decoder(x) 720 | x = x.transpose(1, 2) 721 | return x, x 722 | 723 | 724 | class Trainer: 725 | def __init__(self, args, model, optimizer, scaler, criterion=nn.MSELoss()): 726 | self.model = model 727 | self.args = args 728 | self.criterion = criterion 729 | self.optimizer = optimizer 730 | self.scaler = scaler 731 | self.clip = args.clip 732 | self.lr_decay_rate = args.lr_decay_rate 733 | self.epochs = args.epochs 734 | self.scheduler = optim.lr_scheduler.LambdaLR( 735 | self.optimizer, lr_lambda=lambda epochs: self.lr_decay_rate ** epochs) 736 | 737 | def train(self, input_data, target): 738 | self.model.train() 739 | self.optimizer.zero_grad() 740 | 741 | # train 742 | output, h = self.model(input_data) 743 | # print('output shape', output.shape) 744 | # h.detach() 745 | 746 | # loss, weights update 747 | # TODO: squeeze performance? 748 | output = output.squeeze() 749 | # TODO: replace the inverse_transform function by self-define or in forward function. 750 | predict = self.scaler.inverse_transform(output) 751 | # TODO: compare speed of cal all metrics here or outside. 752 | # 1. cal metrics here 753 | # print('predict:', predict.shape) 754 | # print('target:', target.shape) 755 | 756 | # target [batch, N, seq] 757 | # loss = self.criterion(predict, target) 758 | mae, mape, rmse = util.calc_metrics(predict, target) 759 | mae.backward(retain_graph=True) 760 | # loss.backward() 761 | self.optimizer.step() 762 | return mae.item(), mae.item(), mape.item(), rmse.item() 763 | 764 | def eval(self, input_data, target): 765 | self.model.eval() 766 | 767 | output, h = self.model(input_data) # [batch_size,seq_length,num_nodes] 768 | h.detach() 769 | 770 | output = output.squeeze() 771 | 772 | predict = self.scaler.inverse_transform(output) 773 | 774 | predict = torch.clamp(predict, min=0., max=70.) 775 | mae, mape, rmse = util.calc_metrics(predict, target) 776 | return mae.item(), mape.item(), rmse.item() 777 | -------------------------------------------------------------------------------- /prepareData.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import configparser 5 | 6 | 7 | def search_data(sequence_length, num_of_depend, label_start_idx, 8 | num_for_predict, units, points_per_hour): 9 | ''' 10 | Parameters 11 | ---------- 12 | sequence_length: int, length of all history data 13 | num_of_depend: int, 14 | label_start_idx: int, the first index of predicting target 15 | num_for_predict: int, the number of points will be predicted for each sample 16 | units: int, week: 7 * 24, day: 24, recent(hour): 1 17 | points_per_hour: int, number of points per hour, depends on data 18 | Returns 19 | ---------- 20 | list[(start_idx, end_idx)] 21 | ''' 22 | 23 | if points_per_hour < 0: 24 | raise ValueError("points_per_hour should be greater than 0!") 25 | 26 | if label_start_idx + num_for_predict > sequence_length: 27 | return None 28 | 29 | x_idx = [] 30 | for i in range(1, num_of_depend + 1): 31 | start_idx = label_start_idx - points_per_hour * units * i 32 | end_idx = start_idx + num_for_predict 33 | if start_idx >= 0: 34 | x_idx.append((start_idx, end_idx)) 35 | else: 36 | return None 37 | 38 | if len(x_idx) != num_of_depend: 39 | return None 40 | 41 | return x_idx[::-1] 42 | 43 | 44 | def get_sample_indices(data_sequence, num_of_weeks, num_of_days, num_of_hours, 45 | label_start_idx, num_for_predict, num_predict, points_per_hour=12): 46 | ''' 47 | Parameters 48 | ---------- 49 | data_sequence: np.ndarray 50 | shape is (sequence_length, num_of_vertices, num_of_features) 51 | num_of_weeks, num_of_days, num_of_hours: int 52 | label_start_idx: int, the first index of predicting target, 预测值开始的那个点 53 | num_for_predict: int, 54 | the number of points will be predicted for each sample 55 | points_per_hour: int, default 12, number of points per hour 56 | Returns 57 | ---------- 58 | week_sample: np.ndarray 59 | shape is (num_of_weeks * points_per_hour, 60 | num_of_vertices, num_of_features) 61 | day_sample: np.ndarray 62 | shape is (num_of_days * points_per_hour, 63 | num_of_vertices, num_of_features) 64 | hour_sample: np.ndarray 65 | shape is (num_of_hours * points_per_hour, 66 | num_of_vertices, num_of_features) 67 | target: np.ndarray 68 | shape is (num_for_predict, num_of_vertices, num_of_features) 69 | ''' 70 | week_sample, day_sample, hour_sample = None, None, None 71 | 72 | if label_start_idx + num_for_predict > data_sequence.shape[0]: 73 | return week_sample, day_sample, hour_sample, None 74 | 75 | if num_of_weeks > 0: 76 | week_indices = search_data(data_sequence.shape[0], num_of_weeks, 77 | label_start_idx, num_for_predict, 78 | 7 * 24, points_per_hour) 79 | if not week_indices: 80 | return None, None, None, None 81 | 82 | week_sample = np.concatenate([data_sequence[i: j] 83 | for i, j in week_indices], axis=0) 84 | 85 | if num_of_days > 0: 86 | day_indices = search_data(data_sequence.shape[0], num_of_days, 87 | label_start_idx, num_for_predict, 88 | 24, points_per_hour) 89 | if not day_indices: 90 | return None, None, None, None 91 | 92 | day_sample = np.concatenate([data_sequence[i: j] 93 | for i, j in day_indices], axis=0) 94 | 95 | if num_of_hours > 0: 96 | hour_indices = search_data(data_sequence.shape[0], num_of_hours, 97 | label_start_idx, num_for_predict, 98 | 1, points_per_hour) 99 | if not hour_indices: 100 | return None, None, None, None 101 | 102 | hour_sample = np.concatenate([data_sequence[i: j] 103 | for i, j in hour_indices], axis=0) 104 | 105 | target = data_sequence[label_start_idx: label_start_idx + num_predict] 106 | 107 | return week_sample, day_sample, hour_sample, target 108 | -------------------------------------------------------------------------------- /requirements: -------------------------------------------------------------------------------- 1 | pandas 2 | pygsp 3 | networkx 4 | texttable 5 | sklearn 6 | fastprogress 7 | tables 8 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import numpy as np 4 | 5 | import pandas as pd 6 | import scipy.sparse as sp 7 | import torch 8 | from scipy.sparse import linalg 9 | 10 | 11 | def get_common_args(): 12 | parser = argparse.ArgumentParser() 13 | # learning params 14 | parser.add_argument('--dev', action='store_true', help='dev') 15 | parser.add_argument('--dev_size', type=int, default=1000, help='dev_sample_size') 16 | parser.add_argument('--best_model_save_path', type=str, default='.best_model', help='best_model') 17 | parser.add_argument('--pre_model_path', type=str, default='./pre_model/best_model', help='pre_model_path') 18 | parser.add_argument('--batch_size', type=int, default=128, help='batch size') 19 | parser.add_argument('--epochs', type=int, default=10, help='epochs') 20 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 21 | parser.add_argument('--lr_decay_rate', type=float, default=0.97, help='lr_decay_rate') 22 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight_decay') 23 | parser.add_argument('--clip', type=int, default=3, help='clip') 24 | parser.add_argument('--seq_length', type=int, default=3, help='seq_length') 25 | parser.add_argument('--predict_len', type=int, default=12, help='predict_len') 26 | parser.add_argument('--scheduler', action='store_true', help='scheduler') 27 | parser.add_argument('--mo', type=float, default=0.1, help='momentum') 28 | 29 | # running params 30 | parser.add_argument('--cuda', action='store_true', help='cuda') 31 | parser.add_argument('--transpose', action='store_true', help='transpose sequence and feature?') 32 | parser.add_argument('--data_path', type=str, default='./data/METR-LA', help='data path') 33 | parser.add_argument('--adj_file', type=str, default='./data/sensor_graph/adj_mx.pkl', 34 | help='adj data path') 35 | parser.add_argument('--adj_type', type=str, default='scalap', help='adj type', choices=ADJ_CHOICES) 36 | parser.add_argument('--fig_filename', type=str, default='./mae', help='fig_filename') 37 | 38 | # model params 39 | parser.add_argument('--att', action='store_true', help='attention') 40 | parser.add_argument('--recur', action='store_true', help='recur') 41 | parser.add_argument('--fusion', action='store_true', help='fusion') 42 | parser.add_argument('--pretrain', action='store_true', help='pretrain') 43 | parser.add_argument('--feature_len', type=int, default=3, help='input feature_len') 44 | parser.add_argument('--gcn_layer_num', type=int, default=2, help='gcn_layer_num') 45 | parser.add_argument('--wavelets_num', type=int, default=20, help='wavelets_num') 46 | parser.add_argument('--rnn_layer_num', type=int, default=2, help='rnn_layer_num') 47 | parser.add_argument('--rnn_in_channel', type=int, default=32, help='rnn_in_channel') 48 | parser.add_argument('--rnn', action='store_true', help='attention') 49 | parser.add_argument('--decoder_type', type=str, default='rnn', help='decoder_type') 50 | 51 | 52 | return parser 53 | 54 | 55 | # common parameters 56 | class Args: 57 | def __init__(self, 58 | dev=False, 59 | data_path='./data/METR-LA', 60 | best_model_save_path='./best_model', 61 | batch_size=128, 62 | epochs=10, 63 | lr=0.0005, 64 | weight_decay=0.0001, 65 | cuda=False, 66 | transpose=False, 67 | params=dict(), 68 | adj_type='scalap', 69 | adj_file='', 70 | seq_length=12, 71 | n_iters=1, 72 | clip=3, 73 | lr_decay_rate=0.97, 74 | addaptadj=True 75 | ): 76 | self.dev = dev 77 | self.batch_size = batch_size 78 | self.best_model_save_path = best_model_save_path 79 | self.epochs = epochs 80 | self.lr = lr 81 | self.clip = clip 82 | self.lr_decay_rate = lr_decay_rate 83 | self.weight_decay = weight_decay 84 | self.data_path = data_path 85 | self.cuda = cuda 86 | self.transpose = transpose 87 | self.params = params 88 | self.adj_type = adj_type 89 | self.adj_file = adj_file 90 | self.seq_length = seq_length 91 | self.n_iters = n_iters 92 | self.addaptadj = addaptadj 93 | 94 | 95 | class GWNNArgs: 96 | def __init__(self, num_nodes=207, do_grap_conv=True, p=True, aptonly=False, addaptadj=True, randomadj=False, 97 | nhid=22, in_dim=2, dropout=0.3, apt_size=10, cat_feat_gc=False, clip=None): 98 | self.do_graph_conv = do_grap_conv 99 | self.p = p 100 | self.aptonly = aptonly 101 | self.addaptadj = addaptadj 102 | self.randomadj = randomadj 103 | self.nhid = nhid 104 | self.in_dim = in_dim 105 | self.num_nodes = num_nodes 106 | self.dropout = dropout 107 | # self.n_obs = n_obs 108 | self.apt_size = apt_size 109 | self.cat_feat_gc = cat_feat_gc 110 | # self.fill_zeroes = fill_zeroes 111 | # self.checkpoint = checkpoint 112 | self.clip = clip 113 | self.lr_decay_rate = 0.97 114 | 115 | 116 | class StandardScaler(): 117 | 118 | def __init__(self, mean, std, fill_zeroes=True): 119 | self.mean = mean 120 | self.std = std 121 | self.fill_zeroes = fill_zeroes 122 | 123 | def transform(self, data): 124 | if self.fill_zeroes: 125 | mask = (data == 0) 126 | data[mask] = self.mean 127 | return (data - self.mean) / self.std 128 | 129 | def inverse_transform(self, data): 130 | return (data * self.std) + self.mean 131 | 132 | 133 | class TrafficDataLoader(object): 134 | def __init__(self, xs, ys, batch_size, cuda=False, transpose=False, pad_with_last_sample=True): 135 | """ 136 | :param xs: 137 | :param ys: 138 | :param batch_size: 139 | :param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size. 140 | """ 141 | self.batch_size = batch_size 142 | self.current_ind = 0 143 | if pad_with_last_sample: 144 | # batch 145 | num_padding = (batch_size - (len(xs) % batch_size)) % batch_size 146 | x_padding = np.repeat(xs[-1:], num_padding, axis=0) 147 | y_padding = np.repeat(ys[-1:], num_padding, axis=0) 148 | xs = np.concatenate([xs, x_padding], axis=0) 149 | ys = np.concatenate([ys, y_padding], axis=0) 150 | self.size = len(xs) 151 | self.num_batch = int(self.size // self.batch_size) 152 | xs = torch.Tensor(xs) 153 | ys = torch.Tensor(ys) 154 | if cuda: 155 | xs, ys = xs.cuda(), ys.cuda() 156 | # # temporal filter one sensor data as x, speed as y 157 | # xs = xs.squeeze()[:, :, [0]].squeeze() 158 | # ys = ys.squeeze()[:, :, [0]].squeeze() 159 | # ys = ys[:, :, 0] 160 | if transpose: 161 | xs = xs.transpose(1, 3) 162 | # ys = ys.transpose(1, 2) 163 | self.xs = xs 164 | self.ys = ys 165 | 166 | def shuffle(self): 167 | permutation = np.random.permutation(self.size) 168 | xs, ys = self.xs[permutation], self.ys[permutation] 169 | self.xs = xs 170 | self.ys = ys 171 | 172 | def get_iterator(self): 173 | self.current_ind = 0 174 | 175 | def _wrapper(): 176 | # TODO: Bug, we need to add more conditions. 177 | start_ind = 0 178 | end_ind = 0 179 | while self.current_ind < self.num_batch and start_ind <= end_ind and start_ind <= self.size: 180 | start_ind = self.batch_size * self.current_ind 181 | end_ind = min(self.size, self.batch_size * (self.current_ind + 1)) 182 | x_i = self.xs[start_ind: end_ind, ...] 183 | y_i = self.ys[start_ind: end_ind, ...] 184 | yield x_i, y_i 185 | self.current_ind += 1 186 | 187 | return _wrapper() 188 | 189 | def norm(tensor_data, dim=0): 190 | mu = tensor_data.mean(axis=dim, keepdim=True) 191 | std = tensor_data.std(axis=dim, keepdim=True) 192 | return (tensor_data - mu) / (std + 0.00005) 193 | 194 | def sym_adj(adj): 195 | """Symmetrically normalize adjacency matrix.""" 196 | print('origin adj:', adj) 197 | adj = sp.coo_matrix(adj) 198 | rowsum = np.array(adj.sum(1)) 199 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 200 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 201 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 202 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).astype(np.float32).todense() 203 | 204 | 205 | def sym_norm_lap(adj): 206 | N = adj.shape[0] 207 | adj_norm = sym_adj(adj) 208 | L = np.eye(N) - adj_norm 209 | return L 210 | 211 | 212 | def asym_adj(adj): 213 | adj = sp.coo_matrix(adj) 214 | rowsum = np.array(adj.sum(1)).flatten() 215 | d_inv = np.power(rowsum, -1).flatten() 216 | d_inv[np.isinf(d_inv)] = 0. 217 | d_mat = sp.diags(d_inv) 218 | return d_mat.dot(adj).astype(np.float32).todense() 219 | 220 | 221 | def calculate_normalized_laplacian(adj): 222 | """ 223 | # L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2 224 | # D = diag(A 1) 225 | :param adj: 226 | :return: 227 | """ 228 | d = np.array(adj.sum(1)) 229 | d_inv_sqrt = np.power(d, -0.5).flatten() 230 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 231 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt).toarray() 232 | normalized_laplacian = sp.eye(adj.shape[0]) - np.matmul(np.matmul(d_mat_inv_sqrt, adj), d_mat_inv_sqrt) 233 | return normalized_laplacian 234 | 235 | 236 | def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True): 237 | if undirected: 238 | adj_mx = np.maximum.reduce([adj_mx, adj_mx.T]) 239 | L = calculate_normalized_laplacian(adj_mx) 240 | if lambda_max is None: 241 | lambda_max, _ = linalg.eigsh(L, 1, which='LM') 242 | lambda_max = lambda_max[0] 243 | L = sp.csr_matrix(L) 244 | M, _ = L.shape 245 | I = sp.identity(M, format='csr', dtype=L.dtype) 246 | L = (2 / lambda_max * L) - I 247 | return L.astype(np.float32).todense() 248 | 249 | 250 | def load_pickle(pickle_file): 251 | try: 252 | with open(pickle_file, 'rb') as f: 253 | pickle_data = pickle.load(f) 254 | except UnicodeDecodeError as e: 255 | with open(pickle_file, 'rb') as f: 256 | pickle_data = pickle.load(f, encoding='latin1') 257 | except Exception as e: 258 | print('Unable to load data ', pickle_file, ':', e) 259 | raise 260 | return pickle_data 261 | 262 | 263 | ADJ_CHOICES = ['scalap', 'normlap', 'symnadj', 'transition', 'identity'] 264 | 265 | 266 | def load_adj(pkl_filename, adjtype): 267 | sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename) 268 | 269 | adj_mx[adj_mx>0] = 1 270 | adj_mx[adj_mx<0] = 0 271 | 272 | if adjtype == "scalap": 273 | adj = [calculate_scaled_laplacian(adj_mx)] 274 | elif adjtype == "normlap": 275 | adj = [calculate_normalized_laplacian(adj_mx).astype(np.float32)] 276 | elif adjtype == "symnadj": 277 | adj = [sym_adj(adj_mx)] 278 | elif adjtype == "sym_norm_lap": 279 | adj = [sym_norm_lap(adj_mx)] 280 | elif adjtype == "transition": 281 | adj = [asym_adj(adj_mx)] 282 | elif adjtype == "doubletransition": 283 | adj = [asym_adj(adj_mx), asym_adj(np.transpose(adj_mx))] 284 | elif adjtype == "identity": 285 | adj = [np.diag(np.ones(adj_mx.shape[0])).astype(np.float32)] 286 | else: 287 | error = 0 288 | assert error, "adj type not defined" 289 | return sensor_ids, sensor_id_to_ind, adj, 290 | 291 | 292 | # TODO: to check how pytorch implement loss to avoid Nan. 293 | def calc_metrics(preds, labels, null_val=0.): 294 | if np.isnan(null_val): 295 | mask = ~torch.isnan(labels) 296 | else: 297 | mask = (labels != null_val) 298 | mask = mask.float() 299 | mask /= torch.mean(mask) 300 | # handle all zeros. 301 | mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) 302 | mse = (preds - labels) ** 2 303 | mae = torch.abs(preds - labels) 304 | mape = mae / labels 305 | mae, mape, mse = [mask_and_fillna(l, mask) for l in [mae, mape, mse]] 306 | rmse = torch.sqrt(mse) 307 | return mae, mape, rmse 308 | 309 | 310 | def mask_and_fillna(loss, mask): 311 | loss = loss * mask 312 | loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss) 313 | return torch.mean(loss) 314 | 315 | 316 | def calc_tstep_metrics(model, test_loader, scaler, realy, seq_length) -> pd.DataFrame: 317 | model.eval() 318 | outputs = [] 319 | for _, (x, __) in enumerate(test_loader.get_iterator()): 320 | testx = torch.Tensor(x).cuda().transpose(1, 3) 321 | with torch.no_grad(): 322 | preds = model(testx).transpose(1, 3) 323 | outputs.append(preds.squeeze(1)) 324 | yhat = torch.cat(outputs, dim=0)[:realy.size(0), ...] 325 | test_met = [] 326 | 327 | for i in range(seq_length): 328 | pred = scaler.inverse_transform(yhat[:, :, i]) 329 | pred = torch.clamp(pred, min=0., max=70.) 330 | real = realy[:, :, i] 331 | test_met.append([x.item() for x in calc_metrics(pred, real)]) 332 | test_met_df = pd.DataFrame(test_met, columns=['mae', 'mape', 'rmse']).rename_axis('t') 333 | return test_met_df, yhat 334 | 335 | 336 | def _to_ser(arr): 337 | return pd.DataFrame(arr.cpu().detach().numpy()).stack().rename_axis(['obs', 'sensor_id']) 338 | 339 | 340 | def make_pred_df(realy, yhat, scaler, seq_length): 341 | df = pd.DataFrame(dict(y_last=_to_ser(realy[:, :, seq_length - 1]), 342 | yhat_last=_to_ser(scaler.inverse_transform(yhat[:, :, seq_length - 1])), 343 | y_3=_to_ser(realy[:, :, 2]), 344 | yhat_3=_to_ser(scaler.inverse_transform(yhat[:, :, 2])))) 345 | return df 346 | 347 | 348 | def make_graph_inputs(args, device): 349 | sensor_ids, sensor_id_to_ind, adj_mx = load_adj(args.adj_file, args.adj_type) 350 | if device == 'gpu': 351 | supports = [torch.tensor(i).cuda() for i in adj_mx] 352 | else: 353 | supports = [torch.tensor(i) for i in adj_mx] 354 | aptinit = None if args.gwnnArgs.randomadj else supports[ 355 | 0] # ignored without do_graph_conv and add_apt_adj 356 | # if args.aptonly: 357 | # if not args.addaptadj and args.do_graph_conv: raise ValueError( 358 | # 'WARNING: not using adjacency matrix') 359 | # supports = None 360 | return aptinit, supports 361 | 362 | 363 | def get_shared_arg_parser(): 364 | parser = argparse.ArgumentParser() 365 | parser.add_argument('--data', type=str, default='data/METR-LA', help='data path') 366 | parser.add_argument('--adjdata', type=str, default='data/sensor_graph/adj_mx.pkl', 367 | help='adj data path') 368 | parser.add_argument('--adjtype', type=str, default='doubletransition', help='adj type', choices=ADJ_CHOICES) 369 | parser.add_argument('--do_graph_conv', action='store_true', 370 | help='whether to add graph convolution layer') 371 | parser.add_argument('--aptonly', action='store_true', help='whether only adaptive adj') 372 | parser.add_argument('--addaptadj', action='store_true', help='whether add adaptive adj') 373 | parser.add_argument('--randomadj', action='store_true', 374 | help='whether random initialize adaptive adj') 375 | parser.add_argument('--seq_length', type=int, default=12, help='') 376 | parser.add_argument('--nhid', type=int, default=40, help='Number of channels for internal conv') 377 | parser.add_argument('--in_dim', type=int, default=2, help='inputs dimension') 378 | parser.add_argument('--num_nodes', type=int, default=325, help='number of nodes') 379 | parser.add_argument('--batch_size', type=int, default=1024, help='batch size') 380 | parser.add_argument('--dropout', type=float, default=0.3, help='dropout rate') 381 | parser.add_argument('--n_obs', default=None, help='Only use this many observations. For unit testing.') 382 | parser.add_argument('--apt_size', default=10, type=int) 383 | parser.add_argument('--cat_feat_gc', action='store_true') 384 | parser.add_argument('--fill_zeroes', action='store_true') 385 | parser.add_argument('--checkpoint', type=str, help='') 386 | return parser 387 | --------------------------------------------------------------------------------