├── .DS_Store ├── LICENSE ├── data └── .DS_Store ├── lib ├── TrainInits.py ├── add_window.py ├── dataloader.py ├── load_dataset.py ├── logger.py ├── metrics.py └── normalization.py ├── model ├── .DS_Store ├── AGCN.py ├── AGCRN.py ├── AGCRNCell.py ├── BasicTrainer.py ├── PEMSD4_AGCRN.conf ├── PEMSD8_AGCRN.conf └── Run.py └── readme.md /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeiBAI/AGCRN/7fbbf2aeb099242098a3cf482b55cd45d7295c28/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 LeiBAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeiBAI/AGCRN/7fbbf2aeb099242098a3cf482b55cd45d7295c28/data/.DS_Store -------------------------------------------------------------------------------- /lib/TrainInits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | def init_seed(seed): 6 | ''' 7 | Disable cudnn to maximize reproducibility 8 | ''' 9 | torch.cuda.cudnn_enabled = False 10 | torch.backends.cudnn.deterministic = True 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | 16 | def init_device(opt): 17 | if torch.cuda.is_available(): 18 | opt.cuda = True 19 | torch.cuda.set_device(int(opt.device[5])) 20 | else: 21 | opt.cuda = False 22 | opt.device = 'cpu' 23 | return opt 24 | 25 | def init_optim(model, opt): 26 | ''' 27 | Initialize optimizer 28 | ''' 29 | return torch.optim.Adam(params=model.parameters(),lr=opt.lr_init) 30 | 31 | def init_lr_scheduler(optim, opt): 32 | ''' 33 | Initialize the learning rate scheduler 34 | ''' 35 | #return torch.optim.lr_scheduler.StepLR(optimizer=optim,gamma=opt.lr_scheduler_rate,step_size=opt.lr_scheduler_step) 36 | return torch.optim.lr_scheduler.MultiStepLR(optimizer=optim, milestones=opt.lr_decay_steps, 37 | gamma = opt.lr_scheduler_rate) 38 | 39 | def print_model_parameters(model, only_num = True): 40 | print('*****************Model Parameter*****************') 41 | if not only_num: 42 | for name, param in model.named_parameters(): 43 | print(name, param.shape, param.requires_grad) 44 | total_num = sum([param.nelement() for param in model.parameters()]) 45 | print('Total params num: {}'.format(total_num)) 46 | print('*****************Finish Parameter****************') 47 | 48 | def get_memory_usage(device): 49 | allocated_memory = torch.cuda.memory_allocated(device) / (1024*1024.) 50 | cached_memory = torch.cuda.memory_cached(device) / (1024*1024.) 51 | return allocated_memory, cached_memory 52 | #print('Allocated Memory: {:.2f} MB, Cached Memory: {:.2f} MB'.format(allocated_memory, cached_memory)) -------------------------------------------------------------------------------- /lib/add_window.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def Add_Window_Horizon(data, window=3, horizon=1, single=False): 4 | ''' 5 | :param data: shape [B, ...] 6 | :param window: 7 | :param horizon: 8 | :return: X is [B, W, ...], Y is [B, H, ...] 9 | ''' 10 | length = len(data) 11 | end_index = length - horizon - window + 1 12 | X = [] #windows 13 | Y = [] #horizon 14 | index = 0 15 | if single: 16 | while index < end_index: 17 | X.append(data[index:index+window]) 18 | Y.append(data[index+window+horizon-1:index+window+horizon]) 19 | index = index + 1 20 | else: 21 | while index < end_index: 22 | X.append(data[index:index+window]) 23 | Y.append(data[index+window:index+window+horizon]) 24 | index = index + 1 25 | X = np.array(X) 26 | Y = np.array(Y) 27 | return X, Y 28 | 29 | if __name__ == '__main__': 30 | from data.load_raw_data import Load_Sydney_Demand_Data 31 | path = '../data/1h_data_new3.csv' 32 | data = Load_Sydney_Demand_Data(path) 33 | print(data.shape) 34 | X, Y = Add_Window_Horizon(data, horizon=2) 35 | print(X.shape, Y.shape) 36 | 37 | 38 | -------------------------------------------------------------------------------- /lib/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.utils.data 4 | from lib.add_window import Add_Window_Horizon 5 | from lib.load_dataset import load_st_dataset 6 | from lib.normalization import NScaler, MinMax01Scaler, MinMax11Scaler, StandardScaler, ColumnMinMaxScaler 7 | 8 | def normalize_dataset(data, normalizer, column_wise=False): 9 | if normalizer == 'max01': 10 | if column_wise: 11 | minimum = data.min(axis=0, keepdims=True) 12 | maximum = data.max(axis=0, keepdims=True) 13 | else: 14 | minimum = data.min() 15 | maximum = data.max() 16 | scaler = MinMax01Scaler(minimum, maximum) 17 | data = scaler.transform(data) 18 | print('Normalize the dataset by MinMax01 Normalization') 19 | elif normalizer == 'max11': 20 | if column_wise: 21 | minimum = data.min(axis=0, keepdims=True) 22 | maximum = data.max(axis=0, keepdims=True) 23 | else: 24 | minimum = data.min() 25 | maximum = data.max() 26 | scaler = MinMax11Scaler(minimum, maximum) 27 | data = scaler.transform(data) 28 | print('Normalize the dataset by MinMax11 Normalization') 29 | elif normalizer == 'std': 30 | if column_wise: 31 | mean = data.mean(axis=0, keepdims=True) 32 | std = data.std(axis=0, keepdims=True) 33 | else: 34 | mean = data.mean() 35 | std = data.std() 36 | scaler = StandardScaler(mean, std) 37 | data = scaler.transform(data) 38 | print('Normalize the dataset by Standard Normalization') 39 | elif normalizer == 'None': 40 | scaler = NScaler() 41 | data = scaler.transform(data) 42 | print('Does not normalize the dataset') 43 | elif normalizer == 'cmax': 44 | #column min max, to be depressed 45 | #note: axis must be the spatial dimension, please check ! 46 | scaler = ColumnMinMaxScaler(data.min(axis=0), data.max(axis=0)) 47 | data = scaler.transform(data) 48 | print('Normalize the dataset by Column Min-Max Normalization') 49 | else: 50 | raise ValueError 51 | return data, scaler 52 | 53 | def split_data_by_days(data, val_days, test_days, interval=60): 54 | ''' 55 | :param data: [B, *] 56 | :param val_days: 57 | :param test_days: 58 | :param interval: interval (15, 30, 60) minutes 59 | :return: 60 | ''' 61 | T = int((24*60)/interval) 62 | test_data = data[-T*test_days:] 63 | val_data = data[-T*(test_days + val_days): -T*test_days] 64 | train_data = data[:-T*(test_days + val_days)] 65 | return train_data, val_data, test_data 66 | 67 | def split_data_by_ratio(data, val_ratio, test_ratio): 68 | data_len = data.shape[0] 69 | test_data = data[-int(data_len*test_ratio):] 70 | val_data = data[-int(data_len*(test_ratio+val_ratio)):-int(data_len*test_ratio)] 71 | train_data = data[:-int(data_len*(test_ratio+val_ratio))] 72 | return train_data, val_data, test_data 73 | 74 | def data_loader(X, Y, batch_size, shuffle=True, drop_last=True): 75 | cuda = True if torch.cuda.is_available() else False 76 | TensorFloat = torch.cuda.FloatTensor if cuda else torch.FloatTensor 77 | X, Y = TensorFloat(X), TensorFloat(Y) 78 | data = torch.utils.data.TensorDataset(X, Y) 79 | dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, 80 | shuffle=shuffle, drop_last=drop_last) 81 | return dataloader 82 | 83 | 84 | def get_dataloader(args, normalizer = 'std', tod=False, dow=False, weather=False, single=True): 85 | #load raw st dataset 86 | data = load_st_dataset(args.dataset) # B, N, D 87 | #normalize st data 88 | data, scaler = normalize_dataset(data, normalizer, args.column_wise) 89 | #spilit dataset by days or by ratio 90 | if args.test_ratio > 1: 91 | data_train, data_val, data_test = split_data_by_days(data, args.val_ratio, args.test_ratio) 92 | else: 93 | data_train, data_val, data_test = split_data_by_ratio(data, args.val_ratio, args.test_ratio) 94 | #add time window 95 | x_tra, y_tra = Add_Window_Horizon(data_train, args.lag, args.horizon, single) 96 | x_val, y_val = Add_Window_Horizon(data_val, args.lag, args.horizon, single) 97 | x_test, y_test = Add_Window_Horizon(data_test, args.lag, args.horizon, single) 98 | print('Train: ', x_tra.shape, y_tra.shape) 99 | print('Val: ', x_val.shape, y_val.shape) 100 | print('Test: ', x_test.shape, y_test.shape) 101 | ##############get dataloader###################### 102 | train_dataloader = data_loader(x_tra, y_tra, args.batch_size, shuffle=True, drop_last=True) 103 | if len(x_val) == 0: 104 | val_dataloader = None 105 | else: 106 | val_dataloader = data_loader(x_val, y_val, args.batch_size, shuffle=False, drop_last=True) 107 | test_dataloader = data_loader(x_test, y_test, args.batch_size, shuffle=False, drop_last=False) 108 | return train_dataloader, val_dataloader, test_dataloader, scaler 109 | 110 | 111 | if __name__ == '__main__': 112 | import argparse 113 | #MetrLA 207; BikeNYC 128; SIGIR_solar 137; SIGIR_electric 321 114 | DATASET = 'SIGIR_electric' 115 | if DATASET == 'MetrLA': 116 | NODE_NUM = 207 117 | elif DATASET == 'BikeNYC': 118 | NODE_NUM = 128 119 | elif DATASET == 'SIGIR_solar': 120 | NODE_NUM = 137 121 | elif DATASET == 'SIGIR_electric': 122 | NODE_NUM = 321 123 | parser = argparse.ArgumentParser(description='PyTorch dataloader') 124 | parser.add_argument('--dataset', default=DATASET, type=str) 125 | parser.add_argument('--num_nodes', default=NODE_NUM, type=int) 126 | parser.add_argument('--val_ratio', default=0.1, type=float) 127 | parser.add_argument('--test_ratio', default=0.2, type=float) 128 | parser.add_argument('--lag', default=12, type=int) 129 | parser.add_argument('--horizon', default=12, type=int) 130 | parser.add_argument('--batch_size', default=64, type=int) 131 | args = parser.parse_args() 132 | train_dataloader, val_dataloader, test_dataloader, scaler = get_dataloader(args, normalizer = 'std', tod=False, dow=False, weather=False, single=True) -------------------------------------------------------------------------------- /lib/load_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def load_st_dataset(dataset): 5 | #output B, N, D 6 | if dataset == 'PEMSD4': 7 | data_path = os.path.join('../data/PeMSD4/pems04.npz') 8 | data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data 9 | elif dataset == 'PEMSD8': 10 | data_path = os.path.join('../data/PeMSD8/pems08.npz') 11 | data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data 12 | else: 13 | raise ValueError 14 | if len(data.shape) == 2: 15 | data = np.expand_dims(data, axis=-1) 16 | print('Load %s Dataset shaped: ' % dataset, data.shape, data.max(), data.min(), data.mean(), np.median(data)) 17 | return data 18 | -------------------------------------------------------------------------------- /lib/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from datetime import datetime 4 | 5 | def get_logger(root, name=None, debug=True): 6 | #when debug is true, show DEBUG and INFO in screen 7 | #when debug is false, show DEBUG in file and info in both screen&file 8 | #INFO will always be in screen 9 | # create a logger 10 | logger = logging.getLogger(name) 11 | #critical > error > warning > info > debug > notset 12 | logger.setLevel(logging.DEBUG) 13 | 14 | # define the formate 15 | formatter = logging.Formatter('%(asctime)s: %(message)s', "%Y-%m-%d %H:%M") 16 | # create another handler for output log to console 17 | console_handler = logging.StreamHandler() 18 | if debug: 19 | console_handler.setLevel(logging.DEBUG) 20 | else: 21 | console_handler.setLevel(logging.INFO) 22 | # create a handler for write log to file 23 | logfile = os.path.join(root, 'run.log') 24 | print('Creat Log File in: ', logfile) 25 | file_handler = logging.FileHandler(logfile, mode='w') 26 | file_handler.setLevel(logging.DEBUG) 27 | file_handler.setFormatter(formatter) 28 | console_handler.setFormatter(formatter) 29 | # add Handler to logger 30 | logger.addHandler(console_handler) 31 | if not debug: 32 | logger.addHandler(file_handler) 33 | return logger 34 | 35 | 36 | if __name__ == '__main__': 37 | time = datetime.now().strftime('%Y%m%d%H%M%S') 38 | print(time) 39 | logger = get_logger('./log.txt', debug=True) 40 | logger.debug('this is a {} debug message'.format(1)) 41 | logger.info('this is an info message') 42 | logger.debug('this is a debug message') 43 | logger.info('this is an info message') 44 | logger.debug('this is a debug message') 45 | logger.info('this is an info message') -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Always evaluate the model with MAE, RMSE, MAPE, RRSE, PNBI, and oPNBI. 3 | Why add mask to MAE and RMSE? 4 | Filter the 0 that may be caused by error (such as loop sensor) 5 | Why add mask to MAPE and MARE? 6 | Ignore very small values (e.g., 0.5/0.5=100%) 7 | ''' 8 | import numpy as np 9 | import torch 10 | 11 | def MAE_torch(pred, true, mask_value=None): 12 | if mask_value != None: 13 | mask = torch.gt(true, mask_value) 14 | pred = torch.masked_select(pred, mask) 15 | true = torch.masked_select(true, mask) 16 | return torch.mean(torch.abs(true-pred)) 17 | 18 | def MSE_torch(pred, true, mask_value=None): 19 | if mask_value != None: 20 | mask = torch.gt(true, mask_value) 21 | pred = torch.masked_select(pred, mask) 22 | true = torch.masked_select(true, mask) 23 | return torch.mean((pred - true) ** 2) 24 | 25 | def RMSE_torch(pred, true, mask_value=None): 26 | if mask_value != None: 27 | mask = torch.gt(true, mask_value) 28 | pred = torch.masked_select(pred, mask) 29 | true = torch.masked_select(true, mask) 30 | return torch.sqrt(torch.mean((pred - true) ** 2)) 31 | 32 | def RRSE_torch(pred, true, mask_value=None): 33 | if mask_value != None: 34 | mask = torch.gt(true, mask_value) 35 | pred = torch.masked_select(pred, mask) 36 | true = torch.masked_select(true, mask) 37 | return torch.sqrt(torch.sum((pred - true) ** 2)) / torch.sqrt(torch.sum((pred - true.mean()) ** 2)) 38 | 39 | def CORR_torch(pred, true, mask_value=None): 40 | #input B, T, N, D or B, N, D or B, N 41 | if len(pred.shape) == 2: 42 | pred = pred.unsqueeze(dim=1).unsqueeze(dim=1) 43 | true = true.unsqueeze(dim=1).unsqueeze(dim=1) 44 | elif len(pred.shape) == 3: 45 | pred = pred.transpose(1, 2).unsqueeze(dim=1) 46 | true = true.transpose(1, 2).unsqueeze(dim=1) 47 | elif len(pred.shape) == 4: 48 | #B, T, N, D -> B, T, D, N 49 | pred = pred.transpose(2, 3) 50 | true = true.transpose(2, 3) 51 | else: 52 | raise ValueError 53 | dims = (0, 1, 2) 54 | pred_mean = pred.mean(dim=dims) 55 | true_mean = true.mean(dim=dims) 56 | pred_std = pred.std(dim=dims) 57 | true_std = true.std(dim=dims) 58 | correlation = ((pred - pred_mean)*(true - true_mean)).mean(dim=dims) / (pred_std*true_std) 59 | index = (true_std != 0) 60 | correlation = (correlation[index]).mean() 61 | return correlation 62 | 63 | 64 | def MAPE_torch(pred, true, mask_value=None): 65 | if mask_value != None: 66 | mask = torch.gt(true, mask_value) 67 | pred = torch.masked_select(pred, mask) 68 | true = torch.masked_select(true, mask) 69 | return torch.mean(torch.abs(torch.div((true - pred), true))) 70 | 71 | def PNBI_torch(pred, true, mask_value=None): 72 | if mask_value != None: 73 | mask = torch.gt(true, mask_value) 74 | pred = torch.masked_select(pred, mask) 75 | true = torch.masked_select(true, mask) 76 | indicator = torch.gt(pred - true, 0).float() 77 | return indicator.mean() 78 | 79 | def oPNBI_torch(pred, true, mask_value=None): 80 | if mask_value != None: 81 | mask = torch.gt(true, mask_value) 82 | pred = torch.masked_select(pred, mask) 83 | true = torch.masked_select(true, mask) 84 | bias = (true+pred) / (2*true) 85 | return bias.mean() 86 | 87 | def MARE_torch(pred, true, mask_value=None): 88 | if mask_value != None: 89 | mask = torch.gt(true, mask_value) 90 | pred = torch.masked_select(pred, mask) 91 | true = torch.masked_select(true, mask) 92 | return torch.div(torch.sum(torch.abs((true - pred))), torch.sum(true)) 93 | 94 | def SMAPE_torch(pred, true, mask_value=None): 95 | if mask_value != None: 96 | mask = torch.gt(true, mask_value) 97 | pred = torch.masked_select(pred, mask) 98 | true = torch.masked_select(true, mask) 99 | return torch.mean(torch.abs(true-pred)/(torch.abs(true)+torch.abs(pred))) 100 | 101 | 102 | def MAE_np(pred, true, mask_value=None): 103 | if mask_value != None: 104 | mask = np.where(true > (mask_value), True, False) 105 | true = true[mask] 106 | pred = pred[mask] 107 | MAE = np.mean(np.absolute(pred-true)) 108 | return MAE 109 | 110 | def RMSE_np(pred, true, mask_value=None): 111 | if mask_value != None: 112 | mask = np.where(true > (mask_value), True, False) 113 | true = true[mask] 114 | pred = pred[mask] 115 | RMSE = np.sqrt(np.mean(np.square(pred-true))) 116 | return RMSE 117 | 118 | #Root Relative Squared Error 119 | def RRSE_np(pred, true, mask_value=None): 120 | if mask_value != None: 121 | mask = np.where(true > (mask_value), True, False) 122 | true = true[mask] 123 | pred = pred[mask] 124 | mean = true.mean() 125 | return np.divide(np.sqrt(np.sum((pred-true) ** 2)), np.sqrt(np.sum((true-mean) ** 2))) 126 | 127 | def MAPE_np(pred, true, mask_value=None): 128 | if mask_value != None: 129 | mask = np.where(true > (mask_value), True, False) 130 | true = true[mask] 131 | pred = pred[mask] 132 | return np.mean(np.absolute(np.divide((true - pred), true))) 133 | 134 | def PNBI_np(pred, true, mask_value=None): 135 | #if PNBI=0, all pred are smaller than true 136 | #if PNBI=1, all pred are bigger than true 137 | if mask_value != None: 138 | mask = np.where(true > (mask_value), True, False) 139 | true = true[mask] 140 | pred = pred[mask] 141 | bias = pred-true 142 | indicator = np.where(bias>0, True, False) 143 | return indicator.mean() 144 | 145 | def oPNBI_np(pred, true, mask_value=None): 146 | #if oPNBI>1, pred are bigger than true 147 | #if oPNBI<1, pred are smaller than true 148 | #however, this metric is too sentive to small values. Not good! 149 | if mask_value != None: 150 | mask = np.where(true > (mask_value), True, False) 151 | true = true[mask] 152 | pred = pred[mask] 153 | bias = (true + pred) / (2 * true) 154 | return bias.mean() 155 | 156 | def MARE_np(pred, true, mask_value=None): 157 | if mask_value != None: 158 | mask = np.where(true> (mask_value), True, False) 159 | true = true[mask] 160 | pred = pred[mask] 161 | return np.divide(np.sum(np.absolute((true - pred))), np.sum(true)) 162 | 163 | def CORR_np(pred, true, mask_value=None): 164 | #input B, T, N, D or B, N, D or B, N 165 | if len(pred.shape) == 2: 166 | #B, N 167 | pred = pred.unsqueeze(dim=1).unsqueeze(dim=1) 168 | true = true.unsqueeze(dim=1).unsqueeze(dim=1) 169 | elif len(pred.shape) == 3: 170 | #np.transpose include permute, B, T, N 171 | pred = np.expand_dims(pred.transpose(0, 2, 1), axis=1) 172 | true = np.expand_dims(true.transpose(0, 2, 1), axis=1) 173 | elif len(pred.shape) == 4: 174 | #B, T, N, D -> B, T, D, N 175 | pred = pred.transpose(0, 1, 2, 3) 176 | true = true.transpose(0, 1, 2, 3) 177 | else: 178 | raise ValueError 179 | dims = (0, 1, 2) 180 | pred_mean = pred.mean(axis=dims) 181 | true_mean = true.mean(axis=dims) 182 | pred_std = pred.std(axis=dims) 183 | true_std = true.std(axis=dims) 184 | correlation = ((pred - pred_mean)*(true - true_mean)).mean(axis=dims) / (pred_std*true_std) 185 | index = (true_std != 0) 186 | correlation = (correlation[index]).mean() 187 | return correlation 188 | 189 | def All_Metrics(pred, true, mask1, mask2): 190 | #mask1 filter the very small value, mask2 filter the value lower than a defined threshold 191 | assert type(pred) == type(true) 192 | if type(pred) == np.ndarray: 193 | mae = MAE_np(pred, true, mask1) 194 | rmse = RMSE_np(pred, true, mask1) 195 | mape = MAPE_np(pred, true, mask2) 196 | rrse = RRSE_np(pred, true, mask1) 197 | corr = 0 198 | #corr = CORR_np(pred, true, mask1) 199 | #pnbi = PNBI_np(pred, true, mask1) 200 | #opnbi = oPNBI_np(pred, true, mask2) 201 | elif type(pred) == torch.Tensor: 202 | mae = MAE_torch(pred, true, mask1) 203 | rmse = RMSE_torch(pred, true, mask1) 204 | mape = MAPE_torch(pred, true, mask2) 205 | rrse = RRSE_torch(pred, true, mask1) 206 | corr = CORR_torch(pred, true, mask1) 207 | #pnbi = PNBI_torch(pred, true, mask1) 208 | #opnbi = oPNBI_torch(pred, true, mask2) 209 | else: 210 | raise TypeError 211 | return mae, rmse, mape, rrse, corr 212 | 213 | def SIGIR_Metrics(pred, true, mask1, mask2): 214 | rrse = RRSE_torch(pred, true, mask1) 215 | corr = CORR_torch(pred, true, 0) 216 | return rrse, corr 217 | 218 | if __name__ == '__main__': 219 | pred = torch.Tensor([1, 2, 3,4]) 220 | true = torch.Tensor([2, 1, 4,5]) 221 | print(All_Metrics(pred, true, None, None)) 222 | 223 | -------------------------------------------------------------------------------- /lib/normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class NScaler(object): 6 | def transform(self, data): 7 | return data 8 | def inverse_transform(self, data): 9 | return data 10 | 11 | class StandardScaler: 12 | """ 13 | Standard the input 14 | """ 15 | 16 | def __init__(self, mean, std): 17 | self.mean = mean 18 | self.std = std 19 | 20 | def transform(self, data): 21 | return (data - self.mean) / self.std 22 | 23 | def inverse_transform(self, data): 24 | if type(data) == torch.Tensor and type(self.mean) == np.ndarray: 25 | self.std = torch.from_numpy(self.std).to(data.device).type(data.dtype) 26 | self.mean = torch.from_numpy(self.mean).to(data.device).type(data.dtype) 27 | return (data * self.std) + self.mean 28 | 29 | class MinMax01Scaler: 30 | """ 31 | Standard the input 32 | """ 33 | 34 | def __init__(self, min, max): 35 | self.min = min 36 | self.max = max 37 | 38 | def transform(self, data): 39 | return (data - self.min) / (self.max - self.min) 40 | 41 | def inverse_transform(self, data): 42 | if type(data) == torch.Tensor and type(self.min) == np.ndarray: 43 | self.min = torch.from_numpy(self.min).to(data.device).type(data.dtype) 44 | self.max = torch.from_numpy(self.max).to(data.device).type(data.dtype) 45 | return (data * (self.max - self.min) + self.min) 46 | 47 | class MinMax11Scaler: 48 | """ 49 | Standard the input 50 | """ 51 | 52 | def __init__(self, min, max): 53 | self.min = min 54 | self.max = max 55 | 56 | def transform(self, data): 57 | return ((data - self.min) / (self.max - self.min)) * 2. - 1. 58 | 59 | def inverse_transform(self, data): 60 | if type(data) == torch.Tensor and type(self.min) == np.ndarray: 61 | self.min = torch.from_numpy(self.min).to(data.device).type(data.dtype) 62 | self.max = torch.from_numpy(self.max).to(data.device).type(data.dtype) 63 | return ((data + 1.) / 2.) * (self.max - self.min) + self.min 64 | 65 | class ColumnMinMaxScaler(): 66 | #Note: to use this scale, must init the min and max with column min and column max 67 | def __init__(self, min, max): 68 | self.min = min 69 | self.min_max = max - self.min 70 | self.min_max[self.min_max==0] = 1 71 | def transform(self, data): 72 | print(data.shape, self.min_max.shape) 73 | return (data - self.min) / self.min_max 74 | 75 | def inverse_transform(self, data): 76 | if type(data) == torch.Tensor and type(self.min) == np.ndarray: 77 | self.min_max = torch.from_numpy(self.min_max).to(data.device).type(torch.float32) 78 | self.min = torch.from_numpy(self.min).to(data.device).type(torch.float32) 79 | #print(data.dtype, self.min_max.dtype, self.min.dtype) 80 | return (data * self.min_max + self.min) 81 | 82 | def one_hot_by_column(data): 83 | #data is a 2D numpy array 84 | len = data.shape[0] 85 | for i in range(data.shape[1]): 86 | column = data[:, i] 87 | max = column.max() 88 | min = column.min() 89 | #print(len, max, min) 90 | zero_matrix = np.zeros((len, max-min+1)) 91 | zero_matrix[np.arange(len), column-min] = 1 92 | if i == 0: 93 | encoded = zero_matrix 94 | else: 95 | encoded = np.hstack((encoded, zero_matrix)) 96 | return encoded 97 | 98 | 99 | def minmax_by_column(data): 100 | # data is a 2D numpy array 101 | for i in range(data.shape[1]): 102 | column = data[:, i] 103 | max = column.max() 104 | min = column.min() 105 | column = (column - min) / (max - min) 106 | column = column[:, np.newaxis] 107 | if i == 0: 108 | _normalized = column 109 | else: 110 | _normalized = np.hstack((_normalized, column)) 111 | return _normalized 112 | 113 | 114 | if __name__ == '__main__': 115 | 116 | 117 | test_data = np.array([[0,0,0, 1], [0, 1, 3, 2], [0, 2, 1, 3]]) 118 | print(test_data) 119 | minimum = test_data.min(axis=1) 120 | print(minimum, minimum.shape, test_data.shape) 121 | maximum = test_data.max(axis=1) 122 | print(maximum) 123 | print(test_data-minimum) 124 | test_data = (test_data-minimum) / (maximum-minimum) 125 | print(test_data) 126 | print(0 == 0) 127 | print(0.00 == 0) 128 | print(0 == 0.00) 129 | #print(one_hot_by_column(test_data)) 130 | #print(minmax_by_column(test_data)) -------------------------------------------------------------------------------- /model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeiBAI/AGCRN/7fbbf2aeb099242098a3cf482b55cd45d7295c28/model/.DS_Store -------------------------------------------------------------------------------- /model/AGCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | class AVWGCN(nn.Module): 6 | def __init__(self, dim_in, dim_out, cheb_k, embed_dim): 7 | super(AVWGCN, self).__init__() 8 | self.cheb_k = cheb_k 9 | self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) 10 | self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) 11 | def forward(self, x, node_embeddings): 12 | #x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N] 13 | #output shape [B, N, C] 14 | node_num = node_embeddings.shape[0] 15 | supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1) 16 | support_set = [torch.eye(node_num).to(supports.device), supports] 17 | #default cheb_k = 3 18 | for k in range(2, self.cheb_k): 19 | support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2]) 20 | supports = torch.stack(support_set, dim=0) 21 | weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out 22 | bias = torch.matmul(node_embeddings, self.bias_pool) #N, dim_out 23 | x_g = torch.einsum("knm,bmc->bknc", supports, x) #B, cheb_k, N, dim_in 24 | x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in 25 | x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out 26 | return x_gconv -------------------------------------------------------------------------------- /model/AGCRN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.AGCRNCell import AGCRNCell 4 | 5 | class AVWDCRNN(nn.Module): 6 | def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): 7 | super(AVWDCRNN, self).__init__() 8 | assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.' 9 | self.node_num = node_num 10 | self.input_dim = dim_in 11 | self.num_layers = num_layers 12 | self.dcrnn_cells = nn.ModuleList() 13 | self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)) 14 | for _ in range(1, num_layers): 15 | self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)) 16 | 17 | def forward(self, x, init_state, node_embeddings): 18 | #shape of x: (B, T, N, D) 19 | #shape of init_state: (num_layers, B, N, hidden_dim) 20 | assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim 21 | seq_length = x.shape[1] 22 | current_inputs = x 23 | output_hidden = [] 24 | for i in range(self.num_layers): 25 | state = init_state[i] 26 | inner_states = [] 27 | for t in range(seq_length): 28 | state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, node_embeddings) 29 | inner_states.append(state) 30 | output_hidden.append(state) 31 | current_inputs = torch.stack(inner_states, dim=1) 32 | #current_inputs: the outputs of last layer: (B, T, N, hidden_dim) 33 | #output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim) 34 | #last_state: (B, N, hidden_dim) 35 | return current_inputs, output_hidden 36 | 37 | def init_hidden(self, batch_size): 38 | init_states = [] 39 | for i in range(self.num_layers): 40 | init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size)) 41 | return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim) 42 | 43 | class AGCRN(nn.Module): 44 | def __init__(self, args): 45 | super(AGCRN, self).__init__() 46 | self.num_node = args.num_nodes 47 | self.input_dim = args.input_dim 48 | self.hidden_dim = args.rnn_units 49 | self.output_dim = args.output_dim 50 | self.horizon = args.horizon 51 | self.num_layers = args.num_layers 52 | 53 | self.default_graph = args.default_graph 54 | self.node_embeddings = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) 55 | 56 | self.encoder = AVWDCRNN(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_k, 57 | args.embed_dim, args.num_layers) 58 | 59 | #predictor 60 | self.end_conv = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) 61 | 62 | def forward(self, source, targets, teacher_forcing_ratio=0.5): 63 | #source: B, T_1, N, D 64 | #target: B, T_2, N, D 65 | #supports = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec1.transpose(0,1))), dim=1) 66 | 67 | init_state = self.encoder.init_hidden(source.shape[0]) 68 | output, _ = self.encoder(source, init_state, self.node_embeddings) #B, T, N, hidden 69 | output = output[:, -1:, :, :] #B, 1, N, hidden 70 | 71 | #CNN based predictor 72 | output = self.end_conv((output)) #B, T*C, N, 1 73 | output = output.squeeze(-1).reshape(-1, self.horizon, self.output_dim, self.num_node) 74 | output = output.permute(0, 1, 3, 2) #B, T, N, C 75 | 76 | return output -------------------------------------------------------------------------------- /model/AGCRNCell.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.AGCN import AVWGCN 4 | 5 | class AGCRNCell(nn.Module): 6 | def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): 7 | super(AGCRNCell, self).__init__() 8 | self.node_num = node_num 9 | self.hidden_dim = dim_out 10 | self.gate = AVWGCN(dim_in+self.hidden_dim, 2*dim_out, cheb_k, embed_dim) 11 | self.update = AVWGCN(dim_in+self.hidden_dim, dim_out, cheb_k, embed_dim) 12 | 13 | def forward(self, x, state, node_embeddings): 14 | #x: B, num_nodes, input_dim 15 | #state: B, num_nodes, hidden_dim 16 | state = state.to(x.device) 17 | input_and_state = torch.cat((x, state), dim=-1) 18 | z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings)) 19 | z, r = torch.split(z_r, self.hidden_dim, dim=-1) 20 | candidate = torch.cat((x, z*state), dim=-1) 21 | hc = torch.tanh(self.update(candidate, node_embeddings)) 22 | h = r*state + (1-r)*hc 23 | return h 24 | 25 | def init_hidden_state(self, batch_size): 26 | return torch.zeros(batch_size, self.node_num, self.hidden_dim) -------------------------------------------------------------------------------- /model/BasicTrainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import os 4 | import time 5 | import copy 6 | import numpy as np 7 | from lib.logger import get_logger 8 | from lib.metrics import All_Metrics 9 | 10 | class Trainer(object): 11 | def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, 12 | scaler, args, lr_scheduler=None): 13 | super(Trainer, self).__init__() 14 | self.model = model 15 | self.loss = loss 16 | self.optimizer = optimizer 17 | self.train_loader = train_loader 18 | self.val_loader = val_loader 19 | self.test_loader = test_loader 20 | self.scaler = scaler 21 | self.args = args 22 | self.lr_scheduler = lr_scheduler 23 | self.train_per_epoch = len(train_loader) 24 | if val_loader != None: 25 | self.val_per_epoch = len(val_loader) 26 | self.best_path = os.path.join(self.args.log_dir, 'best_model.pth') 27 | self.loss_figure_path = os.path.join(self.args.log_dir, 'loss.png') 28 | #log 29 | if os.path.isdir(args.log_dir) == False and not args.debug: 30 | os.makedirs(args.log_dir, exist_ok=True) 31 | self.logger = get_logger(args.log_dir, name=args.model, debug=args.debug) 32 | self.logger.info('Experiment log path in: {}'.format(args.log_dir)) 33 | #if not args.debug: 34 | #self.logger.info("Argument: %r", args) 35 | # for arg, value in sorted(vars(args).items()): 36 | # self.logger.info("Argument %s: %r", arg, value) 37 | 38 | def val_epoch(self, epoch, val_dataloader): 39 | self.model.eval() 40 | total_val_loss = 0 41 | 42 | with torch.no_grad(): 43 | for batch_idx, (data, target) in enumerate(val_dataloader): 44 | data = data[..., :self.args.input_dim] 45 | label = target[..., :self.args.output_dim] 46 | output = self.model(data, target, teacher_forcing_ratio=0.) 47 | if self.args.real_value: 48 | label = self.scaler.inverse_transform(label) 49 | loss = self.loss(output.cuda(), label) 50 | #a whole batch of Metr_LA is filtered 51 | if not torch.isnan(loss): 52 | total_val_loss += loss.item() 53 | val_loss = total_val_loss / len(val_dataloader) 54 | self.logger.info('**********Val Epoch {}: average Loss: {:.6f}'.format(epoch, val_loss)) 55 | return val_loss 56 | 57 | def train_epoch(self, epoch): 58 | self.model.train() 59 | total_loss = 0 60 | for batch_idx, (data, target) in enumerate(self.train_loader): 61 | data = data[..., :self.args.input_dim] 62 | label = target[..., :self.args.output_dim] # (..., 1) 63 | self.optimizer.zero_grad() 64 | 65 | #teacher_forcing for RNN encoder-decoder model 66 | #if teacher_forcing_ratio = 1: use label as input in the decoder for all steps 67 | if self.args.teacher_forcing: 68 | global_step = (epoch - 1) * self.train_per_epoch + batch_idx 69 | teacher_forcing_ratio = self._compute_sampling_threshold(global_step, self.args.tf_decay_steps) 70 | else: 71 | teacher_forcing_ratio = 1. 72 | #data and target shape: B, T, N, F; output shape: B, T, N, F 73 | output = self.model(data, target, teacher_forcing_ratio=teacher_forcing_ratio) 74 | if self.args.real_value: 75 | label = self.scaler.inverse_transform(label) 76 | loss = self.loss(output.cuda(), label) 77 | loss.backward() 78 | 79 | # add max grad clipping 80 | if self.args.grad_norm: 81 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 82 | self.optimizer.step() 83 | total_loss += loss.item() 84 | 85 | #log information 86 | if batch_idx % self.args.log_step == 0: 87 | self.logger.info('Train Epoch {}: {}/{} Loss: {:.6f}'.format( 88 | epoch, batch_idx, self.train_per_epoch, loss.item())) 89 | train_epoch_loss = total_loss/self.train_per_epoch 90 | self.logger.info('**********Train Epoch {}: averaged Loss: {:.6f}, tf_ratio: {:.6f}'.format(epoch, train_epoch_loss, teacher_forcing_ratio)) 91 | 92 | #learning rate decay 93 | if self.args.lr_decay: 94 | self.lr_scheduler.step() 95 | return train_epoch_loss 96 | 97 | def train(self): 98 | best_model = None 99 | best_loss = float('inf') 100 | not_improved_count = 0 101 | train_loss_list = [] 102 | val_loss_list = [] 103 | start_time = time.time() 104 | for epoch in range(1, self.args.epochs + 1): 105 | #epoch_time = time.time() 106 | train_epoch_loss = self.train_epoch(epoch) 107 | #print(time.time()-epoch_time) 108 | #exit() 109 | if self.val_loader == None: 110 | val_dataloader = self.test_loader 111 | else: 112 | val_dataloader = self.val_loader 113 | val_epoch_loss = self.val_epoch(epoch, val_dataloader) 114 | 115 | #print('LR:', self.optimizer.param_groups[0]['lr']) 116 | train_loss_list.append(train_epoch_loss) 117 | val_loss_list.append(val_epoch_loss) 118 | if train_epoch_loss > 1e6: 119 | self.logger.warning('Gradient explosion detected. Ending...') 120 | break 121 | #if self.val_loader == None: 122 | #val_epoch_loss = train_epoch_loss 123 | if val_epoch_loss < best_loss: 124 | best_loss = val_epoch_loss 125 | not_improved_count = 0 126 | best_state = True 127 | else: 128 | not_improved_count += 1 129 | best_state = False 130 | # early stop 131 | if self.args.early_stop: 132 | if not_improved_count == self.args.early_stop_patience: 133 | self.logger.info("Validation performance didn\'t improve for {} epochs. " 134 | "Training stops.".format(self.args.early_stop_patience)) 135 | break 136 | # save the best state 137 | if best_state == True: 138 | self.logger.info('*********************************Current best model saved!') 139 | best_model = copy.deepcopy(self.model.state_dict()) 140 | 141 | training_time = time.time() - start_time 142 | self.logger.info("Total training time: {:.4f}min, best loss: {:.6f}".format((training_time / 60), best_loss)) 143 | 144 | #save the best model to file 145 | if not self.args.debug: 146 | torch.save(best_model, self.best_path) 147 | self.logger.info("Saving current best model to " + self.best_path) 148 | 149 | #test 150 | self.model.load_state_dict(best_model) 151 | #self.val_epoch(self.args.epochs, self.test_loader) 152 | self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) 153 | 154 | def save_checkpoint(self): 155 | state = { 156 | 'state_dict': self.model.state_dict(), 157 | 'optimizer': self.optimizer.state_dict(), 158 | 'config': self.args 159 | } 160 | torch.save(state, self.best_path) 161 | self.logger.info("Saving current best model to " + self.best_path) 162 | 163 | @staticmethod 164 | def test(model, args, data_loader, scaler, logger, path=None): 165 | if path != None: 166 | check_point = torch.load(path) 167 | state_dict = check_point['state_dict'] 168 | args = check_point['config'] 169 | model.load_state_dict(state_dict) 170 | model.to(args.device) 171 | model.eval() 172 | y_pred = [] 173 | y_true = [] 174 | with torch.no_grad(): 175 | for batch_idx, (data, target) in enumerate(data_loader): 176 | data = data[..., :args.input_dim] 177 | label = target[..., :args.output_dim] 178 | output = model(data, target, teacher_forcing_ratio=0) 179 | y_true.append(label) 180 | y_pred.append(output) 181 | y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) 182 | if args.real_value: 183 | y_pred = torch.cat(y_pred, dim=0) 184 | else: 185 | y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) 186 | np.save('./{}_true.npy'.format(args.dataset), y_true.cpu().numpy()) 187 | np.save('./{}_pred.npy'.format(args.dataset), y_pred.cpu().numpy()) 188 | for t in range(y_true.shape[1]): 189 | mae, rmse, mape, _, _ = All_Metrics(y_pred[:, t, ...], y_true[:, t, ...], 190 | args.mae_thresh, args.mape_thresh) 191 | logger.info("Horizon {:02d}, MAE: {:.2f}, RMSE: {:.2f}, MAPE: {:.4f}%".format( 192 | t + 1, mae, rmse, mape*100)) 193 | mae, rmse, mape, _, _ = All_Metrics(y_pred, y_true, args.mae_thresh, args.mape_thresh) 194 | logger.info("Average Horizon, MAE: {:.2f}, RMSE: {:.2f}, MAPE: {:.4f}%".format( 195 | mae, rmse, mape*100)) 196 | 197 | @staticmethod 198 | def _compute_sampling_threshold(global_step, k): 199 | """ 200 | Computes the sampling probability for scheduled sampling using inverse sigmoid. 201 | :param global_step: 202 | :param k: 203 | :return: 204 | """ 205 | return k / (k + math.exp(global_step / k)) -------------------------------------------------------------------------------- /model/PEMSD4_AGCRN.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 307 3 | lag = 12 4 | horizon = 12 5 | val_ratio = 0.2 6 | test_ratio = 0.2 7 | tod = False 8 | normalizer = std 9 | column_wise = False 10 | default_graph = True 11 | 12 | [model] 13 | input_dim = 1 14 | output_dim = 1 15 | embed_dim = 10 16 | rnn_units = 64 17 | num_layers = 2 18 | cheb_order = 2 19 | 20 | [train] 21 | loss_func = mae 22 | seed = 10 23 | batch_size = 64 24 | epochs = 100 25 | lr_init = 0.003 26 | lr_decay = False 27 | lr_decay_rate = 0.3 28 | lr_decay_step = 5,20,40,70 29 | early_stop = True 30 | early_stop_patience = 15 31 | grad_norm = False 32 | max_grad_norm = 5 33 | real_value = True 34 | 35 | [test] 36 | mae_thresh = None 37 | mape_thresh = 0. 38 | 39 | [log] 40 | log_step = 20 41 | plot = False -------------------------------------------------------------------------------- /model/PEMSD8_AGCRN.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 170 3 | lag = 12 4 | horizon = 12 5 | val_ratio = 0.2 6 | test_ratio = 0.2 7 | tod = False 8 | normalizer = std 9 | column_wise = False 10 | default_graph = True 11 | 12 | [model] 13 | input_dim = 1 14 | output_dim = 1 15 | embed_dim = 2 16 | rnn_units = 64 17 | num_layers = 2 18 | cheb_order = 2 19 | 20 | [train] 21 | loss_func = mae 22 | seed = 12 23 | batch_size = 64 24 | epochs = 100 25 | lr_init = 0.003 26 | lr_decay = False 27 | lr_decay_rate = 0.3 28 | lr_decay_step = 5,20,40,70 29 | early_stop = True 30 | early_stop_patience = 15 31 | grad_norm = False 32 | max_grad_norm = 5 33 | real_value = True 34 | 35 | [test] 36 | mae_thresh = None 37 | mape_thresh = 0. 38 | 39 | [log] 40 | log_step = 20 41 | plot = False -------------------------------------------------------------------------------- /model/Run.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | file_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 5 | print(file_dir) 6 | sys.path.append(file_dir) 7 | 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | import argparse 12 | import configparser 13 | from datetime import datetime 14 | from model.AGCRN import AGCRN as Network 15 | from model.BasicTrainer import Trainer 16 | from lib.TrainInits import init_seed 17 | from lib.dataloader import get_dataloader 18 | from lib.TrainInits import print_model_parameters 19 | 20 | 21 | #*************************************************************************# 22 | Mode = 'Train' 23 | DEBUG = 'True' 24 | DATASET = 'PEMSD4' #PEMSD4 or PEMSD8 25 | DEVICE = 'cuda:0' 26 | MODEL = 'AGCRN' 27 | 28 | #get configuration 29 | config_file = './{}_{}.conf'.format(DATASET, MODEL) 30 | #print('Read configuration file: %s' % (config_file)) 31 | config = configparser.ConfigParser() 32 | config.read(config_file) 33 | 34 | from lib.metrics import MAE_torch 35 | def masked_mae_loss(scaler, mask_value): 36 | def loss(preds, labels): 37 | if scaler: 38 | preds = scaler.inverse_transform(preds) 39 | labels = scaler.inverse_transform(labels) 40 | mae = MAE_torch(pred=preds, true=labels, mask_value=mask_value) 41 | return mae 42 | return loss 43 | 44 | #parser 45 | args = argparse.ArgumentParser(description='arguments') 46 | args.add_argument('--dataset', default=DATASET, type=str) 47 | args.add_argument('--mode', default=Mode, type=str) 48 | args.add_argument('--device', default=DEVICE, type=str, help='indices of GPUs') 49 | args.add_argument('--debug', default=DEBUG, type=eval) 50 | args.add_argument('--model', default=MODEL, type=str) 51 | args.add_argument('--cuda', default=True, type=bool) 52 | #data 53 | args.add_argument('--val_ratio', default=config['data']['val_ratio'], type=float) 54 | args.add_argument('--test_ratio', default=config['data']['test_ratio'], type=float) 55 | args.add_argument('--lag', default=config['data']['lag'], type=int) 56 | args.add_argument('--horizon', default=config['data']['horizon'], type=int) 57 | args.add_argument('--num_nodes', default=config['data']['num_nodes'], type=int) 58 | args.add_argument('--tod', default=config['data']['tod'], type=eval) 59 | args.add_argument('--normalizer', default=config['data']['normalizer'], type=str) 60 | args.add_argument('--column_wise', default=config['data']['column_wise'], type=eval) 61 | args.add_argument('--default_graph', default=config['data']['default_graph'], type=eval) 62 | #model 63 | args.add_argument('--input_dim', default=config['model']['input_dim'], type=int) 64 | args.add_argument('--output_dim', default=config['model']['output_dim'], type=int) 65 | args.add_argument('--embed_dim', default=config['model']['embed_dim'], type=int) 66 | args.add_argument('--rnn_units', default=config['model']['rnn_units'], type=int) 67 | args.add_argument('--num_layers', default=config['model']['num_layers'], type=int) 68 | args.add_argument('--cheb_k', default=config['model']['cheb_order'], type=int) 69 | #train 70 | args.add_argument('--loss_func', default=config['train']['loss_func'], type=str) 71 | args.add_argument('--seed', default=config['train']['seed'], type=int) 72 | args.add_argument('--batch_size', default=config['train']['batch_size'], type=int) 73 | args.add_argument('--epochs', default=config['train']['epochs'], type=int) 74 | args.add_argument('--lr_init', default=config['train']['lr_init'], type=float) 75 | args.add_argument('--lr_decay', default=config['train']['lr_decay'], type=eval) 76 | args.add_argument('--lr_decay_rate', default=config['train']['lr_decay_rate'], type=float) 77 | args.add_argument('--lr_decay_step', default=config['train']['lr_decay_step'], type=str) 78 | args.add_argument('--early_stop', default=config['train']['early_stop'], type=eval) 79 | args.add_argument('--early_stop_patience', default=config['train']['early_stop_patience'], type=int) 80 | args.add_argument('--grad_norm', default=config['train']['grad_norm'], type=eval) 81 | args.add_argument('--max_grad_norm', default=config['train']['max_grad_norm'], type=int) 82 | args.add_argument('--teacher_forcing', default=False, type=bool) 83 | #args.add_argument('--tf_decay_steps', default=2000, type=int, help='teacher forcing decay steps') 84 | args.add_argument('--real_value', default=config['train']['real_value'], type=eval, help = 'use real value for loss calculation') 85 | #test 86 | args.add_argument('--mae_thresh', default=config['test']['mae_thresh'], type=eval) 87 | args.add_argument('--mape_thresh', default=config['test']['mape_thresh'], type=float) 88 | #log 89 | args.add_argument('--log_dir', default='./', type=str) 90 | args.add_argument('--log_step', default=config['log']['log_step'], type=int) 91 | args.add_argument('--plot', default=config['log']['plot'], type=eval) 92 | args = args.parse_args() 93 | init_seed(args.seed) 94 | if torch.cuda.is_available(): 95 | torch.cuda.set_device(int(args.device[5])) 96 | else: 97 | args.device = 'cpu' 98 | 99 | #init model 100 | model = Network(args) 101 | model = model.to(args.device) 102 | for p in model.parameters(): 103 | if p.dim() > 1: 104 | nn.init.xavier_uniform_(p) 105 | else: 106 | nn.init.uniform_(p) 107 | print_model_parameters(model, only_num=False) 108 | 109 | #load dataset 110 | train_loader, val_loader, test_loader, scaler = get_dataloader(args, 111 | normalizer=args.normalizer, 112 | tod=args.tod, dow=False, 113 | weather=False, single=False) 114 | 115 | #init loss function, optimizer 116 | if args.loss_func == 'mask_mae': 117 | loss = masked_mae_loss(scaler, mask_value=0.0) 118 | elif args.loss_func == 'mae': 119 | loss = torch.nn.L1Loss().to(args.device) 120 | elif args.loss_func == 'mse': 121 | loss = torch.nn.MSELoss().to(args.device) 122 | else: 123 | raise ValueError 124 | 125 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr_init, eps=1.0e-8, 126 | weight_decay=0, amsgrad=False) 127 | #learning rate decay 128 | lr_scheduler = None 129 | if args.lr_decay: 130 | print('Applying learning rate decay.') 131 | lr_decay_steps = [int(i) for i in list(args.lr_decay_step.split(','))] 132 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, 133 | milestones=lr_decay_steps, 134 | gamma=args.lr_decay_rate) 135 | #lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=64) 136 | 137 | #config log path 138 | current_time = datetime.now().strftime('%Y%m%d%H%M%S') 139 | current_dir = os.path.dirname(os.path.realpath(__file__)) 140 | log_dir = os.path.join(current_dir,'experiments', args.dataset, current_time) 141 | args.log_dir = log_dir 142 | 143 | #start training 144 | trainer = Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, 145 | args, lr_scheduler=lr_scheduler) 146 | if args.mode == 'train': 147 | trainer.train() 148 | elif args.mode == 'test': 149 | model.load_state_dict(torch.load('../pre-trained/{}.pth'.format(args.dataset))) 150 | print("Load saved model") 151 | trainer.test(model, trainer.args, test_loader, scaler, trainer.logger) 152 | else: 153 | raise ValueError 154 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting 2 | 3 | This folder concludes the code and data of our AGCRN model: [Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting](https://arxiv.org/pdf/2007.02842.pdf), which has been accepted to NeurIPS 2020. 4 | 5 | ## Structure: 6 | 7 | * data: including PEMSD4 and PEMSD8 dataset used in our experiments, which are released by and available at [ASTGCN](https://github.com/Davidham3/ASTGCN/tree/master/data). 8 | 9 | * lib: contains self-defined modules for our work, such as data loading, data pre-process, normalization, and evaluate metrics. 10 | 11 | * model: implementation of our AGCRN model 12 | 13 | 14 | ## Requirements 15 | 16 | Python 3.6.5, Pytorch 1.1.0, Numpy 1.16.3, argparse and configparser 17 | 18 | 19 | 20 | To replicate the results in PEMSD4 and PEMSD8 datasets, you can run the the codes in the "model" folder directly. If you want to use the model for your own datasets, please load your dataset by revising "load_dataset" in the "lib" folder and remember tuning the learning rate (gradient norm can be used to facilitate the training). 21 | 22 | Please cite our work if you find useful. 23 | 24 | 25 | 26 | --------------------------------------------------------------------------------