├── README.md ├── lib ├── TrainInits.py ├── add_window.py ├── dataloader.py ├── load_dataset.py ├── metrics.py ├── normalization.py └── utils.py ├── main.py ├── model ├── AGCN.py ├── AGCRNCell.py ├── AGCRN_UQ.py ├── BasicTrainer.py ├── PEMS03_AGCRN.conf ├── PEMS04_AGCRN.conf ├── PEMS07_AGCRN.conf ├── PEMS08_AGCRN.conf ├── readme.md ├── test_methods.py └── train_methods.py ├── notebooks ├── main_v2.ipynb └── readme └── slide └── uq_slide.pdf /README.md: -------------------------------------------------------------------------------- 1 | # Code for DeepSTUQ 2 | 3 | ## Related Papers: 4 | ### [Uncertainty Quantification for Traffic Forecasting: A Unified Approach, ICDE 2023](https://arxiv.org/pdf/2208.05875.pdf) 5 | ### [Towards a Unified Understanding of Uncertainty Quantification in Traffic Flow Forecasting, TKDE 2023](https://ieeexplore.ieee.org/abstract/document/10242138) 6 | 7 | ## Datasets: 8 | [PEMS datasets](https://www.kaggle.com/datasets/elmahy/pems-dataset?resource=download) 9 | 10 | 11 | ## Acknowledgement: 12 | The implementation is based on [this repository](https://github.com/LeiBAI/AGCRN) 13 | 14 | 15 | -------------------------------------------------------------------------------- /lib/TrainInits.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | def init_seed(seed): 6 | ''' 7 | Disable cudnn to maximize reproducibility 8 | ''' 9 | torch.cuda.cudnn_enabled = False 10 | torch.backends.cudnn.deterministic = True 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | 16 | def init_device(opt): 17 | if torch.cuda.is_available(): 18 | opt.cuda = True 19 | torch.cuda.set_device(int(opt.device[5])) 20 | else: 21 | opt.cuda = False 22 | opt.device = 'cpu' 23 | return opt 24 | 25 | def init_optim(model, opt): 26 | ''' 27 | Initialize optimizer 28 | ''' 29 | return torch.optim.Adam(params=model.parameters(),lr=opt.lr_init) 30 | 31 | def init_lr_scheduler(optim, opt): 32 | ''' 33 | Initialize the learning rate scheduler 34 | ''' 35 | #return torch.optim.lr_scheduler.StepLR(optimizer=optim,gamma=opt.lr_scheduler_rate,step_size=opt.lr_scheduler_step) 36 | return torch.optim.lr_scheduler.MultiStepLR(optimizer=optim, milestones=opt.lr_decay_steps, 37 | gamma = opt.lr_scheduler_rate) 38 | 39 | def print_model_parameters(model, only_num = True): 40 | print('*****************Model Parameter*****************') 41 | if not only_num: 42 | for name, param in model.named_parameters(): 43 | print(name, param.shape, param.requires_grad) 44 | total_num = sum([param.nelement() for param in model.parameters()]) 45 | print('Total params num: {}'.format(total_num)) 46 | print('*****************Finish Parameter****************') 47 | 48 | def get_memory_usage(device): 49 | allocated_memory = torch.cuda.memory_allocated(device) / (1024*1024.) 50 | cached_memory = torch.cuda.memory_cached(device) / (1024*1024.) 51 | return allocated_memory, cached_memory 52 | #print('Allocated Memory: {:.2f} MB, Cached Memory: {:.2f} MB'.format(allocated_memory, cached_memory)) -------------------------------------------------------------------------------- /lib/add_window.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def Add_Window_Horizon(data, window=3, horizon=1, single=False): 4 | ''' 5 | :param data: shape [B, ...] 6 | :param window: 7 | :param horizon: 8 | :return: X is [B, W, ...], Y is [B, H, ...] 9 | ''' 10 | length = len(data) 11 | end_index = length - horizon - window + 1 12 | X = [] #windows 13 | Y = [] #horizon 14 | index = 0 15 | if single: 16 | while index < end_index: 17 | X.append(data[index:index+window]) 18 | Y.append(data[index+window+horizon-1:index+window+horizon]) 19 | index = index + 1 20 | else: 21 | while index < end_index: 22 | X.append(data[index:index+window]) 23 | Y.append(data[index+window:index+window+horizon]) 24 | index = index + 1 25 | X = np.array(X) 26 | Y = np.array(Y) 27 | return X, Y 28 | 29 | if __name__ == '__main__': 30 | from data.load_raw_data import Load_Sydney_Demand_Data 31 | path = '../data/1h_data_new3.csv' 32 | data = Load_Sydney_Demand_Data(path) 33 | print(data.shape) 34 | X, Y = Add_Window_Horizon(data, horizon=2) 35 | print(X.shape, Y.shape) 36 | 37 | -------------------------------------------------------------------------------- /lib/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.utils.data 4 | from lib.add_window import Add_Window_Horizon 5 | from lib.load_dataset import load_st_dataset 6 | from lib.normalization import NScaler, MinMax01Scaler, MinMax11Scaler, StandardScaler, ColumnMinMaxScaler 7 | 8 | def normalize_dataset(data, normalizer, column_wise=False): 9 | if normalizer == 'max01': 10 | if column_wise: 11 | minimum = data.min(axis=0, keepdims=True) 12 | maximum = data.max(axis=0, keepdims=True) 13 | else: 14 | minimum = data.min() 15 | maximum = data.max() 16 | scaler = MinMax01Scaler(minimum, maximum) 17 | data = scaler.transform(data) 18 | print('Normalize the dataset by MinMax01 Normalization') 19 | elif normalizer == 'max11': 20 | if column_wise: 21 | minimum = data.min(axis=0, keepdims=True) 22 | maximum = data.max(axis=0, keepdims=True) 23 | else: 24 | minimum = data.min() 25 | maximum = data.max() 26 | scaler = MinMax11Scaler(minimum, maximum) 27 | data = scaler.transform(data) 28 | print('Normalize the dataset by MinMax11 Normalization') 29 | elif normalizer == 'std': 30 | if column_wise: 31 | mean = data.mean(axis=0, keepdims=True) 32 | std = data.std(axis=0, keepdims=True) 33 | else: 34 | mean = data.mean() 35 | std = data.std() 36 | scaler = StandardScaler(mean, std) 37 | data = scaler.transform(data) 38 | print('Normalize the dataset by Standard Normalization') 39 | elif normalizer == 'None': 40 | scaler = NScaler() 41 | data = scaler.transform(data) 42 | print('Does not normalize the dataset') 43 | elif normalizer == 'cmax': 44 | #column min max, to be depressed 45 | #note: axis must be the spatial dimension, please check ! 46 | scaler = ColumnMinMaxScaler(data.min(axis=0), data.max(axis=0)) 47 | data = scaler.transform(data) 48 | print('Normalize the dataset by Column Min-Max Normalization') 49 | else: 50 | raise ValueError 51 | return data, scaler 52 | 53 | def split_data_by_days(data, val_days, test_days, interval=60): 54 | ''' 55 | :param data: [B, *] 56 | :param val_days: 57 | :param test_days: 58 | :param interval: interval (15, 30, 60) minutes 59 | :return: 60 | ''' 61 | T = int((24*60)/interval) 62 | test_data = data[-T*test_days:] 63 | val_data = data[-T*(test_days + val_days): -T*test_days] 64 | train_data = data[:-T*(test_days + val_days)] 65 | return train_data, val_data, test_data 66 | 67 | def split_data_by_ratio(data, val_ratio, test_ratio): 68 | data_len = data.shape[0] 69 | test_data = data[-int(data_len*test_ratio):] 70 | val_data = data[-int(data_len*(test_ratio+val_ratio)):-int(data_len*test_ratio)] 71 | train_data = data[:-int(data_len*(test_ratio+val_ratio))] 72 | return train_data, val_data, test_data 73 | 74 | def data_loader(X, Y, batch_size, shuffle=True, drop_last=True): 75 | cuda = True if torch.cuda.is_available() else False 76 | TensorFloat = torch.cuda.FloatTensor if cuda else torch.FloatTensor 77 | X, Y = TensorFloat(X), TensorFloat(Y) 78 | data = torch.utils.data.TensorDataset(X, Y) 79 | dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, 80 | shuffle=shuffle, drop_last=drop_last) 81 | return dataloader 82 | 83 | 84 | def get_dataloader(args, normalizer = 'std', tod=False, dow=False, weather=False, single=True): 85 | #load raw st dataset 86 | data = load_st_dataset(args.dataset) # B, N, D 87 | #normalize st data 88 | data, scaler = normalize_dataset(data, normalizer, args.column_wise) 89 | #spilit dataset by days or by ratio 90 | if args.test_ratio > 1: 91 | data_train, data_val, data_test = split_data_by_days(data, args.val_ratio, args.test_ratio) 92 | else: 93 | data_train, data_val, data_test = split_data_by_ratio(data, args.val_ratio, args.test_ratio) 94 | #add time window 95 | x_tra, y_tra = Add_Window_Horizon(data_train, args.lag, args.horizon, single) 96 | x_val, y_val = Add_Window_Horizon(data_val, args.lag, args.horizon, single) 97 | x_test, y_test = Add_Window_Horizon(data_test, args.lag, args.horizon, single) 98 | print('Train: ', x_tra.shape, y_tra.shape) 99 | print('Val: ', x_val.shape, y_val.shape) 100 | print('Test: ', x_test.shape, y_test.shape) 101 | ##############get dataloader###################### 102 | train_dataloader = data_loader(x_tra, y_tra, args.batch_size, shuffle=True, drop_last=True) 103 | if len(x_val) == 0: 104 | val_dataloader = None 105 | else: 106 | val_dataloader = data_loader(x_val, y_val, args.batch_size, shuffle=False, drop_last=True) 107 | test_dataloader = data_loader(x_test, y_test, args.batch_size, shuffle=False, drop_last=False) 108 | return train_dataloader, val_dataloader, test_dataloader, scaler 109 | 110 | 111 | if __name__ == '__main__': 112 | import argparse 113 | #MetrLA 207; BikeNYC 128; SIGIR_solar 137; SIGIR_electric 321 114 | DATASET = 'SIGIR_electric' 115 | if DATASET == 'MetrLA': 116 | NODE_NUM = 207 117 | elif DATASET == 'BikeNYC': 118 | NODE_NUM = 128 119 | elif DATASET == 'SIGIR_solar': 120 | NODE_NUM = 137 121 | elif DATASET == 'SIGIR_electric': 122 | NODE_NUM = 321 123 | parser = argparse.ArgumentParser(description='PyTorch dataloader') 124 | parser.add_argument('--dataset', default=DATASET, type=str) 125 | parser.add_argument('--num_nodes', default=NODE_NUM, type=int) 126 | parser.add_argument('--val_ratio', default=0.1, type=float) 127 | parser.add_argument('--test_ratio', default=0.2, type=float) 128 | parser.add_argument('--lag', default=12, type=int) 129 | parser.add_argument('--horizon', default=12, type=int) 130 | parser.add_argument('--batch_size', default=64, type=int) 131 | args = parser.parse_args() 132 | train_dataloader, val_dataloader, test_dataloader, scaler = get_dataloader(args, normalizer = 'std', tod=False, dow=False, weather=False, single=True) -------------------------------------------------------------------------------- /lib/load_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def load_st_dataset(dataset): 5 | #output B, N, D 6 | if dataset == 'PEMS04': 7 | data_path = os.path.join('./data/PEMS04/pems04.npz') 8 | data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data 9 | elif dataset == 'PEMS08': 10 | data_path = os.path.join('./data/PEMS08/pems08.npz') 11 | data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data 12 | elif dataset == 'PEMS03': 13 | data_path = os.path.join('./data/PEMS03/pems03.npz') 14 | data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data 15 | elif dataset == 'PEMS07': 16 | data_path = os.path.join('./data/PEMS07/pems07.npz') 17 | data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data 18 | else: 19 | raise ValueError 20 | if len(data.shape) == 2: 21 | data = np.expand_dims(data, axis=-1) 22 | print('Load %s Dataset shaped: ' % dataset, data.shape, data.max(), data.min(), data.mean(), np.median(data)) 23 | return data -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Always evaluate the model with MAE, RMSE, MAPE, RRSE, PNBI, and oPNBI. 3 | Why add mask to MAE and RMSE? 4 | Filter the 0 that may be caused by error (such as loop sensor) 5 | Why add mask to MAPE and MARE? 6 | Ignore very small values (e.g., 0.5/0.5=100%) 7 | ''' 8 | import numpy as np 9 | import torch 10 | 11 | def MAE_torch(pred, true, mask_value=None): 12 | if mask_value != None: 13 | mask = torch.gt(true, mask_value) 14 | pred = torch.masked_select(pred, mask) 15 | true = torch.masked_select(true, mask) 16 | return torch.mean(torch.abs(true-pred)) 17 | 18 | def MSE_torch(pred, true, mask_value=None): 19 | if mask_value != None: 20 | mask = torch.gt(true, mask_value) 21 | pred = torch.masked_select(pred, mask) 22 | true = torch.masked_select(true, mask) 23 | return torch.mean((pred - true) ** 2) 24 | 25 | def RMSE_torch(pred, true, mask_value=None): 26 | if mask_value != None: 27 | mask = torch.gt(true, mask_value) 28 | pred = torch.masked_select(pred, mask) 29 | true = torch.masked_select(true, mask) 30 | return torch.sqrt(torch.mean((pred - true) ** 2)) 31 | 32 | def RRSE_torch(pred, true, mask_value=None): 33 | if mask_value != None: 34 | mask = torch.gt(true, mask_value) 35 | pred = torch.masked_select(pred, mask) 36 | true = torch.masked_select(true, mask) 37 | return torch.sqrt(torch.sum((pred - true) ** 2)) / torch.sqrt(torch.sum((pred - true.mean()) ** 2)) 38 | 39 | def CORR_torch(pred, true, mask_value=None): 40 | #input B, T, N, D or B, N, D or B, N 41 | if len(pred.shape) == 2: 42 | pred = pred.unsqueeze(dim=1).unsqueeze(dim=1) 43 | true = true.unsqueeze(dim=1).unsqueeze(dim=1) 44 | elif len(pred.shape) == 3: 45 | pred = pred.transpose(1, 2).unsqueeze(dim=1) 46 | true = true.transpose(1, 2).unsqueeze(dim=1) 47 | elif len(pred.shape) == 4: 48 | #B, T, N, D -> B, T, D, N 49 | pred = pred.transpose(2, 3) 50 | true = true.transpose(2, 3) 51 | else: 52 | raise ValueError 53 | dims = (0, 1, 2) 54 | pred_mean = pred.mean(dim=dims) 55 | true_mean = true.mean(dim=dims) 56 | pred_std = pred.std(dim=dims) 57 | true_std = true.std(dim=dims) 58 | correlation = ((pred - pred_mean)*(true - true_mean)).mean(dim=dims) / (pred_std*true_std) 59 | index = (true_std != 0) 60 | correlation = (correlation[index]).mean() 61 | return correlation 62 | 63 | 64 | def MAPE_torch(pred, true, mask_value=None): 65 | if mask_value != None: 66 | mask = torch.gt(true, mask_value) 67 | pred = torch.masked_select(pred, mask) 68 | true = torch.masked_select(true, mask) 69 | return torch.mean(torch.abs(torch.div((true - pred), true))) 70 | 71 | def PNBI_torch(pred, true, mask_value=None): 72 | if mask_value != None: 73 | mask = torch.gt(true, mask_value) 74 | pred = torch.masked_select(pred, mask) 75 | true = torch.masked_select(true, mask) 76 | indicator = torch.gt(pred - true, 0).float() 77 | return indicator.mean() 78 | 79 | def oPNBI_torch(pred, true, mask_value=None): 80 | if mask_value != None: 81 | mask = torch.gt(true, mask_value) 82 | pred = torch.masked_select(pred, mask) 83 | true = torch.masked_select(true, mask) 84 | bias = (true+pred) / (2*true) 85 | return bias.mean() 86 | 87 | def MARE_torch(pred, true, mask_value=None): 88 | if mask_value != None: 89 | mask = torch.gt(true, mask_value) 90 | pred = torch.masked_select(pred, mask) 91 | true = torch.masked_select(true, mask) 92 | return torch.div(torch.sum(torch.abs((true - pred))), torch.sum(true)) 93 | 94 | def SMAPE_torch(pred, true, mask_value=None): 95 | if mask_value != None: 96 | mask = torch.gt(true, mask_value) 97 | pred = torch.masked_select(pred, mask) 98 | true = torch.masked_select(true, mask) 99 | return torch.mean(torch.abs(true-pred)/(torch.abs(true)+torch.abs(pred))) 100 | 101 | 102 | def MAE_np(pred, true, mask_value=None): 103 | if mask_value != None: 104 | mask = np.where(true > (mask_value), True, False) 105 | true = true[mask] 106 | pred = pred[mask] 107 | MAE = np.mean(np.absolute(pred-true)) 108 | return MAE 109 | 110 | def RMSE_np(pred, true, mask_value=None): 111 | if mask_value != None: 112 | mask = np.where(true > (mask_value), True, False) 113 | true = true[mask] 114 | pred = pred[mask] 115 | RMSE = np.sqrt(np.mean(np.square(pred-true))) 116 | return RMSE 117 | 118 | #Root Relative Squared Error 119 | def RRSE_np(pred, true, mask_value=None): 120 | if mask_value != None: 121 | mask = np.where(true > (mask_value), True, False) 122 | true = true[mask] 123 | pred = pred[mask] 124 | mean = true.mean() 125 | return np.divide(np.sqrt(np.sum((pred-true) ** 2)), np.sqrt(np.sum((true-mean) ** 2))) 126 | 127 | def MAPE_np(pred, true, mask_value=None): 128 | if mask_value != None: 129 | mask = np.where(true > (mask_value), True, False) 130 | true = true[mask] 131 | pred = pred[mask] 132 | return np.mean(np.absolute(np.divide((true - pred), true))) 133 | 134 | def PNBI_np(pred, true, mask_value=None): 135 | #if PNBI=0, all pred are smaller than true 136 | #if PNBI=1, all pred are bigger than true 137 | if mask_value != None: 138 | mask = np.where(true > (mask_value), True, False) 139 | true = true[mask] 140 | pred = pred[mask] 141 | bias = pred-true 142 | indicator = np.where(bias>0, True, False) 143 | return indicator.mean() 144 | 145 | def oPNBI_np(pred, true, mask_value=None): 146 | #if oPNBI>1, pred are bigger than true 147 | #if oPNBI<1, pred are smaller than true 148 | #however, this metric is too sentive to small values. Not good! 149 | if mask_value != None: 150 | mask = np.where(true > (mask_value), True, False) 151 | true = true[mask] 152 | pred = pred[mask] 153 | bias = (true + pred) / (2 * true) 154 | return bias.mean() 155 | 156 | def MARE_np(pred, true, mask_value=None): 157 | if mask_value != None: 158 | mask = np.where(true> (mask_value), True, False) 159 | true = true[mask] 160 | pred = pred[mask] 161 | return np.divide(np.sum(np.absolute((true - pred))), np.sum(true)) 162 | 163 | def CORR_np(pred, true, mask_value=None): 164 | #input B, T, N, D or B, N, D or B, N 165 | if len(pred.shape) == 2: 166 | #B, N 167 | pred = pred.unsqueeze(dim=1).unsqueeze(dim=1) 168 | true = true.unsqueeze(dim=1).unsqueeze(dim=1) 169 | elif len(pred.shape) == 3: 170 | #np.transpose include permute, B, T, N 171 | pred = np.expand_dims(pred.transpose(0, 2, 1), axis=1) 172 | true = np.expand_dims(true.transpose(0, 2, 1), axis=1) 173 | elif len(pred.shape) == 4: 174 | #B, T, N, D -> B, T, D, N 175 | pred = pred.transpose(0, 1, 2, 3) 176 | true = true.transpose(0, 1, 2, 3) 177 | else: 178 | raise ValueError 179 | dims = (0, 1, 2) 180 | pred_mean = pred.mean(axis=dims) 181 | true_mean = true.mean(axis=dims) 182 | pred_std = pred.std(axis=dims) 183 | true_std = true.std(axis=dims) 184 | correlation = ((pred - pred_mean)*(true - true_mean)).mean(axis=dims) / (pred_std*true_std) 185 | index = (true_std != 0) 186 | correlation = (correlation[index]).mean() 187 | return correlation 188 | 189 | def All_Metrics(pred, true, mask1, mask2): 190 | #mask1 filter the very small value, mask2 filter the value lower than a defined threshold 191 | assert type(pred) == type(true) 192 | if type(pred) == np.ndarray: 193 | mae = MAE_np(pred, true, mask1) 194 | rmse = RMSE_np(pred, true, mask1) 195 | mape = MAPE_np(pred, true, mask2) 196 | rrse = RRSE_np(pred, true, mask1) 197 | corr = 0 198 | #corr = CORR_np(pred, true, mask1) 199 | #pnbi = PNBI_np(pred, true, mask1) 200 | #opnbi = oPNBI_np(pred, true, mask2) 201 | elif type(pred) == torch.Tensor: 202 | mae = MAE_torch(pred, true, mask1) 203 | rmse = RMSE_torch(pred, true, mask1) 204 | mape = MAPE_torch(pred, true, mask2) 205 | rrse = RRSE_torch(pred, true, mask1) 206 | corr = CORR_torch(pred, true, mask1) 207 | #pnbi = PNBI_torch(pred, true, mask1) 208 | #opnbi = oPNBI_torch(pred, true, mask2) 209 | else: 210 | raise TypeError 211 | return mae, rmse, mape, rrse, corr 212 | 213 | def SIGIR_Metrics(pred, true, mask1, mask2): 214 | rrse = RRSE_torch(pred, true, mask1) 215 | corr = CORR_torch(pred, true, 0) 216 | return rrse, corr 217 | 218 | if __name__ == '__main__': 219 | pred = torch.Tensor([1, 2, 3,4]) 220 | true = torch.Tensor([2, 1, 4,5]) 221 | print(All_Metrics(pred, true, None, None)) 222 | -------------------------------------------------------------------------------- /lib/normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class NScaler(object): 6 | def transform(self, data): 7 | return data 8 | def inverse_transform(self, data): 9 | return data 10 | 11 | class StandardScaler: 12 | """ 13 | Standard the input 14 | """ 15 | 16 | def __init__(self, mean, std): 17 | self.mean = mean 18 | self.std = std 19 | 20 | def transform(self, data): 21 | return (data - self.mean) / self.std 22 | 23 | def inverse_transform(self, data): 24 | if type(data) == torch.Tensor and type(self.mean) == np.ndarray: 25 | self.std = torch.from_numpy(self.std).to(data.device).type(data.dtype) 26 | self.mean = torch.from_numpy(self.mean).to(data.device).type(data.dtype) 27 | return (data * self.std) + self.mean 28 | 29 | class MinMax01Scaler: 30 | """ 31 | Standard the input 32 | """ 33 | 34 | def __init__(self, min, max): 35 | self.min = min 36 | self.max = max 37 | 38 | def transform(self, data): 39 | return (data - self.min) / (self.max - self.min) 40 | 41 | def inverse_transform(self, data): 42 | if type(data) == torch.Tensor and type(self.min) == np.ndarray: 43 | self.min = torch.from_numpy(self.min).to(data.device).type(data.dtype) 44 | self.max = torch.from_numpy(self.max).to(data.device).type(data.dtype) 45 | return (data * (self.max - self.min) + self.min) 46 | 47 | class MinMax11Scaler: 48 | """ 49 | Standard the input 50 | """ 51 | 52 | def __init__(self, min, max): 53 | self.min = min 54 | self.max = max 55 | 56 | def transform(self, data): 57 | return ((data - self.min) / (self.max - self.min)) * 2. - 1. 58 | 59 | def inverse_transform(self, data): 60 | if type(data) == torch.Tensor and type(self.min) == np.ndarray: 61 | self.min = torch.from_numpy(self.min).to(data.device).type(data.dtype) 62 | self.max = torch.from_numpy(self.max).to(data.device).type(data.dtype) 63 | return ((data + 1.) / 2.) * (self.max - self.min) + self.min 64 | 65 | class ColumnMinMaxScaler(): 66 | #Note: to use this scale, must init the min and max with column min and column max 67 | def __init__(self, min, max): 68 | self.min = min 69 | self.min_max = max - self.min 70 | self.min_max[self.min_max==0] = 1 71 | def transform(self, data): 72 | print(data.shape, self.min_max.shape) 73 | return (data - self.min) / self.min_max 74 | 75 | def inverse_transform(self, data): 76 | if type(data) == torch.Tensor and type(self.min) == np.ndarray: 77 | self.min_max = torch.from_numpy(self.min_max).to(data.device).type(torch.float32) 78 | self.min = torch.from_numpy(self.min).to(data.device).type(torch.float32) 79 | #print(data.dtype, self.min_max.dtype, self.min.dtype) 80 | return (data * self.min_max + self.min) 81 | 82 | def one_hot_by_column(data): 83 | #data is a 2D numpy array 84 | len = data.shape[0] 85 | for i in range(data.shape[1]): 86 | column = data[:, i] 87 | max = column.max() 88 | min = column.min() 89 | #print(len, max, min) 90 | zero_matrix = np.zeros((len, max-min+1)) 91 | zero_matrix[np.arange(len), column-min] = 1 92 | if i == 0: 93 | encoded = zero_matrix 94 | else: 95 | encoded = np.hstack((encoded, zero_matrix)) 96 | return encoded 97 | 98 | 99 | def minmax_by_column(data): 100 | # data is a 2D numpy array 101 | for i in range(data.shape[1]): 102 | column = data[:, i] 103 | max = column.max() 104 | min = column.min() 105 | column = (column - min) / (max - min) 106 | column = column[:, np.newaxis] 107 | if i == 0: 108 | _normalized = column 109 | else: 110 | _normalized = np.hstack((_normalized, column)) 111 | return _normalized 112 | 113 | 114 | if __name__ == '__main__': 115 | 116 | 117 | test_data = np.array([[0,0,0, 1], [0, 1, 3, 2], [0, 2, 1, 3]]) 118 | print(test_data) 119 | minimum = test_data.min(axis=1) 120 | print(minimum, minimum.shape, test_data.shape) 121 | maximum = test_data.max(axis=1) 122 | print(maximum) 123 | print(test_data-minimum) 124 | test_data = (test_data-minimum) / (maximum-minimum) 125 | print(test_data) 126 | print(0 == 0) 127 | print(0.00 == 0) 128 | print(0 == 0.00) 129 | #print(one_hot_by_column(test_data)) 130 | #print(minmax_by_column(test_data)) -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ###---Enable Dropout---### 4 | def enable_dropout(model): 5 | """ Function to enable the dropout layers during test-time """ 6 | for m in model.modules(): 7 | if m.__class__.__name__.startswith('Dropout'): 8 | m.train() 9 | 10 | ###---Save model---### 11 | def save_model_(model,model_name,dataset,pre_len): 12 | torch.save({'model_state_dict': model.state_dict(), 13 | }, f"check_points/{model_name}_{dataset}_{str(pre_len)}.pth" ) 14 | print("Model saved!") 15 | 16 | ###---Load model---### 17 | def load_model_(model,model_name,dataset,pre_len): 18 | PATH1 = f"check_points/{model_name}_{dataset}_{str(pre_len)}.pth" 19 | checkpoint = torch.load(PATH1) 20 | model.load_state_dict(checkpoint['model_state_dict']) 21 | print("Model loaded!") 22 | return model -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import os 5 | import sys 6 | 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | import argparse 11 | import configparser 12 | from datetime import datetime 13 | from model.AGCRN_UQ import AGCRN_UQ as Network 14 | from model.BasicTrainer import Trainer 15 | from lib.TrainInits import init_seed 16 | from lib.dataloader import get_dataloader 17 | from lib.TrainInits import print_model_parameters 18 | from tqdm import tqdm 19 | from copy import deepcopy 20 | import torch.nn.functional as F 21 | import torchcontrib 22 | from model.train_methods import awa_train_combined, train_cali_mcx 23 | from model.test_methods import combined_test 24 | 25 | 26 | #*************************************************************************# 27 | Mode = 'train' 28 | DEBUG = 'True' 29 | DATASET = 'PEMS08' #PEMS04/8/3/7 30 | DEVICE = 'cuda:0' 31 | MODEL = 'AGCRN' 32 | MODEL_NAME = "combined"#"combined/basic/dropout/heter" 33 | 34 | #get configuration 35 | config_file = 'model/{}_{}.conf'.format(DATASET, MODEL) 36 | #print('Read configuration file: %s' % (config_file)) 37 | config = configparser.ConfigParser() 38 | config.read(config_file) 39 | 40 | from lib.utils import enable_dropout,save_model_,load_model_ 41 | from lib.metrics import All_Metrics 42 | 43 | 44 | 45 | from lib.metrics import MAE_torch 46 | def masked_mae_loss(scaler, mask_value): 47 | def loss(preds, labels): 48 | if scaler: 49 | preds = scaler.inverse_transform(preds) 50 | labels = scaler.inverse_transform(labels) 51 | mae = MAE_torch(pred=preds, true=labels, mask_value=mask_value) 52 | return mae 53 | return loss 54 | 55 | #parser 56 | args = argparse.ArgumentParser(description='arguments') 57 | args.add_argument('--dataset', default=DATASET, type=str) 58 | args.add_argument('--mode', default=Mode, type=str) 59 | args.add_argument('--device', default=DEVICE, type=str, help='indices of GPUs') 60 | args.add_argument('--debug', default=DEBUG, type=eval) 61 | args.add_argument('--model', default=MODEL, type=str) 62 | args.add_argument('--cuda', default=True, type=bool) 63 | #data 64 | args.add_argument('--val_ratio', default=config['data']['val_ratio'], type=float) 65 | args.add_argument('--test_ratio', default=config['data']['test_ratio'], type=float) 66 | args.add_argument('--lag', default=config['data']['lag'], type=int) 67 | args.add_argument('--horizon', default=config['data']['horizon'], type=int) 68 | args.add_argument('--num_nodes', default=config['data']['num_nodes'], type=int) 69 | args.add_argument('--tod', default=config['data']['tod'], type=eval) 70 | args.add_argument('--normalizer', default=config['data']['normalizer'], type=str) 71 | args.add_argument('--column_wise', default=config['data']['column_wise'], type=eval) 72 | args.add_argument('--default_graph', default=config['data']['default_graph'], type=eval) 73 | #model 74 | args.add_argument('--input_dim', default=config['model']['input_dim'], type=int) 75 | args.add_argument('--output_dim', default=config['model']['output_dim'], type=int) 76 | args.add_argument('--embed_dim', default=config['model']['embed_dim'], type=int) 77 | args.add_argument('--rnn_units', default=config['model']['rnn_units'], type=int) 78 | args.add_argument('--num_layers', default=config['model']['num_layers'], type=int) 79 | args.add_argument('--cheb_k', default=config['model']['cheb_order'], type=int) 80 | args.add_argument('--p1', default=config['model']['p1'], type=float) 81 | 82 | #train 83 | args.add_argument('--loss_func', default=config['train']['loss_func'], type=str) 84 | #args.add_argument('--loss_func', default='mse', type=str) 85 | args.add_argument('--seed', default=config['train']['seed'], type=int) 86 | args.add_argument('--batch_size', default=config['train']['batch_size'], type=int) 87 | args.add_argument('--epochs', default=config['train']['epochs'], type=int) 88 | #args.add_argument('--epochs', default=500, type=int) 89 | args.add_argument('--lr_init', default=config['train']['lr_init'], type=float) 90 | #args.add_argument('--lr_init', default=1e-2, type=float) 91 | args.add_argument('--lr_decay', default=config['train']['lr_decay'], type=eval) 92 | args.add_argument('--lr_decay_rate', default=config['train']['lr_decay_rate'], type=float) 93 | args.add_argument('--lr_decay_step', default=config['train']['lr_decay_step'], type=str) 94 | args.add_argument('--early_stop', default=config['train']['early_stop'], type=eval) 95 | args.add_argument('--early_stop_patience', default=config['train']['early_stop_patience'], type=int) 96 | args.add_argument('--grad_norm', default=config['train']['grad_norm'], type=eval) 97 | args.add_argument('--max_grad_norm', default=config['train']['max_grad_norm'], type=int) 98 | args.add_argument('--teacher_forcing', default=False, type=bool) 99 | args.add_argument('--tf_decay_steps', default=2000, type=int, help='teacher forcing decay steps') 100 | args.add_argument('--real_value', default=config['train']['real_value'], type=eval, help = 'use real value for loss calculation') 101 | #test 102 | args.add_argument('--mae_thresh', default=config['test']['mae_thresh'], type=eval) 103 | args.add_argument('--mape_thresh', default=config['test']['mape_thresh'], type=float) 104 | #log 105 | args.add_argument('--log_dir', default='./', type=str) 106 | args.add_argument('--log_step', default=config['log']['log_step'], type=int) 107 | args.add_argument('--plot', default=config['log']['plot'], type=eval) 108 | args.add_argument('--model_name', default=MODEL_NAME, type=str) 109 | 110 | 111 | 112 | args = args.parse_args([]) 113 | init_seed(args.seed) 114 | if torch.cuda.is_available(): 115 | torch.cuda.set_device(int(args.device[5])) 116 | else: 117 | args.device = 'cpu' 118 | 119 | #init model 120 | model = Network(args).to(args.device) 121 | for p in model.parameters(): 122 | if p.dim() > 1: 123 | nn.init.xavier_uniform_(p) 124 | else: 125 | nn.init.uniform_(p) 126 | print_model_parameters(model, only_num=False) 127 | 128 | #load dataset 129 | train_loader, val_loader, test_loader, scaler = get_dataloader(args, 130 | normalizer=args.normalizer, 131 | tod=args.tod, dow=False, 132 | weather=False, single=False) 133 | #init loss function, optimizer 134 | if args.loss_func == 'mask_mae': 135 | loss = masked_mae_loss(scaler, mask_value=0.0) 136 | elif args.loss_func == 'mae': 137 | loss = torch.nn.L1Loss().to(args.device) 138 | elif args.loss_func == 'mse': 139 | loss = torch.nn.MSELoss().to(args.device) 140 | else: 141 | raise ValueError 142 | 143 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr_init, eps=1.0e-8, 144 | weight_decay=1e-6, amsgrad=False) 145 | 146 | #learning rate decay 147 | lr_scheduler = None 148 | if args.lr_decay: 149 | print('Applying learning rate decay.') 150 | lr_decay_steps = [int(i) for i in list(args.lr_decay_step.split(','))] 151 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, 152 | milestones=lr_decay_steps, 153 | gamma=args.lr_decay_rate) 154 | #start training 155 | trainer = Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, 156 | args, lr_scheduler=lr_scheduler) 157 | 158 | 159 | ### Pre-traning 160 | trainer.train() 161 | 162 | 163 | ### AWA Re-training 164 | 165 | #trainer.model = awa_train_combined(trainer,epochs=20) 166 | 167 | 168 | ### Model Calibration 169 | 170 | T = train_cali_mc(trainer.model,10, args, val_loader, scaler) 171 | combined_test(trainer.model,10,trainer.args, trainer.test_loader, scaler,T) 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /model/AGCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | class AVWGCN(nn.Module): 6 | def __init__(self, model_name,p1,dim_in, dim_out, cheb_k, embed_dim): 7 | super(AVWGCN, self).__init__() 8 | self.cheb_k = cheb_k 9 | self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) 10 | self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) 11 | 12 | self.model_name = model_name 13 | self.p1= p1 14 | 15 | def forward(self, x, node_embeddings): 16 | #x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N] 17 | #output shape [B, N, C] 18 | node_num = node_embeddings.shape[0] 19 | supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1) 20 | support_set = [torch.eye(node_num).to(supports.device), supports] 21 | #default cheb_k = 3 22 | for k in range(2, self.cheb_k): 23 | support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2]) 24 | supports = torch.stack(support_set, dim=0) 25 | 26 | 27 | weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out 28 | bias = torch.matmul(node_embeddings, self.bias_pool) #N, dim_out 29 | x_g = torch.einsum("knm,bmc->bknc", supports, x) #B, cheb_k, N, dim_in 30 | x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in 31 | 32 | 33 | x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out 34 | ### add dropout 35 | 36 | if self.model_name == "dropout" or self.model_name == "combined": 37 | x_gconv = F.dropout(x_gconv, self.p1) #pems04=0.1/pems08=0.05 38 | 39 | return x_gconv 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /model/AGCRNCell.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from model.AGCN import AVWGCN 4 | 5 | class AGCRNCell(nn.Module): 6 | def __init__(self, model_name, p1, node_num, dim_in, dim_out, cheb_k, embed_dim): 7 | super(AGCRNCell, self).__init__() 8 | self.node_num = node_num 9 | self.hidden_dim = dim_out 10 | self.gate = AVWGCN(model_name, p1, dim_in+self.hidden_dim, 2*dim_out, cheb_k, embed_dim) 11 | self.update = AVWGCN(model_name, p1, dim_in+self.hidden_dim, dim_out, cheb_k, embed_dim) 12 | 13 | def forward(self, x, state, node_embeddings): 14 | #x: B, num_nodes, input_dim 15 | #state: B, num_nodes, hidden_dim 16 | state = state.to(x.device) 17 | input_and_state = torch.cat((x, state), dim=-1) 18 | z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings)) 19 | z, r = torch.split(z_r, self.hidden_dim, dim=-1) 20 | candidate = torch.cat((x, z*state), dim=-1) 21 | hc = torch.tanh(self.update(candidate, node_embeddings)) 22 | h = r*state + (1-r)*hc 23 | 24 | #print(x.shape, state.shape, input_and_state.shape) 25 | return h 26 | 27 | def init_hidden_state(self, batch_size): 28 | return torch.zeros(batch_size, self.node_num, self.hidden_dim) -------------------------------------------------------------------------------- /model/AGCRN_UQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as td 5 | from model.AGCRNCell import AGCRNCell 6 | 7 | class AVWDCRNN(nn.Module): 8 | def __init__(self, model_name, p1,node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): 9 | super(AVWDCRNN, self).__init__() 10 | assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.' 11 | self.node_num = node_num 12 | self.input_dim = dim_in 13 | self.num_layers = num_layers 14 | self.dcrnn_cells = nn.ModuleList() 15 | self.dcrnn_cells.append(AGCRNCell(model_name, p1,node_num, dim_in, dim_out, cheb_k, embed_dim)) 16 | for _ in range(1, num_layers): 17 | self.dcrnn_cells.append(AGCRNCell(model_name, p1,node_num, dim_out, dim_out, cheb_k, embed_dim)) 18 | 19 | def forward(self, x, init_state, node_embeddings): 20 | #shape of x: (B, T, N, D) 21 | #shape of init_state: (num_layers, B, N, hidden_dim) 22 | assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim 23 | seq_length = x.shape[1] 24 | current_inputs = x 25 | output_hidden = [] 26 | for i in range(self.num_layers): 27 | state = init_state[i] 28 | inner_states = [] 29 | for t in range(seq_length): 30 | state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, node_embeddings) 31 | inner_states.append(state) 32 | output_hidden.append(state) 33 | current_inputs = torch.stack(inner_states, dim=1) 34 | #current_inputs: the outputs of last layer: (B, T, N, hidden_dim) 35 | #output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim) 36 | #last_state: (B, N, hidden_dim) 37 | return current_inputs, output_hidden 38 | 39 | def init_hidden(self, batch_size): 40 | init_states = [] 41 | for i in range(self.num_layers): 42 | init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size)) 43 | return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim) 44 | 45 | 46 | ###========Main model========= 47 | class AGCRN_UQ(nn.Module): 48 | def __init__(self, args): 49 | super(AGCRN_UQ, self).__init__() 50 | self.num_node = args.num_nodes 51 | self.input_dim = args.input_dim 52 | self.hidden_dim = args.rnn_units 53 | self.output_dim = args.output_dim 54 | self.horizon = args.horizon 55 | self.num_layers = args.num_layers 56 | 57 | ### 58 | self.model_name = args.model_name 59 | self.p1= args.p1 60 | 61 | 62 | self.default_graph = args.default_graph 63 | self.node_embeddings = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) 64 | 65 | 66 | self.encoder = AVWDCRNN(self.model_name, self.p1, args.num_nodes, args.input_dim, args.rnn_units, args.cheb_k, 67 | args.embed_dim, args.num_layers) 68 | 69 | 70 | if self.model_name == "combined": 71 | self.get_mu = nn.Sequential( 72 | nn.Conv2d(1, 32, kernel_size=(1,1), bias=True), 73 | nn.Dropout(0.2), 74 | nn.ReLU(), 75 | nn.Conv2d(32, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True), 76 | ) 77 | 78 | self.get_log_var = nn.Sequential( 79 | nn.Conv2d(1, 32, kernel_size=(1,1), bias=True), 80 | nn.Dropout(0.2), 81 | nn.ReLU(), 82 | nn.Conv2d(32, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True), 83 | ) 84 | 85 | 86 | def forward(self, source, targets, teacher_forcing_ratio=0.5): 87 | #source: B, T_1, N, D 88 | #target: B, T_2, N, D 89 | #supports = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec1.transpose(0,1))), dim=1) 90 | 91 | emb = self.node_embeddings 92 | init_state = self.encoder.init_hidden(source.shape[0]) 93 | output, _ = self.encoder(source, init_state, emb) #B, T, N, hidden 94 | #print(output.shape) 95 | output = output[:, -1:, :, :] 96 | 97 | #CNN based predictor 98 | 99 | if self.model_name == "combined": 100 | mu = self.get_mu((output)) #B, T*C, N, 1 101 | mu = mu.squeeze(-1).reshape(-1, self.horizon, self.output_dim, self.num_node) 102 | mu = mu.permute(0, 1, 3, 2) #B, T, N, C 103 | 104 | log_var = self.get_log_var((output)) #B, T*C, N, 1 105 | log_var = log_var.squeeze(-1).reshape(-1, self.horizon, self.output_dim, self.num_node) 106 | log_var = log_var.permute(0, 1, 3, 2) 107 | return mu, log_var 108 | 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /model/BasicTrainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import os 4 | import time 5 | import copy 6 | import numpy as np 7 | #from lib.logger import get_logger 8 | from lib.metrics import All_Metrics 9 | from lib.utils import enable_dropout 10 | import torch.nn.functional as F 11 | from torch.autograd import grad 12 | import torch.nn as nn 13 | 14 | from torch.optim.swa_utils import AveragedModel, SWALR 15 | from torch.optim.lr_scheduler import CosineAnnealingLR 16 | #from model.train_methods import QuantileLoss 17 | 18 | 19 | class Trainer(object): 20 | def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, 21 | scaler, args, lr_scheduler=None): 22 | super().__init__() 23 | self.model = model 24 | self.loss = loss 25 | self.optimizer = optimizer 26 | self.train_loader = train_loader 27 | self.val_loader = val_loader 28 | self.test_loader = test_loader 29 | self.scaler = scaler 30 | self.args = args 31 | self.lr_scheduler = lr_scheduler 32 | self.train_per_epoch = len(train_loader) 33 | #self.quantile_loss = QuantileLoss() 34 | 35 | def val_epoch(self, epoch, val_dataloader): 36 | self.model.eval() 37 | total_val_loss = 0 38 | 39 | with torch.no_grad(): 40 | for batch_idx, (data, target) in enumerate(val_dataloader): 41 | data = data[..., :self.args.input_dim] 42 | label = target[..., :self.args.output_dim] 43 | if self.args.real_value: 44 | label = self.scaler.inverse_transform(label) 45 | 46 | if self.args.model_name == "basic" or self.args.model_name=="dropout": 47 | output = self.model(data, target, teacher_forcing_ratio=0.) 48 | loss = self.loss(output.cuda(), label) 49 | 50 | elif self.args.model_name == "heter" or self.args.model_name=="combined": 51 | output, _ = self.model(data, target, teacher_forcing_ratio=0.) 52 | loss = self.loss(output.cuda(), label) 53 | 54 | elif self.args.model_name == "quantile": 55 | output = self.model(data, target, teacher_forcing_ratio=0.) 56 | loss = self.quantile_loss(output,label) 57 | 58 | #a whole batch of Metr_LA is filtered 59 | if not torch.isnan(loss): 60 | total_val_loss += loss.item() 61 | val_loss = total_val_loss / len(val_dataloader) 62 | print('**********Val Epoch {}: average Loss: {:.6f}'.format(epoch, val_loss)) 63 | return val_loss 64 | 65 | def train_epoch(self, epoch): 66 | self.model.train() 67 | total_loss = 0 68 | for batch_idx, (data, target) in enumerate(self.train_loader): 69 | data = data[..., :self.args.input_dim] 70 | label = target[..., :self.args.output_dim] # (..., 1) 71 | 72 | #label = torch.log(label) 73 | 74 | self.optimizer.zero_grad() 75 | 76 | #teacher_forcing for RNN encoder-decoder model 77 | #if teacher_forcing_ratio = 1: use label as input in the decoder for all steps 78 | if self.args.teacher_forcing: 79 | global_step = (epoch - 1) * self.train_per_epoch + batch_idx 80 | teacher_forcing_ratio = self._compute_sampling_threshold(global_step, self.args.tf_decay_steps) 81 | else: 82 | teacher_forcing_ratio = 1. 83 | #data and target shape: B, T, N, F; output shape: B, T, N, F 84 | 85 | if self.args.real_value: 86 | label = self.scaler.inverse_transform(label) 87 | 88 | if self.args.model_name == "basic" or self.args.model_name=="dropout": 89 | index = torch.randperm(data.shape[1]) 90 | data = data[:,index,:,:] 91 | 92 | output = self.model(data, target, teacher_forcing_ratio=0.) 93 | loss = self.loss(output.cuda(), label) 94 | 95 | elif self.args.model_name == "heter" or self.args.model_name=="combined": 96 | mu, log_var = self.model(data, target, teacher_forcing_ratio=teacher_forcing_ratio) 97 | loss = torch.mean(torch.exp(-log_var)*(label-mu)**2 + log_var) 98 | loss = 0.1*loss + 0.9*self.loss(mu, label) 99 | 100 | elif self.args.model_name == "quantile": 101 | output = self.model(data, target, teacher_forcing_ratio=teacher_forcing_ratio) 102 | loss = self.quantile_loss(output,label) 103 | 104 | loss.backward() 105 | 106 | # add max grad clipping 107 | if self.args.grad_norm: 108 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 109 | self.optimizer.step() 110 | total_loss += loss.item() 111 | 112 | #log information 113 | if batch_idx % self.args.log_step == 0: 114 | print('Train Epoch {}: {}/{} Loss: {:.6f}'.format( 115 | epoch, batch_idx, self.train_per_epoch, loss.item())) 116 | train_epoch_loss = total_loss/self.train_per_epoch 117 | #self.logger.info('**********Train Epoch {}: averaged Loss: {:.6f}, tf_ratio: {:.6f}'.format(epoch, train_epoch_loss, teacher_forcing_ratio)) 118 | 119 | #learning rate decay 120 | if self.args.lr_decay: 121 | self.lr_scheduler.step() 122 | return train_epoch_loss 123 | 124 | def train(self): 125 | 126 | best_model = None 127 | best_loss = float('inf') 128 | not_improved_count = 0 129 | train_loss_list = [] 130 | val_loss_list = [] 131 | start_time = time.time() 132 | for epoch in range(1, self.args.epochs + 1): 133 | #epoch_time = time.time() 134 | train_epoch_loss = self.train_epoch(epoch) 135 | #print(time.time()-epoch_time) 136 | #exit() 137 | if self.val_loader == None: 138 | val_dataloader = self.test_loader 139 | else: 140 | val_dataloader = self.val_loader 141 | val_epoch_loss = self.val_epoch(epoch, val_dataloader) 142 | 143 | #print('LR:', self.optimizer.param_groups[0]['lr']) 144 | train_loss_list.append(train_epoch_loss) 145 | val_loss_list.append(val_epoch_loss) 146 | if train_epoch_loss > 1e6: 147 | print('Gradient explosion detected. Ending...') 148 | break 149 | #if self.val_loader == None: 150 | #val_epoch_loss = train_epoch_loss 151 | if val_epoch_loss < best_loss: 152 | best_loss = val_epoch_loss 153 | not_improved_count = 0 154 | best_state = True 155 | else: 156 | not_improved_count += 1 157 | best_state = False 158 | # early stop 159 | if self.args.early_stop: 160 | if not_improved_count == self.args.early_stop_patience: 161 | print("Validation performance didn\'t improve for {} epochs. " 162 | "Training stops.".format(self.args.early_stop_patience)) 163 | break 164 | # save the best state 165 | if best_state == True: 166 | print('*********************************Current best model saved!') 167 | best_model = copy.deepcopy(self.model.state_dict()) 168 | 169 | 170 | training_time = time.time() - start_time 171 | print("Total training time: {:.4f}min, best loss: {:.6f}".format((training_time / 60), best_loss)) 172 | 173 | #save the best model to file 174 | if not self.args.debug: 175 | torch.save(best_model, self.best_path) 176 | print("Saving current best model to " + self.best_path) 177 | 178 | #test 179 | self.model.load_state_dict(best_model) 180 | #self.val_epoch(self.args.epochs, self.test_loader) 181 | #self.test(self.model, self.args, self.test_loader, self.scaler, self.logger=None) 182 | 183 | @staticmethod 184 | def test(model, args, data_loader, scaler, logger=None, path=None): 185 | if path != None: 186 | check_point = torch.load(path) 187 | state_dict = check_point['state_dict'] 188 | args = check_point['config'] 189 | model.load_state_dict(state_dict) 190 | model.to(args.device) 191 | 192 | model.eval() 193 | #enable_dropout(model) 194 | 195 | y_pred = [] 196 | y_true = [] 197 | with torch.no_grad(): 198 | for batch_idx, (data, target) in enumerate(data_loader): 199 | data = data[..., :args.input_dim] 200 | label = target[..., :args.output_dim] 201 | #output, z, mu, log_sigma = model.sample_code_(data) 202 | 203 | #label = torch.log(label) 204 | 205 | output = model(data, target, teacher_forcing_ratio=0) 206 | #output, *_ = model(data, target, teacher_forcing_ratio=0) 207 | y_true.append(label) 208 | y_pred.append(output) 209 | y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) 210 | if args.real_value: 211 | y_pred = torch.cat(y_pred, dim=0) 212 | else: 213 | y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) 214 | #np.save('./{}_true.npy'.format(args.dataset), y_true.cpu().numpy()) 215 | #np.save('./{}_pred.npy'.format(args.dataset), y_pred.cpu().numpy()) 216 | for t in range(y_true.shape[1]): 217 | mae, rmse, mape, _, _ = All_Metrics(y_pred[:, t, ...], y_true[:, t, ...], 218 | args.mae_thresh, args.mape_thresh) 219 | print("Horizon {:02d}, MAE: {:.2f}, RMSE: {:.2f}, MAPE: {:.4f}%".format( 220 | t + 1, mae, rmse, mape*100)) 221 | mae, rmse, mape, _, _ = All_Metrics(y_pred, y_true, args.mae_thresh, args.mape_thresh) 222 | print("Average Horizon, MAE: {:.2f}, RMSE: {:.2f}, MAPE: {:.4f}%".format( 223 | mae, rmse, mape*100)) 224 | -------------------------------------------------------------------------------- /model/PEMS03_AGCRN.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 358 3 | lag = 12 4 | horizon = 12 5 | val_ratio = 0.2 6 | test_ratio = 0.2 7 | tod = False 8 | normalizer = std 9 | column_wise = False 10 | default_graph = True 11 | 12 | [model] 13 | input_dim = 1 14 | output_dim = 1 15 | embed_dim = 10 16 | rnn_units = 64 17 | num_layers = 2 18 | cheb_order = 2 19 | p1 = 0.1 20 | 21 | [train] 22 | loss_func = mae 23 | seed = 10 24 | batch_size = 64 25 | epochs = 100 26 | lr_init = 0.003 27 | lr_decay = False 28 | lr_decay_rate = 0.3 29 | lr_decay_step = 5,20,40,70 30 | early_stop = True 31 | early_stop_patience = 15 32 | grad_norm = False 33 | max_grad_norm = 5 34 | real_value = True 35 | 36 | [test] 37 | mae_thresh = None 38 | mape_thresh = 0.001 39 | 40 | [log] 41 | log_step = 20 42 | plot = False -------------------------------------------------------------------------------- /model/PEMS04_AGCRN.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 307 3 | lag = 12 4 | horizon = 12 5 | val_ratio = 0.2 6 | test_ratio = 0.2 7 | tod = False 8 | normalizer = std 9 | column_wise = False 10 | default_graph = True 11 | 12 | [model] 13 | input_dim = 1 14 | output_dim = 1 15 | embed_dim = 10 16 | rnn_units = 64 17 | num_layers = 2 18 | cheb_order = 2 19 | p1 = 0.1 20 | 21 | [train] 22 | loss_func = mae 23 | seed = 10 24 | batch_size = 64 25 | epochs = 100 26 | lr_init = 0.003 27 | lr_decay = False 28 | lr_decay_rate = 0.3 29 | lr_decay_step = 5,20,40,70 30 | early_stop = True 31 | early_stop_patience = 15 32 | grad_norm = False 33 | max_grad_norm = 5 34 | real_value = True 35 | 36 | [test] 37 | mae_thresh = None 38 | mape_thresh = 0. 39 | 40 | [log] 41 | log_step = 20 42 | plot = False -------------------------------------------------------------------------------- /model/PEMS07_AGCRN.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 883 3 | lag = 12 4 | horizon = 12 5 | val_ratio = 0.2 6 | test_ratio = 0.2 7 | tod = False 8 | normalizer = std 9 | column_wise = False 10 | default_graph = True 11 | 12 | [model] 13 | input_dim = 1 14 | output_dim = 1 15 | embed_dim = 10 16 | rnn_units = 64 17 | num_layers = 2 18 | cheb_order = 2 19 | p1 = 0.1 20 | 21 | [train] 22 | loss_func = mae 23 | seed = 10 24 | batch_size = 32 25 | epochs = 100 26 | lr_init = 0.003 27 | lr_decay = False 28 | lr_decay_rate = 0.3 29 | lr_decay_step = 5,20,40,70 30 | early_stop = True 31 | early_stop_patience = 15 32 | grad_norm = False 33 | max_grad_norm = 5 34 | real_value = True 35 | 36 | [test] 37 | mae_thresh = None 38 | mape_thresh = 0. 39 | 40 | [log] 41 | log_step = 20 42 | plot = False -------------------------------------------------------------------------------- /model/PEMS08_AGCRN.conf: -------------------------------------------------------------------------------- 1 | [data] 2 | num_nodes = 170 3 | lag = 12 4 | horizon = 12 5 | val_ratio = 0.2 6 | test_ratio = 0.2 7 | tod = False 8 | normalizer = std 9 | column_wise = False 10 | default_graph = True 11 | 12 | [model] 13 | input_dim = 1 14 | output_dim = 1 15 | embed_dim = 2 16 | rnn_units = 64 17 | num_layers = 2 18 | cheb_order = 2 19 | p1 = 0.05 20 | 21 | [train] 22 | loss_func = mae 23 | seed = 12 24 | batch_size = 64 25 | epochs = 100 26 | lr_init = 0.003 27 | lr_decay = False 28 | lr_decay_rate = 0.3 29 | lr_decay_step = 5,20,40,70 30 | early_stop = True 31 | early_stop_patience = 15 32 | grad_norm = False 33 | max_grad_norm = 5 34 | real_value = True 35 | 36 | [test] 37 | mae_thresh = None 38 | mape_thresh = 0. 39 | 40 | [log] 41 | log_step = 20 42 | plot = False -------------------------------------------------------------------------------- /model/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/test_methods.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lib.utils import enable_dropout 4 | from lib.metrics import All_Metrics 5 | from tqdm import tqdm 6 | 7 | ####======MC+Heter========#### 8 | def combined_test(model,num_samples,args, data_loader, scaler, T=torch.zeros(1).cuda(), logger=None, path=None): 9 | model.eval() 10 | enable_dropout(model) 11 | nll_fun = nn.GaussianNLLLoss() 12 | y_true = [] 13 | with torch.no_grad(): 14 | for batch_idx, (_, target) in enumerate(data_loader): 15 | label = target[..., :args.output_dim] 16 | y_true.append(label) 17 | y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)).squeeze(3) 18 | 19 | mc_mus = torch.empty(0, y_true.size(0), y_true.size(1), y_true.size(2)).cuda() 20 | mc_log_vars = torch.empty(0, y_true.size(0),y_true.size(1), y_true.size(2)).cuda() 21 | 22 | with torch.no_grad(): 23 | for i in tqdm(range(num_samples)): 24 | mu_pred = [] 25 | log_var_pred = [] 26 | for batch_idx, (data, _) in enumerate(data_loader): 27 | data = data[..., :args.input_dim] 28 | mu, log_var = model.forward(data, target, teacher_forcing_ratio=0) 29 | #print(mu.size()) 30 | mu_pred.append(mu.squeeze(3)) 31 | log_var_pred.append(log_var.squeeze(3)) 32 | 33 | if args.real_value: 34 | mu_pred = torch.cat(mu_pred, dim=0) 35 | else: 36 | mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0)) 37 | log_var_pred = torch.cat(log_var_pred, dim=0) 38 | 39 | #print(mc_mus.size(),mu_pred.size()) 40 | mc_mus = torch.vstack((mc_mus,mu_pred.unsqueeze(0))) 41 | mc_log_vars = torch.vstack((mc_log_vars,log_var_pred.unsqueeze(0))) 42 | 43 | temperature = torch.exp(T) 44 | y_pred = torch.mean(mc_mus, axis=0) 45 | total_var = torch.var(mc_mus, axis=0)+torch.exp(torch.mean(mc_log_vars, axis=0))/temperature 46 | total_std = total_var**0.5 47 | 48 | mpiw = 2*1.96*torch.mean(total_std) 49 | nll = nll_fun(y_pred.ravel(), y_true.ravel(), total_var.ravel()) 50 | lower_bound = y_pred-1.96*total_std 51 | upper_bound = y_pred+1.96*total_std 52 | in_num = torch.sum((y_true >= lower_bound)&(y_true <= upper_bound )) 53 | picp = in_num/(y_true.size(0)*y_true.size(1)*y_true.size(2)) 54 | 55 | 56 | print("Average Horizon, MAE: {:.4f}, RMSE: {:.4f}, MAPE: {:.4f}%, NLL: {:.4f}, \ 57 | PICP: {:.4f}%, MPIW: {:.4f}".format(mae, rmse, mape*100, nll, picp*100, mpiw)) 58 | -------------------------------------------------------------------------------- /model/train_methods.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lib.utils import enable_dropout 4 | from torch.optim.swa_utils import AveragedModel, SWALR 5 | from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts 6 | from tqdm import tqdm 7 | 8 | 9 | ###====Adaptive Weight Averaging====### 10 | def awa_train_combined(trainer,epoch_swa, regularizer=None, lr_schedule=None): 11 | #total_loss = 0 12 | #criterion = torch.nn.L1Loss()#torch.nn.MSELoss() 13 | #loss_sum = 0.0 14 | num_iters = len(trainer.train_loader) 15 | 16 | lr1 = 0.003 17 | lr2 = lr1*0.01 18 | 19 | optimizer_swa = torch.optim.Adam(params=model.parameters(), lr=lr1, betas=(0.9,0.99), 20 | weight_decay=1e-4, amsgrad=False) 21 | 22 | cycle = num_iters 23 | swa_c = 1 24 | swa_model = AveragedModel(trainer.model) 25 | #scheduler = CosineAnnealingLR(optimizer_swa, T_max=cycle,eta_min=0.0) 26 | scheduler_swa = CosineAnnealingWarmRestarts(optimizer_swa,T_0=num_iters, 27 | T_mult=1, eta_min=lr2, last_epoch=-1) 28 | 29 | lr_ls = [] 30 | for epoch in tqdm(range(epoch_swa)): 31 | trainer.model.train() 32 | for iter, (data, target) in enumerate(trainer.train_loader): 33 | input = data[..., :trainer.args.input_dim] 34 | label = target[..., :trainer.args.output_dim] # (..., 1) 35 | optimizer_swa.zero_grad() 36 | 37 | input = input.cuda(non_blocking=True) 38 | label = label.cuda(non_blocking=True) 39 | 40 | if trainer.args.teacher_forcing: 41 | global_step = (epoch - 1) * trainer.train_per_epoch + batch_idx 42 | teacher_forcing_ratio = trainer._compute_sampling_threshold(global_step, trainer.args.tf_decay_steps) 43 | else: 44 | teacher_forcing_ratio = 1. 45 | #output,log_var = trainer.model.forward_heter(data, target, teacher_forcing_ratio=teacher_forcing_ratio) 46 | mu,log_var = trainer.model.forward(data, target, teacher_forcing_ratio=0.5) 47 | if trainer.args.real_value: 48 | label = trainer.scaler.inverse_transform(label) 49 | loss = torch.mean(torch.exp(-log_var)*(label-mu)**2 + log_var) 50 | loss = 0.1*loss + 0.9*trainer.loss(mu, label) 51 | #loss = trainer.loss(mu, label) 52 | loss.backward() 53 | optimizer_swa.step() 54 | if (epoch % 2 ==0) & (iter != num_iters-1): 55 | scheduler_swa.step() 56 | else: 57 | optimizer_swa.param_groups[0]["lr"]=lr2 58 | #scheduler.step() 59 | if (epoch+1) % 2 ==0:#) & (epoch !=0): 60 | swa_model.update_parameters(trainer.model) 61 | torch.optim.swa_utils.update_bn(trainer.train_loader, swa_model) 62 | 63 | #swa_scheduler.step() 64 | #scheduler.step() 65 | return swa_model 66 | 67 | 68 | def swa_train(trainer,epoch_swa, regularizer=None, lr_schedule=None): 69 | 70 | num_iters = len(trainer.train_loader) 71 | 72 | lr1 = 0.003#0.1 73 | lr2 = lr1*0.01#0.001 74 | 75 | optimizer_swa = torch.optim.Adam(params=model.parameters(), lr=lr1, betas=(0.9,0.99), 76 | weight_decay=1e-4, amsgrad=False) 77 | 78 | cycle = num_iters 79 | swa_c = 1 80 | swa_model = AveragedModel(trainer.model) 81 | #scheduler = CosineAnnealingLR(optimizer_swa, T_max=cycle,eta_min=0.0) 82 | scheduler_swa = CosineAnnealingWarmRestarts(optimizer_swa,T_0=num_iters, 83 | T_mult=1, eta_min=lr2, last_epoch=-1) 84 | lr_ls = [] 85 | for epoch in tqdm(range(epoch_swa)): 86 | trainer.model.train() 87 | for iter, (data, target) in enumerate(trainer.train_loader): 88 | input = data[..., :trainer.args.input_dim] 89 | label = target[..., :trainer.args.output_dim] # (..., 1) 90 | optimizer_swa.zero_grad() 91 | 92 | input = input.cuda(non_blocking=True) 93 | label = label.cuda(non_blocking=True) 94 | 95 | if trainer.args.teacher_forcing: 96 | global_step = (epoch - 1) * trainer.train_per_epoch + batch_idx 97 | teacher_forcing_ratio = trainer._compute_sampling_threshold(global_step, trainer.args.tf_decay_steps) 98 | else: 99 | teacher_forcing_ratio = 1. 100 | #output,log_var = trainer.model.forward_heter(data, target, teacher_forcing_ratio=teacher_forcing_ratio) 101 | output = trainer.model.forward(data, target, teacher_forcing_ratio=0.5) 102 | if trainer.args.real_value: 103 | label = trainer.scaler.inverse_transform(label) 104 | loss = trainer.loss(output, label) 105 | loss.backward() 106 | optimizer_swa.step() 107 | if (epoch % 2 ==0) & (iter != num_iters-1): 108 | scheduler_swa.step() 109 | else: 110 | optimizer_swa.param_groups[0]["lr"]=lr2 111 | #scheduler.step() 112 | if (epoch+1) % 2 ==0: 113 | swa_model.update_parameters(trainer.model) 114 | torch.optim.swa_utils.update_bn(trainer.train_loader, swa_model) 115 | 116 | return swa_model 117 | 118 | ###====Calibration====### 119 | class ModelCali(nn.Module): 120 | def __init__(self,args): 121 | super(ModelCali, self).__init__() 122 | #self.model = model 123 | #self.T = nn.Parameter(torch.ones(args.num_nodes)*1.5, requires_grad=True) 124 | self.T = nn.Parameter(torch.ones(1)*1.5, requires_grad=True) 125 | 126 | 127 | def train_cali(model, args, data_loader, scaler, logger=None, path=None): 128 | model_cali = ModelCali(args).cuda() 129 | optimizer_cali = torch.optim.LBFGS(list(model_cali.parameters()), lr=0.02, max_iter=500) 130 | model.eval() 131 | #nll_fun = nn.GaussianNLLLoss() 132 | y_true = [] 133 | mu_pred = [] 134 | log_var_pred = [] 135 | 136 | with torch.no_grad(): 137 | for batch_idx, (data, target) in enumerate(data_loader): 138 | data = data[..., :args.input_dim] 139 | label = target[..., :args.output_dim] 140 | mu, log_var = model.forward(data, target, teacher_forcing_ratio=0) 141 | mu_pred.append(mu) 142 | log_var_pred.append(log_var) 143 | y_true.append(label) 144 | if args.real_value: 145 | mu_pred = torch.cat(mu_pred, dim=0) 146 | else: 147 | mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0)) 148 | log_var_pred = torch.cat(log_var_pred, dim=0) 149 | y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) 150 | 151 | y_pred = mu_pred 152 | precision = torch.exp(-log_var_pred) 153 | def eval_(): 154 | optimizer_cali.zero_grad() 155 | #loss(input, target, var) 156 | temperature = torch.exp(model_cali.T) 157 | #loss = nll_fun(y_pred,y_true,torch.exp(log_var_pred)/temperature) 158 | loss = torch.mean(temperature*precision*(y_true-y_pred)**2 + log_var_pred-model_cali.T) 159 | #print(loss.item()) 160 | loss.backward() 161 | return loss 162 | optimizer_cali.step(eval_) 163 | print("Calibration finished!") 164 | return model_cali.T 165 | 166 | 167 | def train_cali_mc(model,num_samples, args, data_loader, scaler, logger=None, path=None): 168 | model_cali = ModelCali(args).cuda() 169 | optimizer_cali = torch.optim.LBFGS(list(model_cali.parameters()), lr=0.02, max_iter=500) 170 | model.eval() 171 | enable_dropout(model) 172 | nll_fun = nn.GaussianNLLLoss() 173 | y_true = [] 174 | with torch.no_grad(): 175 | for batch_idx, (_, target) in enumerate(data_loader): 176 | label = target[..., :args.output_dim] 177 | y_true.append(label) 178 | y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)).squeeze(3) 179 | 180 | mc_mus = torch.empty(0, y_true.size(0), y_true.size(1), y_true.size(2)).cuda() 181 | mc_log_vars = torch.empty(0, y_true.size(0),y_true.size(1), y_true.size(2)).cuda() 182 | 183 | with torch.no_grad(): 184 | for i in tqdm(range(num_samples)): 185 | mu_pred = [] 186 | log_var_pred = [] 187 | for batch_idx, (data, target) in enumerate(data_loader): 188 | data = data[..., :args.input_dim] 189 | label = target[..., :args.output_dim] 190 | mu, log_var = model.forward(data, target, teacher_forcing_ratio=0) 191 | #print(mu.size()) 192 | mu_pred.append(mu.squeeze(3)) 193 | log_var_pred.append(log_var.squeeze(3)) 194 | 195 | if args.real_value: 196 | mu_pred = torch.cat(mu_pred, dim=0) 197 | else: 198 | mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0)) 199 | log_var_pred = torch.cat(log_var_pred, dim=0) 200 | 201 | #print(mc_mus.size(),mu_pred.size()) 202 | mc_mus = torch.vstack((mc_mus,mu_pred.unsqueeze(0))) 203 | mc_log_vars = torch.vstack((mc_log_vars,log_var_pred.unsqueeze(0))) 204 | 205 | y_pred = torch.mean(mc_mus, axis=0) 206 | #pred_std = torch.sqrt(torch.exp(torch.mean(mc_log_vars, axis=0))) 207 | #mc_std = torch.std(mc_mus, axis=0) 208 | #total_std = mc_std+pred_std 209 | #total_var = total_std**2 210 | #log_var_total = 2*torch.log(mc_std+pred_std) 211 | log_var_total = torch.exp(torch.mean(mc_log_vars, axis=0)) 212 | #precision = (mc_std+pred_std)**2 213 | precision = torch.exp(-torch.mean(mc_log_vars, axis=0)) 214 | 215 | def eval_(): 216 | optimizer_cali.zero_grad() 217 | #loss(input, target, var) 218 | temperature = torch.exp(model_cali.T) 219 | #loss = nll_fun(y_pred.ravel(),y_true.ravel(),torch.exp(log_var_total.ravel())/temperature) 220 | #loss = torch.mean(temperature*precision*(y_true-y_pred)**2 + log_var_total-model_cali.T) 221 | loss = torch.mean(temperature*precision*(y_true-y_pred)**2-model_cali.T) 222 | #print(loss.item()) 223 | loss.backward() 224 | return loss 225 | optimizer_cali.step(eval_) 226 | print("Calibration finished!") 227 | return model_cali.T 228 | 229 | -------------------------------------------------------------------------------- /notebooks/main_v2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "id": "1af4aa5e", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/html": [ 12 | "" 3266 | ], 3267 | "text/plain": [ 3268 | "" 3269 | ] 3270 | }, 3271 | "execution_count": 3, 3272 | "metadata": {}, 3273 | "output_type": "execute_result" 3274 | } 3275 | ], 3276 | "source": [ 3277 | "%reset -f\n", 3278 | "from jupyterthemes import get_themes\n", 3279 | "import jupyterthemes as jt\n", 3280 | "from jupyterthemes.stylefx import set_nb_theme\n", 3281 | "set_nb_theme('onedork')" 3282 | ] 3283 | }, 3284 | { 3285 | "cell_type": "code", 3286 | "execution_count": 68, 3287 | "id": "d024eac8", 3288 | "metadata": {}, 3289 | "outputs": [], 3290 | "source": [ 3291 | "import torch\n", 3292 | "torch.cuda.empty_cache()" 3293 | ] 3294 | }, 3295 | { 3296 | "cell_type": "code", 3297 | "execution_count": 69, 3298 | "id": "fa419dbc", 3299 | "metadata": {}, 3300 | "outputs": [ 3301 | { 3302 | "data": { 3303 | "text/html": [ 3304 | "" 3305 | ], 3306 | "text/plain": [ 3307 | "" 3308 | ] 3309 | }, 3310 | "metadata": {}, 3311 | "output_type": "display_data" 3312 | } 3313 | ], 3314 | "source": [ 3315 | "from IPython.display import display, HTML\n", 3316 | "display(HTML(\"\"))" 3317 | ] 3318 | }, 3319 | { 3320 | "cell_type": "code", 3321 | "execution_count": 70, 3322 | "id": "8ea4d3cc", 3323 | "metadata": {}, 3324 | "outputs": [], 3325 | "source": [ 3326 | "import os\n", 3327 | "import sys\n", 3328 | "# file_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))\n", 3329 | "# print(file_dir)\n", 3330 | "# sys.path.append(file_dir)\n", 3331 | "\n", 3332 | "import torch\n", 3333 | "import numpy as np\n", 3334 | "import torch.nn as nn\n", 3335 | "import argparse\n", 3336 | "import configparser\n", 3337 | "from datetime import datetime\n", 3338 | "from model.AGCRN import AGCRN_UQ as Network\n", 3339 | "from model.BasicTrainer import Trainer\n", 3340 | "from lib.TrainInits import init_seed\n", 3341 | "from lib.dataloader import get_dataloader\n", 3342 | "from lib.TrainInits import print_model_parameters\n", 3343 | "from tqdm import tqdm\n", 3344 | "from copy import deepcopy\n", 3345 | "import torch.nn.functional as F\n", 3346 | "import torchcontrib\n", 3347 | "from model.train_methods import swa_train_combined, swa_train, train_cali, train_cali_mc, train_fge\n", 3348 | "from model.test_methods import regular_test, heter_test, mc_test, combined_test, ensemble_test,quantile_test" 3349 | ] 3350 | }, 3351 | { 3352 | "cell_type": "code", 3353 | "execution_count": null, 3354 | "id": "7d85451d", 3355 | "metadata": {}, 3356 | "outputs": [], 3357 | "source": [] 3358 | }, 3359 | { 3360 | "cell_type": "code", 3361 | "execution_count": 71, 3362 | "id": "e1c79495", 3363 | "metadata": {}, 3364 | "outputs": [], 3365 | "source": [ 3366 | "#*************************************************************************#\n", 3367 | "Mode = 'train'\n", 3368 | "DEBUG = 'True'\n", 3369 | "DATASET = 'PEMS04' #PEMS04/8/3/7\n", 3370 | "DEVICE = 'cuda:0'\n", 3371 | "MODEL = 'AGCRN'\n", 3372 | "MODEL_NAME = \"combined\"#\"combined\" #\"combined\"#\"basic/dropout/heter/combined_swa\"\n", 3373 | "P1= 0.1 #04/03/07: 0.1; 08: 0.05\n", 3374 | "\n", 3375 | "#get configuration\n", 3376 | "config_file = 'model/{}_{}.conf'.format(DATASET, MODEL)\n", 3377 | "#print('Read configuration file: %s' % (config_file))\n", 3378 | "config = configparser.ConfigParser()\n", 3379 | "config.read(config_file)\n", 3380 | "\n", 3381 | "#config[\"data\"]\n", 3382 | "\n", 3383 | "from lib.utils import enable_dropout,save_model_,load_model_\n", 3384 | "from lib.metrics import All_Metrics\n" 3385 | ] 3386 | }, 3387 | { 3388 | "cell_type": "code", 3389 | "execution_count": 72, 3390 | "id": "35cc4723", 3391 | "metadata": {}, 3392 | "outputs": [ 3393 | { 3394 | "name": "stdout", 3395 | "output_type": "stream", 3396 | "text": [ 3397 | "*****************Model Parameter*****************\n", 3398 | "T torch.Size([1]) True\n", 3399 | "node_embeddings torch.Size([307, 10]) True\n", 3400 | "encoder.dcrnn_cells.0.gate.weights_pool torch.Size([10, 2, 65, 128]) True\n", 3401 | "encoder.dcrnn_cells.0.gate.bias_pool torch.Size([10, 128]) True\n", 3402 | "encoder.dcrnn_cells.0.update.weights_pool torch.Size([10, 2, 65, 64]) True\n", 3403 | "encoder.dcrnn_cells.0.update.bias_pool torch.Size([10, 64]) True\n", 3404 | "encoder.dcrnn_cells.1.gate.weights_pool torch.Size([10, 2, 128, 128]) True\n", 3405 | "encoder.dcrnn_cells.1.gate.bias_pool torch.Size([10, 128]) True\n", 3406 | "encoder.dcrnn_cells.1.update.weights_pool torch.Size([10, 2, 128, 64]) True\n", 3407 | "encoder.dcrnn_cells.1.update.bias_pool torch.Size([10, 64]) True\n", 3408 | "get_mu.0.weight torch.Size([32, 1, 1, 1]) True\n", 3409 | "get_mu.0.bias torch.Size([32]) True\n", 3410 | "get_mu.3.weight torch.Size([12, 32, 1, 64]) True\n", 3411 | "get_mu.3.bias torch.Size([12]) True\n", 3412 | "get_log_var.0.weight torch.Size([32, 1, 1, 1]) True\n", 3413 | "get_log_var.0.bias torch.Size([32]) True\n", 3414 | "get_log_var.3.weight torch.Size([12, 32, 1, 64]) True\n", 3415 | "get_log_var.3.bias torch.Size([12]) True\n", 3416 | "Total params num: 797335\n", 3417 | "*****************Finish Parameter****************\n", 3418 | "Load PEMS04 Dataset shaped: (16992, 307, 1) 919.0 0.0 211.7007794815878 180.0\n", 3419 | "Normalize the dataset by Standard Normalization\n", 3420 | "Train: (10173, 12, 307, 1) (10173, 12, 307, 1)\n", 3421 | "Val: (3375, 12, 307, 1) (3375, 12, 307, 1)\n", 3422 | "Test: (3375, 12, 307, 1) (3375, 12, 307, 1)\n" 3423 | ] 3424 | } 3425 | ], 3426 | "source": [ 3427 | "from lib.metrics import MAE_torch\n", 3428 | "def masked_mae_loss(scaler, mask_value):\n", 3429 | " def loss(preds, labels):\n", 3430 | " if scaler:\n", 3431 | " preds = scaler.inverse_transform(preds)\n", 3432 | " labels = scaler.inverse_transform(labels)\n", 3433 | " mae = MAE_torch(pred=preds, true=labels, mask_value=mask_value)\n", 3434 | " return mae\n", 3435 | " return loss\n", 3436 | "\n", 3437 | "#parser\n", 3438 | "args = argparse.ArgumentParser(description='arguments')\n", 3439 | "args.add_argument('--dataset', default=DATASET, type=str)\n", 3440 | "args.add_argument('--mode', default=Mode, type=str)\n", 3441 | "args.add_argument('--device', default=DEVICE, type=str, help='indices of GPUs')\n", 3442 | "args.add_argument('--debug', default=DEBUG, type=eval)\n", 3443 | "args.add_argument('--model', default=MODEL, type=str)\n", 3444 | "args.add_argument('--cuda', default=True, type=bool)\n", 3445 | "#data\n", 3446 | "args.add_argument('--val_ratio', default=config['data']['val_ratio'], type=float)\n", 3447 | "args.add_argument('--test_ratio', default=config['data']['test_ratio'], type=float)\n", 3448 | "#args.add_argument('--val_ratio', default=0.1, type=float)\n", 3449 | "#args.add_argument('--test_ratio', default=0.85, type=float)\n", 3450 | "\n", 3451 | "args.add_argument('--lag', default=config['data']['lag'], type=int)\n", 3452 | "args.add_argument('--horizon', default=config['data']['horizon'], type=int)\n", 3453 | "args.add_argument('--num_nodes', default=config['data']['num_nodes'], type=int)\n", 3454 | "args.add_argument('--tod', default=config['data']['tod'], type=eval)\n", 3455 | "args.add_argument('--normalizer', default=config['data']['normalizer'], type=str)\n", 3456 | "args.add_argument('--column_wise', default=config['data']['column_wise'], type=eval)\n", 3457 | "args.add_argument('--default_graph', default=config['data']['default_graph'], type=eval)\n", 3458 | "#model\n", 3459 | "args.add_argument('--input_dim', default=config['model']['input_dim'], type=int)\n", 3460 | "args.add_argument('--output_dim', default=config['model']['output_dim'], type=int)\n", 3461 | "args.add_argument('--embed_dim', default=config['model']['embed_dim'], type=int)\n", 3462 | "args.add_argument('--rnn_units', default=config['model']['rnn_units'], type=int)\n", 3463 | "args.add_argument('--num_layers', default=config['model']['num_layers'], type=int)\n", 3464 | "args.add_argument('--cheb_k', default=config['model']['cheb_order'], type=int)\n", 3465 | "#train\n", 3466 | "args.add_argument('--loss_func', default=config['train']['loss_func'], type=str)\n", 3467 | "#args.add_argument('--loss_func', default='mse', type=str)\n", 3468 | "args.add_argument('--seed', default=config['train']['seed'], type=int)\n", 3469 | "args.add_argument('--batch_size', default=config['train']['batch_size'], type=int)\n", 3470 | "args.add_argument('--epochs', default=config['train']['epochs'], type=int)\n", 3471 | "#args.add_argument('--epochs', default=500, type=int)\n", 3472 | "args.add_argument('--lr_init', default=config['train']['lr_init'], type=float)\n", 3473 | "#args.add_argument('--lr_init', default=1e-2, type=float)\n", 3474 | "args.add_argument('--lr_decay', default=config['train']['lr_decay'], type=eval)\n", 3475 | "args.add_argument('--lr_decay_rate', default=config['train']['lr_decay_rate'], type=float)\n", 3476 | "args.add_argument('--lr_decay_step', default=config['train']['lr_decay_step'], type=str)\n", 3477 | "args.add_argument('--early_stop', default=config['train']['early_stop'], type=eval)\n", 3478 | "args.add_argument('--early_stop_patience', default=config['train']['early_stop_patience'], type=int)\n", 3479 | "args.add_argument('--grad_norm', default=config['train']['grad_norm'], type=eval)\n", 3480 | "args.add_argument('--max_grad_norm', default=config['train']['max_grad_norm'], type=int)\n", 3481 | "args.add_argument('--teacher_forcing', default=False, type=bool)\n", 3482 | "args.add_argument('--tf_decay_steps', default=2000, type=int, help='teacher forcing decay steps')\n", 3483 | "args.add_argument('--real_value', default=config['train']['real_value'], type=eval, help = 'use real value for loss calculation')\n", 3484 | "#test\n", 3485 | "args.add_argument('--mae_thresh', default=config['test']['mae_thresh'], type=eval)\n", 3486 | "args.add_argument('--mape_thresh', default=config['test']['mape_thresh'], type=float)\n", 3487 | "#log\n", 3488 | "args.add_argument('--log_dir', default='./', type=str)\n", 3489 | "args.add_argument('--log_step', default=config['log']['log_step'], type=int)\n", 3490 | "args.add_argument('--plot', default=config['log']['plot'], type=eval)\n", 3491 | "args.add_argument('--model_name', default=MODEL_NAME, type=str)\n", 3492 | "args.add_argument('--p1', default=P1, type=float)\n", 3493 | "\n", 3494 | "\n", 3495 | "args = args.parse_args([])\n", 3496 | "init_seed(args.seed)\n", 3497 | "if torch.cuda.is_available():\n", 3498 | " torch.cuda.set_device(int(args.device[5]))\n", 3499 | "else:\n", 3500 | " args.device = 'cpu'\n", 3501 | "\n", 3502 | "#init model\n", 3503 | "model = Network(args).to(args.device)\n", 3504 | "for p in model.parameters():\n", 3505 | " if p.dim() > 1:\n", 3506 | " nn.init.xavier_uniform_(p)\n", 3507 | " else:\n", 3508 | " nn.init.uniform_(p)\n", 3509 | "print_model_parameters(model, only_num=False)\n", 3510 | "\n", 3511 | "#load dataset\n", 3512 | "train_loader, val_loader, test_loader, scaler = get_dataloader(args,\n", 3513 | " normalizer=args.normalizer,\n", 3514 | " tod=args.tod, dow=False,\n", 3515 | " weather=False, single=False)\n", 3516 | "#init loss function, optimizer\n", 3517 | "if args.loss_func == 'mask_mae':\n", 3518 | " loss = masked_mae_loss(scaler, mask_value=0.0)\n", 3519 | "elif args.loss_func == 'mae':\n", 3520 | " loss = torch.nn.L1Loss().to(args.device)\n", 3521 | "elif args.loss_func == 'mse':\n", 3522 | " loss = torch.nn.MSELoss().to(args.device)\n", 3523 | "else:\n", 3524 | " raise ValueError\n", 3525 | "\n", 3526 | "# optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr_init, eps=1.0e-8,\n", 3527 | "# weight_decay=1e-6, amsgrad=False)\n", 3528 | "\n", 3529 | "#basic\n", 3530 | "optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr_init, eps=1.0e-8,\n", 3531 | " weight_decay=0, amsgrad=False)\n", 3532 | "#learning rate decay\n", 3533 | "lr_scheduler = None\n", 3534 | "if args.lr_decay:\n", 3535 | " print('Applying learning rate decay.')\n", 3536 | " lr_decay_steps = [int(i) for i in list(args.lr_decay_step.split(','))]\n", 3537 | " lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,\n", 3538 | " milestones=lr_decay_steps,\n", 3539 | " gamma=args.lr_decay_rate)\n", 3540 | "#start training\n", 3541 | "trainer = Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler,\n", 3542 | " args, lr_scheduler=lr_scheduler)\n" 3543 | ] 3544 | }, 3545 | { 3546 | "cell_type": "markdown", 3547 | "id": "e640ac10", 3548 | "metadata": {}, 3549 | "source": [ 3550 | "# Pre-train model" 3551 | ] 3552 | }, 3553 | { 3554 | "cell_type": "code", 3555 | "execution_count": 73, 3556 | "id": "206aff0e", 3557 | "metadata": {}, 3558 | "outputs": [], 3559 | "source": [ 3560 | "#trainer.train()" 3561 | ] 3562 | }, 3563 | { 3564 | "cell_type": "markdown", 3565 | "id": "ae42fc74", 3566 | "metadata": {}, 3567 | "source": [ 3568 | "# AWA re-train model" 3569 | ] 3570 | }, 3571 | { 3572 | "cell_type": "code", 3573 | "execution_count": null, 3574 | "id": "6ee40969", 3575 | "metadata": {}, 3576 | "outputs": [], 3577 | "source": [ 3578 | "#trainer.model = swa_train_combined(trainer,epoch_swa=20)" 3579 | ] 3580 | }, 3581 | { 3582 | "cell_type": "markdown", 3583 | "id": "9bdea352", 3584 | "metadata": {}, 3585 | "source": [ 3586 | "# Save and load model" 3587 | ] 3588 | }, 3589 | { 3590 | "cell_type": "code", 3591 | "execution_count": 1, 3592 | "id": "e0c72752", 3593 | "metadata": {}, 3594 | "outputs": [], 3595 | "source": [ 3596 | "#save_model_(model,args.model_name,args.dataset,args.horizon)\n", 3597 | "#trainer.model = load_model_(model,args.model_name,args.dataset,args.horizon)" 3598 | ] 3599 | }, 3600 | { 3601 | "cell_type": "markdown", 3602 | "id": "48f3a147", 3603 | "metadata": {}, 3604 | "source": [ 3605 | "# MHCC: online calibration" 3606 | ] 3607 | }, 3608 | { 3609 | "cell_type": "code", 3610 | "execution_count": 93, 3611 | "id": "d6179d41", 3612 | "metadata": {}, 3613 | "outputs": [], 3614 | "source": [ 3615 | "\n", 3616 | "\"\"\"\n", 3617 | "Online inference as validation.\n", 3618 | "\"\"\"\n", 3619 | "\n", 3620 | "def combined_conf_val(model,num_samples,args, data_loader, scaler, q,logger=None, path=None):\n", 3621 | " model.eval()\n", 3622 | " enable_dropout(model)\n", 3623 | " nll_fun = nn.GaussianNLLLoss()\n", 3624 | " y_true = []\n", 3625 | " with torch.no_grad():\n", 3626 | " for batch_idx, (_, target) in enumerate(data_loader):\n", 3627 | " label = target[..., :args.output_dim]\n", 3628 | " y_true.append(label)\n", 3629 | " y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)).squeeze(3)\n", 3630 | " \n", 3631 | " mc_mus = torch.empty(0, y_true.size(0), y_true.size(1), y_true.size(2)).cuda()\n", 3632 | " mc_log_vars = torch.empty(0, y_true.size(0),y_true.size(1), y_true.size(2)).cuda()\n", 3633 | " \n", 3634 | " with torch.no_grad():\n", 3635 | " for i in tqdm(range(num_samples)):\n", 3636 | " mu_pred = []\n", 3637 | " log_var_pred = []\n", 3638 | " for batch_idx, (data, _) in enumerate(data_loader):\n", 3639 | " data = data[..., :args.input_dim]\n", 3640 | " mu, log_var = model.forward(data, target, teacher_forcing_ratio=0)\n", 3641 | " #print(mu.size())\n", 3642 | " mu_pred.append(mu.squeeze(3))\n", 3643 | " log_var_pred.append(log_var.squeeze(3))\n", 3644 | " \n", 3645 | " if args.real_value:\n", 3646 | " mu_pred = torch.cat(mu_pred, dim=0)\n", 3647 | " else:\n", 3648 | " mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0)) \n", 3649 | " log_var_pred = torch.cat(log_var_pred, dim=0) \n", 3650 | "\n", 3651 | " #print(mc_mus.size(),mu_pred.size()) \n", 3652 | " mc_mus = torch.vstack((mc_mus,mu_pred.unsqueeze(0))) \n", 3653 | " mc_log_vars = torch.vstack((mc_log_vars,log_var_pred.unsqueeze(0))) \n", 3654 | " \n", 3655 | " y_pred = torch.mean(mc_mus, axis=0)\n", 3656 | " #total_var = (torch.var(mc_mus, axis=0)+torch.exp(torch.mean(mc_log_vars, axis=0)))#/temperature \n", 3657 | " total_var = torch.exp(torch.mean(mc_log_vars, axis=0))\n", 3658 | " total_std = total_var**0.5 \n", 3659 | " \n", 3660 | " mpiw = 2*torch.mean(torch.mul(total_std,q)) \n", 3661 | " nll = nll_fun(y_pred.ravel(), y_true.ravel(), total_var.ravel())\n", 3662 | " lower_bound = y_pred-torch.mul(total_std,q)\n", 3663 | " upper_bound = y_pred+torch.mul(total_std,q) \n", 3664 | "\n", 3665 | " #in_num = torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ))\n", 3666 | " #print(torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ),dim=0))\n", 3667 | " #picp = in_num/(y_true.size(0)*y_true.size(1)*y_true.size(2))\n", 3668 | " in_num = torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ),dim=0)\n", 3669 | " #picp = in_num/(y_true.size(0)*y_true.size(1)*y_true.size(2))\n", 3670 | " in_num = torch.sum(in_num,dim=1)\n", 3671 | " picp = in_num/(y_true.size(0)*y_true.size(2))#.shape\n", 3672 | " return y_true, y_pred, total_std, picp.detach().cpu().numpy()\n", 3673 | " " 3674 | ] 3675 | }, 3676 | { 3677 | "cell_type": "code", 3678 | "execution_count": 94, 3679 | "id": "6665b648", 3680 | "metadata": {}, 3681 | "outputs": [ 3682 | { 3683 | "name": "stderr", 3684 | "output_type": "stream", 3685 | "text": [ 3686 | "100%|██████████| 10/10 [00:55<00:00, 5.59s/it]\n" 3687 | ] 3688 | } 3689 | ], 3690 | "source": [ 3691 | "y_true_val, y_pred_val, std_val, p = combined_conf_val(model,10,args,val_loader, scaler, q=1.96) \n", 3692 | "#y_true_val, mc_mus_val, mc_log_vars_val = combined_conf_val(model,10,args,test_loader, scaler, q=1.96) #" 3693 | ] 3694 | }, 3695 | { 3696 | "cell_type": "code", 3697 | "execution_count": 96, 3698 | "id": "81e913cd", 3699 | "metadata": {}, 3700 | "outputs": [], 3701 | "source": [ 3702 | "scores = abs(y_pred_val-y_true_val)/std_val#.shape\n", 3703 | "n = y_true_val.shape[0]\n", 3704 | "def quantile_lwci(scores,n,alpha):\n", 3705 | " q = torch.empty(y_true_val.size(1)).cuda()\n", 3706 | " for i in range(y_true_val.shape[1]): \n", 3707 | " qq=np.quantile(scores[:,i,:].detach().cpu().numpy(),min((n+1.0)*(1-alpha)/n,1))\n", 3708 | " q[i]=qq\n", 3709 | " return q\n", 3710 | "def quantile_mhcc(scores,n,alpha_new):\n", 3711 | " q = torch.empty(y_true_val.size(1)).cuda()\n", 3712 | " for i in range(y_true_val.shape[1]): \n", 3713 | " qq=np.quantile(scores[:,i,:].detach().cpu().numpy(),min((n+1.0)*(1-alpha_new[i])/n,1))\n", 3714 | " q[i]=qq\n", 3715 | " return q" 3716 | ] 3717 | }, 3718 | { 3719 | "cell_type": "code", 3720 | "execution_count": 97, 3721 | "id": "736f4549", 3722 | "metadata": {}, 3723 | "outputs": [], 3724 | "source": [ 3725 | "def combined_conf_test(model,num_samples,args, data_loader, scaler, q,logger=None, path=None):\n", 3726 | " model.eval()\n", 3727 | " enable_dropout(model)\n", 3728 | " nll_fun = nn.GaussianNLLLoss()\n", 3729 | " y_true = []\n", 3730 | " with torch.no_grad():\n", 3731 | " for batch_idx, (_, target) in enumerate(data_loader):\n", 3732 | " label = target[..., :args.output_dim]\n", 3733 | " y_true.append(label)\n", 3734 | " y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)).squeeze(3)\n", 3735 | " \n", 3736 | " mc_mus = torch.empty(0, y_true.size(0), y_true.size(1), y_true.size(2)).cuda()\n", 3737 | " mc_log_vars = torch.empty(0, y_true.size(0),y_true.size(1), y_true.size(2)).cuda()\n", 3738 | " \n", 3739 | " with torch.no_grad():\n", 3740 | " for i in tqdm(range(num_samples)):\n", 3741 | " mu_pred = []\n", 3742 | " log_var_pred = []\n", 3743 | " for batch_idx, (data, _) in enumerate(data_loader):\n", 3744 | " data = data[..., :args.input_dim]\n", 3745 | " mu, log_var = model.forward(data, target, teacher_forcing_ratio=0)\n", 3746 | " #print(mu.size())\n", 3747 | " mu_pred.append(mu.squeeze(3))\n", 3748 | " log_var_pred.append(log_var.squeeze(3))\n", 3749 | " \n", 3750 | " if args.real_value:\n", 3751 | " mu_pred = torch.cat(mu_pred, dim=0)\n", 3752 | " else:\n", 3753 | " mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0)) \n", 3754 | " log_var_pred = torch.cat(log_var_pred, dim=0) \n", 3755 | "\n", 3756 | " #print(mc_mus.size(),mu_pred.size()) \n", 3757 | " mc_mus = torch.vstack((mc_mus,mu_pred.unsqueeze(0))) \n", 3758 | " mc_log_vars = torch.vstack((mc_log_vars,log_var_pred.unsqueeze(0))) \n", 3759 | " \n", 3760 | " y_pred = torch.mean(mc_mus, axis=0)\n", 3761 | " total_var = (torch.var(mc_mus, axis=0)+torch.exp(torch.mean(mc_log_vars, axis=0)))#/temperature \n", 3762 | " total_std = total_var**0.5 \n", 3763 | " \n", 3764 | " mpiw = 2*torch.mean(torch.mul(total_std,q)) \n", 3765 | " nll = nll_fun(y_pred.ravel(), y_true.ravel(), total_var.ravel())\n", 3766 | " lower_bound = y_pred-torch.mul(total_std,q)\n", 3767 | " upper_bound = y_pred+torch.mul(total_std,q) \n", 3768 | " \n", 3769 | " in_num = torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ),dim=0)\n", 3770 | " in_num = torch.sum(in_num,dim=1)\n", 3771 | " picp = in_num/(y_true.size(0)*y_true.size(2))#.shape\n", 3772 | " print(picp*100, torch.mean(picp).item()*100, mpiw.item())\n", 3773 | " #return y_true, y_pred, total_std" 3774 | ] 3775 | }, 3776 | { 3777 | "cell_type": "markdown", 3778 | "id": "9988a0e6", 3779 | "metadata": {}, 3780 | "source": [ 3781 | "# Correct target signifcance level $\\alpha$" 3782 | ] 3783 | }, 3784 | { 3785 | "cell_type": "code", 3786 | "execution_count": 103, 3787 | "id": "10579167", 3788 | "metadata": {}, 3789 | "outputs": [ 3790 | { 3791 | "name": "stdout", 3792 | "output_type": "stream", 3793 | "text": [ 3794 | "[0.05340891 0.05050783 0.05023377 0.05054503 0.04958587 0.04828899\n", 3795 | " 0.04636179 0.04609852 0.0462424 0.04623555 0.04712426 0.04645186]\n", 3796 | "[0.05340891 0.05092526 0.05106862 0.0517973 0.05125556 0.05037611\n", 3797 | " 0.04886633 0.04902048 0.04958179 0.04999236 0.05129849 0.05104351]\n" 3798 | ] 3799 | } 3800 | ], 3801 | "source": [ 3802 | "h = np.arange(args.horizon)\n", 3803 | "alpha = 0.05\n", 3804 | "gamma = 0.03#04/07/08:0.03, 03:0\n", 3805 | "alpha_new = p-(1-alpha)+alpha\n", 3806 | "alpha_new = alpha_new+(p[0]-p[-1])*gamma*h*2 #04/07/08:0.03\n", 3807 | "q = quantile_mhcc(scores,n,alpha_new)" 3808 | ] 3809 | }, 3810 | { 3811 | "cell_type": "code", 3812 | "execution_count": 104, 3813 | "id": "30993a93", 3814 | "metadata": {}, 3815 | "outputs": [ 3816 | { 3817 | "name": "stderr", 3818 | "output_type": "stream", 3819 | "text": [ 3820 | ":2: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", 3821 | " q = torch.tensor(q.reshape(-1,12,1)).to(args.device)\n", 3822 | "100%|██████████| 10/10 [00:56<00:00, 5.67s/it]" 3823 | ] 3824 | }, 3825 | { 3826 | "name": "stdout", 3827 | "output_type": "stream", 3828 | "text": [ 3829 | "tensor([95.0162, 95.1789, 95.1004, 94.9924, 94.9851, 95.0227, 95.1640, 95.0996,\n", 3830 | " 95.0323, 95.0006, 94.8244, 94.8621], device='cuda:0') 95.0232207775116 103.91312408447266\n" 3831 | ] 3832 | }, 3833 | { 3834 | "name": "stderr", 3835 | "output_type": "stream", 3836 | "text": [ 3837 | "\n" 3838 | ] 3839 | } 3840 | ], 3841 | "source": [ 3842 | "q = torch.tensor(q.reshape(-1,args.horizon,1)).to(args.device)\n", 3843 | "combined_conf_test(model,10,args,test_loader,scaler,q)" 3844 | ] 3845 | }, 3846 | { 3847 | "cell_type": "markdown", 3848 | "id": "fee6a368", 3849 | "metadata": {}, 3850 | "source": [ 3851 | "# Online MHCC: online calibration\n", 3852 | "\n", 3853 | "### 1. Update nonconformity score\n", 3854 | "\n", 3855 | "### 2. Update empirical prediction interval percentage coverage\n", 3856 | "\n", 3857 | "### 3. Update alpha" 3858 | ] 3859 | }, 3860 | { 3861 | "cell_type": "code", 3862 | "execution_count": null, 3863 | "id": "ce5242e3", 3864 | "metadata": {}, 3865 | "outputs": [], 3866 | "source": [ 3867 | "online_size = y_true_val.shape[0]#1000\n", 3868 | "y_true_online = y_true_val[:online_size,...]\n", 3869 | "y_pred_online = y_pred_val[:online_size,...]\n", 3870 | "std_online = std_val[:online_size,...]\n", 3871 | "score_online = scores[:online_size,...]\n", 3872 | "\n", 3873 | "\n", 3874 | "gamma_online = 0.03 #04/07/08:0.03, 03:0\n", 3875 | "h = np.arange(args.horizon)\n", 3876 | "update_freq = 1000\n", 3877 | "picp_online = p" 3878 | ] 3879 | }, 3880 | { 3881 | "cell_type": "code", 3882 | "execution_count": null, 3883 | "id": "67984ca2", 3884 | "metadata": {}, 3885 | "outputs": [], 3886 | "source": [ 3887 | "alpha = 0.05\n", 3888 | "scores_online = scores\n", 3889 | "\n", 3890 | "def mhcc_online_(y_true_test, y_pred_test, std_test, q_online = q):\n", 3891 | " ii = 0\n", 3892 | " mpiw_ls = [] \n", 3893 | " picp_ls = [] \n", 3894 | " for y_t, p_t, s_t in tqdm(zip(y_true_test, y_pred_test, std_test)): \n", 3895 | " \n", 3896 | " #emperical picp\n", 3897 | " mpiw = 2*torch.mean(torch.mul(s_t,q_online))\n", 3898 | " lower_bound = p_t-torch.mul(s_t,q_online)\n", 3899 | " upper_bound = p_t+torch.mul(s_t,q_online) \n", 3900 | "\n", 3901 | " in_num = torch.sum((y_t >= lower_bound)&(y_t <= upper_bound),dim=0)\n", 3902 | " in_num = torch.sum(in_num,dim=1)\n", 3903 | " picp = in_num/y_true_test.size(2)\n", 3904 | "\n", 3905 | " mpiw_ls.append(mpiw) \n", 3906 | " picp_ls.append(picp) \n", 3907 | " picp = picp.detach().cpu().numpy()\n", 3908 | " picp_online = (picp_online*(online_size+ii) + picp)/(online_size+ii +1)\n", 3909 | " \n", 3910 | " #update alpha\n", 3911 | " score_new = abs(p_t-y_t)/s_t\n", 3912 | " scores_online = torch.cat([scores_online[1:,...],score_new.unsqueeze(0)],dim=0)\n", 3913 | "\n", 3914 | "\n", 3915 | " if (ii +1) % update_freq ==0:\n", 3916 | " #print(ii)\n", 3917 | " alpha_new = picp_online-(1-alpha)+alpha\n", 3918 | " alpha_new = alpha_new+(picp_online[0]-picp_online[-1])*gamma_online*h*2 #04/07/08:0.03\n", 3919 | " for i in range(y_true_test.shape[1]): \n", 3920 | " qq = np.quantile(scores_online[:,i,:].detach().cpu().numpy(),min((online_size+1.0)*(1-alpha_new[i])/online_size,1))\n", 3921 | " #print(qq.shape) \n", 3922 | " q_online[0,i,0]=torch.from_numpy(qq.reshape(-1,1))\n", 3923 | "\n", 3924 | " ii = ii+1 \n", 3925 | " \n", 3926 | " picp_test = torch.stack(picp_ls,dim=1) \n", 3927 | " mpiw_test = torch.stack(mpiw_ls,dim=0) \n", 3928 | " \n", 3929 | " ### update q\n", 3930 | " q_online = q\n", 3931 | " mpiw = 2*torch.mean(torch.mul(std_test,q_online))#.reshape(-1,y_true_test.shape[1],1)) ) \n", 3932 | " lower_bound = y_pred_test-torch.mul(std_test,q_online)\n", 3933 | " upper_bound = y_pred_test+torch.mul(std_test,q_online) \n", 3934 | "\n", 3935 | " in_num = torch.sum((y_true_test >= lower_bound)&(y_true_test <= upper_bound),dim=0)\n", 3936 | " in_num = torch.sum(in_num,dim=1)\n", 3937 | " picp = in_num/(y_true_test.size(0)*y_true_test.size(2)) \n", 3938 | "\n", 3939 | " print(\"PICP: {:.4f}, MPIW: {:.4f}\".format(picp,mpiw))\n", 3940 | " \n", 3941 | " " 3942 | ] 3943 | }, 3944 | { 3945 | "cell_type": "code", 3946 | "execution_count": null, 3947 | "id": "ca062c76", 3948 | "metadata": {}, 3949 | "outputs": [], 3950 | "source": [ 3951 | " mhcc_online_(y_true_test, y_pred_test, std_test, q_online = q)" 3952 | ] 3953 | }, 3954 | { 3955 | "cell_type": "code", 3956 | "execution_count": null, 3957 | "id": "87de181b", 3958 | "metadata": {}, 3959 | "outputs": [], 3960 | "source": [] 3961 | }, 3962 | { 3963 | "cell_type": "code", 3964 | "execution_count": null, 3965 | "id": "32915a3c", 3966 | "metadata": {}, 3967 | "outputs": [], 3968 | "source": [] 3969 | }, 3970 | { 3971 | "cell_type": "code", 3972 | "execution_count": null, 3973 | "id": "48a8c8bc", 3974 | "metadata": {}, 3975 | "outputs": [], 3976 | "source": [] 3977 | }, 3978 | { 3979 | "cell_type": "code", 3980 | "execution_count": null, 3981 | "id": "c15307bc", 3982 | "metadata": {}, 3983 | "outputs": [], 3984 | "source": [] 3985 | } 3986 | ], 3987 | "metadata": { 3988 | "kernelspec": { 3989 | "display_name": "Python 3", 3990 | "language": "python", 3991 | "name": "python3" 3992 | }, 3993 | "language_info": { 3994 | "codemirror_mode": { 3995 | "name": "ipython", 3996 | "version": 3 3997 | }, 3998 | "file_extension": ".py", 3999 | "mimetype": "text/x-python", 4000 | "name": "python", 4001 | "nbconvert_exporter": "python", 4002 | "pygments_lexer": "ipython3", 4003 | "version": "3.8.10" 4004 | } 4005 | }, 4006 | "nbformat": 4, 4007 | "nbformat_minor": 5 4008 | } 4009 | -------------------------------------------------------------------------------- /notebooks/readme: -------------------------------------------------------------------------------- 1 | DeepSTUQ with MHCC 2 | -------------------------------------------------------------------------------- /slide/uq_slide.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WeizhuQIAN/DeepSTUQ_Pytorch/4f2281531d07b48ba81aa1c519a90e65a71993e2/slide/uq_slide.pdf --------------------------------------------------------------------------------