├── README.md
├── lib
├── TrainInits.py
├── add_window.py
├── dataloader.py
├── load_dataset.py
├── metrics.py
├── normalization.py
└── utils.py
├── main.py
├── model
├── AGCN.py
├── AGCRNCell.py
├── AGCRN_UQ.py
├── BasicTrainer.py
├── PEMS03_AGCRN.conf
├── PEMS04_AGCRN.conf
├── PEMS07_AGCRN.conf
├── PEMS08_AGCRN.conf
├── readme.md
├── test_methods.py
└── train_methods.py
├── notebooks
├── main_v2.ipynb
└── readme
└── slide
└── uq_slide.pdf
/README.md:
--------------------------------------------------------------------------------
1 | # Code for DeepSTUQ
2 |
3 | ## Related Papers:
4 | ### [Uncertainty Quantification for Traffic Forecasting: A Unified Approach, ICDE 2023](https://arxiv.org/pdf/2208.05875.pdf)
5 | ### [Towards a Unified Understanding of Uncertainty Quantification in Traffic Flow Forecasting, TKDE 2023](https://ieeexplore.ieee.org/abstract/document/10242138)
6 |
7 | ## Datasets:
8 | [PEMS datasets](https://www.kaggle.com/datasets/elmahy/pems-dataset?resource=download)
9 |
10 |
11 | ## Acknowledgement:
12 | The implementation is based on [this repository](https://github.com/LeiBAI/AGCRN)
13 |
14 |
15 |
--------------------------------------------------------------------------------
/lib/TrainInits.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 |
5 | def init_seed(seed):
6 | '''
7 | Disable cudnn to maximize reproducibility
8 | '''
9 | torch.cuda.cudnn_enabled = False
10 | torch.backends.cudnn.deterministic = True
11 | random.seed(seed)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed(seed)
15 |
16 | def init_device(opt):
17 | if torch.cuda.is_available():
18 | opt.cuda = True
19 | torch.cuda.set_device(int(opt.device[5]))
20 | else:
21 | opt.cuda = False
22 | opt.device = 'cpu'
23 | return opt
24 |
25 | def init_optim(model, opt):
26 | '''
27 | Initialize optimizer
28 | '''
29 | return torch.optim.Adam(params=model.parameters(),lr=opt.lr_init)
30 |
31 | def init_lr_scheduler(optim, opt):
32 | '''
33 | Initialize the learning rate scheduler
34 | '''
35 | #return torch.optim.lr_scheduler.StepLR(optimizer=optim,gamma=opt.lr_scheduler_rate,step_size=opt.lr_scheduler_step)
36 | return torch.optim.lr_scheduler.MultiStepLR(optimizer=optim, milestones=opt.lr_decay_steps,
37 | gamma = opt.lr_scheduler_rate)
38 |
39 | def print_model_parameters(model, only_num = True):
40 | print('*****************Model Parameter*****************')
41 | if not only_num:
42 | for name, param in model.named_parameters():
43 | print(name, param.shape, param.requires_grad)
44 | total_num = sum([param.nelement() for param in model.parameters()])
45 | print('Total params num: {}'.format(total_num))
46 | print('*****************Finish Parameter****************')
47 |
48 | def get_memory_usage(device):
49 | allocated_memory = torch.cuda.memory_allocated(device) / (1024*1024.)
50 | cached_memory = torch.cuda.memory_cached(device) / (1024*1024.)
51 | return allocated_memory, cached_memory
52 | #print('Allocated Memory: {:.2f} MB, Cached Memory: {:.2f} MB'.format(allocated_memory, cached_memory))
--------------------------------------------------------------------------------
/lib/add_window.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def Add_Window_Horizon(data, window=3, horizon=1, single=False):
4 | '''
5 | :param data: shape [B, ...]
6 | :param window:
7 | :param horizon:
8 | :return: X is [B, W, ...], Y is [B, H, ...]
9 | '''
10 | length = len(data)
11 | end_index = length - horizon - window + 1
12 | X = [] #windows
13 | Y = [] #horizon
14 | index = 0
15 | if single:
16 | while index < end_index:
17 | X.append(data[index:index+window])
18 | Y.append(data[index+window+horizon-1:index+window+horizon])
19 | index = index + 1
20 | else:
21 | while index < end_index:
22 | X.append(data[index:index+window])
23 | Y.append(data[index+window:index+window+horizon])
24 | index = index + 1
25 | X = np.array(X)
26 | Y = np.array(Y)
27 | return X, Y
28 |
29 | if __name__ == '__main__':
30 | from data.load_raw_data import Load_Sydney_Demand_Data
31 | path = '../data/1h_data_new3.csv'
32 | data = Load_Sydney_Demand_Data(path)
33 | print(data.shape)
34 | X, Y = Add_Window_Horizon(data, horizon=2)
35 | print(X.shape, Y.shape)
36 |
37 |
--------------------------------------------------------------------------------
/lib/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.utils.data
4 | from lib.add_window import Add_Window_Horizon
5 | from lib.load_dataset import load_st_dataset
6 | from lib.normalization import NScaler, MinMax01Scaler, MinMax11Scaler, StandardScaler, ColumnMinMaxScaler
7 |
8 | def normalize_dataset(data, normalizer, column_wise=False):
9 | if normalizer == 'max01':
10 | if column_wise:
11 | minimum = data.min(axis=0, keepdims=True)
12 | maximum = data.max(axis=0, keepdims=True)
13 | else:
14 | minimum = data.min()
15 | maximum = data.max()
16 | scaler = MinMax01Scaler(minimum, maximum)
17 | data = scaler.transform(data)
18 | print('Normalize the dataset by MinMax01 Normalization')
19 | elif normalizer == 'max11':
20 | if column_wise:
21 | minimum = data.min(axis=0, keepdims=True)
22 | maximum = data.max(axis=0, keepdims=True)
23 | else:
24 | minimum = data.min()
25 | maximum = data.max()
26 | scaler = MinMax11Scaler(minimum, maximum)
27 | data = scaler.transform(data)
28 | print('Normalize the dataset by MinMax11 Normalization')
29 | elif normalizer == 'std':
30 | if column_wise:
31 | mean = data.mean(axis=0, keepdims=True)
32 | std = data.std(axis=0, keepdims=True)
33 | else:
34 | mean = data.mean()
35 | std = data.std()
36 | scaler = StandardScaler(mean, std)
37 | data = scaler.transform(data)
38 | print('Normalize the dataset by Standard Normalization')
39 | elif normalizer == 'None':
40 | scaler = NScaler()
41 | data = scaler.transform(data)
42 | print('Does not normalize the dataset')
43 | elif normalizer == 'cmax':
44 | #column min max, to be depressed
45 | #note: axis must be the spatial dimension, please check !
46 | scaler = ColumnMinMaxScaler(data.min(axis=0), data.max(axis=0))
47 | data = scaler.transform(data)
48 | print('Normalize the dataset by Column Min-Max Normalization')
49 | else:
50 | raise ValueError
51 | return data, scaler
52 |
53 | def split_data_by_days(data, val_days, test_days, interval=60):
54 | '''
55 | :param data: [B, *]
56 | :param val_days:
57 | :param test_days:
58 | :param interval: interval (15, 30, 60) minutes
59 | :return:
60 | '''
61 | T = int((24*60)/interval)
62 | test_data = data[-T*test_days:]
63 | val_data = data[-T*(test_days + val_days): -T*test_days]
64 | train_data = data[:-T*(test_days + val_days)]
65 | return train_data, val_data, test_data
66 |
67 | def split_data_by_ratio(data, val_ratio, test_ratio):
68 | data_len = data.shape[0]
69 | test_data = data[-int(data_len*test_ratio):]
70 | val_data = data[-int(data_len*(test_ratio+val_ratio)):-int(data_len*test_ratio)]
71 | train_data = data[:-int(data_len*(test_ratio+val_ratio))]
72 | return train_data, val_data, test_data
73 |
74 | def data_loader(X, Y, batch_size, shuffle=True, drop_last=True):
75 | cuda = True if torch.cuda.is_available() else False
76 | TensorFloat = torch.cuda.FloatTensor if cuda else torch.FloatTensor
77 | X, Y = TensorFloat(X), TensorFloat(Y)
78 | data = torch.utils.data.TensorDataset(X, Y)
79 | dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size,
80 | shuffle=shuffle, drop_last=drop_last)
81 | return dataloader
82 |
83 |
84 | def get_dataloader(args, normalizer = 'std', tod=False, dow=False, weather=False, single=True):
85 | #load raw st dataset
86 | data = load_st_dataset(args.dataset) # B, N, D
87 | #normalize st data
88 | data, scaler = normalize_dataset(data, normalizer, args.column_wise)
89 | #spilit dataset by days or by ratio
90 | if args.test_ratio > 1:
91 | data_train, data_val, data_test = split_data_by_days(data, args.val_ratio, args.test_ratio)
92 | else:
93 | data_train, data_val, data_test = split_data_by_ratio(data, args.val_ratio, args.test_ratio)
94 | #add time window
95 | x_tra, y_tra = Add_Window_Horizon(data_train, args.lag, args.horizon, single)
96 | x_val, y_val = Add_Window_Horizon(data_val, args.lag, args.horizon, single)
97 | x_test, y_test = Add_Window_Horizon(data_test, args.lag, args.horizon, single)
98 | print('Train: ', x_tra.shape, y_tra.shape)
99 | print('Val: ', x_val.shape, y_val.shape)
100 | print('Test: ', x_test.shape, y_test.shape)
101 | ##############get dataloader######################
102 | train_dataloader = data_loader(x_tra, y_tra, args.batch_size, shuffle=True, drop_last=True)
103 | if len(x_val) == 0:
104 | val_dataloader = None
105 | else:
106 | val_dataloader = data_loader(x_val, y_val, args.batch_size, shuffle=False, drop_last=True)
107 | test_dataloader = data_loader(x_test, y_test, args.batch_size, shuffle=False, drop_last=False)
108 | return train_dataloader, val_dataloader, test_dataloader, scaler
109 |
110 |
111 | if __name__ == '__main__':
112 | import argparse
113 | #MetrLA 207; BikeNYC 128; SIGIR_solar 137; SIGIR_electric 321
114 | DATASET = 'SIGIR_electric'
115 | if DATASET == 'MetrLA':
116 | NODE_NUM = 207
117 | elif DATASET == 'BikeNYC':
118 | NODE_NUM = 128
119 | elif DATASET == 'SIGIR_solar':
120 | NODE_NUM = 137
121 | elif DATASET == 'SIGIR_electric':
122 | NODE_NUM = 321
123 | parser = argparse.ArgumentParser(description='PyTorch dataloader')
124 | parser.add_argument('--dataset', default=DATASET, type=str)
125 | parser.add_argument('--num_nodes', default=NODE_NUM, type=int)
126 | parser.add_argument('--val_ratio', default=0.1, type=float)
127 | parser.add_argument('--test_ratio', default=0.2, type=float)
128 | parser.add_argument('--lag', default=12, type=int)
129 | parser.add_argument('--horizon', default=12, type=int)
130 | parser.add_argument('--batch_size', default=64, type=int)
131 | args = parser.parse_args()
132 | train_dataloader, val_dataloader, test_dataloader, scaler = get_dataloader(args, normalizer = 'std', tod=False, dow=False, weather=False, single=True)
--------------------------------------------------------------------------------
/lib/load_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | def load_st_dataset(dataset):
5 | #output B, N, D
6 | if dataset == 'PEMS04':
7 | data_path = os.path.join('./data/PEMS04/pems04.npz')
8 | data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data
9 | elif dataset == 'PEMS08':
10 | data_path = os.path.join('./data/PEMS08/pems08.npz')
11 | data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data
12 | elif dataset == 'PEMS03':
13 | data_path = os.path.join('./data/PEMS03/pems03.npz')
14 | data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data
15 | elif dataset == 'PEMS07':
16 | data_path = os.path.join('./data/PEMS07/pems07.npz')
17 | data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data
18 | else:
19 | raise ValueError
20 | if len(data.shape) == 2:
21 | data = np.expand_dims(data, axis=-1)
22 | print('Load %s Dataset shaped: ' % dataset, data.shape, data.max(), data.min(), data.mean(), np.median(data))
23 | return data
--------------------------------------------------------------------------------
/lib/metrics.py:
--------------------------------------------------------------------------------
1 | '''
2 | Always evaluate the model with MAE, RMSE, MAPE, RRSE, PNBI, and oPNBI.
3 | Why add mask to MAE and RMSE?
4 | Filter the 0 that may be caused by error (such as loop sensor)
5 | Why add mask to MAPE and MARE?
6 | Ignore very small values (e.g., 0.5/0.5=100%)
7 | '''
8 | import numpy as np
9 | import torch
10 |
11 | def MAE_torch(pred, true, mask_value=None):
12 | if mask_value != None:
13 | mask = torch.gt(true, mask_value)
14 | pred = torch.masked_select(pred, mask)
15 | true = torch.masked_select(true, mask)
16 | return torch.mean(torch.abs(true-pred))
17 |
18 | def MSE_torch(pred, true, mask_value=None):
19 | if mask_value != None:
20 | mask = torch.gt(true, mask_value)
21 | pred = torch.masked_select(pred, mask)
22 | true = torch.masked_select(true, mask)
23 | return torch.mean((pred - true) ** 2)
24 |
25 | def RMSE_torch(pred, true, mask_value=None):
26 | if mask_value != None:
27 | mask = torch.gt(true, mask_value)
28 | pred = torch.masked_select(pred, mask)
29 | true = torch.masked_select(true, mask)
30 | return torch.sqrt(torch.mean((pred - true) ** 2))
31 |
32 | def RRSE_torch(pred, true, mask_value=None):
33 | if mask_value != None:
34 | mask = torch.gt(true, mask_value)
35 | pred = torch.masked_select(pred, mask)
36 | true = torch.masked_select(true, mask)
37 | return torch.sqrt(torch.sum((pred - true) ** 2)) / torch.sqrt(torch.sum((pred - true.mean()) ** 2))
38 |
39 | def CORR_torch(pred, true, mask_value=None):
40 | #input B, T, N, D or B, N, D or B, N
41 | if len(pred.shape) == 2:
42 | pred = pred.unsqueeze(dim=1).unsqueeze(dim=1)
43 | true = true.unsqueeze(dim=1).unsqueeze(dim=1)
44 | elif len(pred.shape) == 3:
45 | pred = pred.transpose(1, 2).unsqueeze(dim=1)
46 | true = true.transpose(1, 2).unsqueeze(dim=1)
47 | elif len(pred.shape) == 4:
48 | #B, T, N, D -> B, T, D, N
49 | pred = pred.transpose(2, 3)
50 | true = true.transpose(2, 3)
51 | else:
52 | raise ValueError
53 | dims = (0, 1, 2)
54 | pred_mean = pred.mean(dim=dims)
55 | true_mean = true.mean(dim=dims)
56 | pred_std = pred.std(dim=dims)
57 | true_std = true.std(dim=dims)
58 | correlation = ((pred - pred_mean)*(true - true_mean)).mean(dim=dims) / (pred_std*true_std)
59 | index = (true_std != 0)
60 | correlation = (correlation[index]).mean()
61 | return correlation
62 |
63 |
64 | def MAPE_torch(pred, true, mask_value=None):
65 | if mask_value != None:
66 | mask = torch.gt(true, mask_value)
67 | pred = torch.masked_select(pred, mask)
68 | true = torch.masked_select(true, mask)
69 | return torch.mean(torch.abs(torch.div((true - pred), true)))
70 |
71 | def PNBI_torch(pred, true, mask_value=None):
72 | if mask_value != None:
73 | mask = torch.gt(true, mask_value)
74 | pred = torch.masked_select(pred, mask)
75 | true = torch.masked_select(true, mask)
76 | indicator = torch.gt(pred - true, 0).float()
77 | return indicator.mean()
78 |
79 | def oPNBI_torch(pred, true, mask_value=None):
80 | if mask_value != None:
81 | mask = torch.gt(true, mask_value)
82 | pred = torch.masked_select(pred, mask)
83 | true = torch.masked_select(true, mask)
84 | bias = (true+pred) / (2*true)
85 | return bias.mean()
86 |
87 | def MARE_torch(pred, true, mask_value=None):
88 | if mask_value != None:
89 | mask = torch.gt(true, mask_value)
90 | pred = torch.masked_select(pred, mask)
91 | true = torch.masked_select(true, mask)
92 | return torch.div(torch.sum(torch.abs((true - pred))), torch.sum(true))
93 |
94 | def SMAPE_torch(pred, true, mask_value=None):
95 | if mask_value != None:
96 | mask = torch.gt(true, mask_value)
97 | pred = torch.masked_select(pred, mask)
98 | true = torch.masked_select(true, mask)
99 | return torch.mean(torch.abs(true-pred)/(torch.abs(true)+torch.abs(pred)))
100 |
101 |
102 | def MAE_np(pred, true, mask_value=None):
103 | if mask_value != None:
104 | mask = np.where(true > (mask_value), True, False)
105 | true = true[mask]
106 | pred = pred[mask]
107 | MAE = np.mean(np.absolute(pred-true))
108 | return MAE
109 |
110 | def RMSE_np(pred, true, mask_value=None):
111 | if mask_value != None:
112 | mask = np.where(true > (mask_value), True, False)
113 | true = true[mask]
114 | pred = pred[mask]
115 | RMSE = np.sqrt(np.mean(np.square(pred-true)))
116 | return RMSE
117 |
118 | #Root Relative Squared Error
119 | def RRSE_np(pred, true, mask_value=None):
120 | if mask_value != None:
121 | mask = np.where(true > (mask_value), True, False)
122 | true = true[mask]
123 | pred = pred[mask]
124 | mean = true.mean()
125 | return np.divide(np.sqrt(np.sum((pred-true) ** 2)), np.sqrt(np.sum((true-mean) ** 2)))
126 |
127 | def MAPE_np(pred, true, mask_value=None):
128 | if mask_value != None:
129 | mask = np.where(true > (mask_value), True, False)
130 | true = true[mask]
131 | pred = pred[mask]
132 | return np.mean(np.absolute(np.divide((true - pred), true)))
133 |
134 | def PNBI_np(pred, true, mask_value=None):
135 | #if PNBI=0, all pred are smaller than true
136 | #if PNBI=1, all pred are bigger than true
137 | if mask_value != None:
138 | mask = np.where(true > (mask_value), True, False)
139 | true = true[mask]
140 | pred = pred[mask]
141 | bias = pred-true
142 | indicator = np.where(bias>0, True, False)
143 | return indicator.mean()
144 |
145 | def oPNBI_np(pred, true, mask_value=None):
146 | #if oPNBI>1, pred are bigger than true
147 | #if oPNBI<1, pred are smaller than true
148 | #however, this metric is too sentive to small values. Not good!
149 | if mask_value != None:
150 | mask = np.where(true > (mask_value), True, False)
151 | true = true[mask]
152 | pred = pred[mask]
153 | bias = (true + pred) / (2 * true)
154 | return bias.mean()
155 |
156 | def MARE_np(pred, true, mask_value=None):
157 | if mask_value != None:
158 | mask = np.where(true> (mask_value), True, False)
159 | true = true[mask]
160 | pred = pred[mask]
161 | return np.divide(np.sum(np.absolute((true - pred))), np.sum(true))
162 |
163 | def CORR_np(pred, true, mask_value=None):
164 | #input B, T, N, D or B, N, D or B, N
165 | if len(pred.shape) == 2:
166 | #B, N
167 | pred = pred.unsqueeze(dim=1).unsqueeze(dim=1)
168 | true = true.unsqueeze(dim=1).unsqueeze(dim=1)
169 | elif len(pred.shape) == 3:
170 | #np.transpose include permute, B, T, N
171 | pred = np.expand_dims(pred.transpose(0, 2, 1), axis=1)
172 | true = np.expand_dims(true.transpose(0, 2, 1), axis=1)
173 | elif len(pred.shape) == 4:
174 | #B, T, N, D -> B, T, D, N
175 | pred = pred.transpose(0, 1, 2, 3)
176 | true = true.transpose(0, 1, 2, 3)
177 | else:
178 | raise ValueError
179 | dims = (0, 1, 2)
180 | pred_mean = pred.mean(axis=dims)
181 | true_mean = true.mean(axis=dims)
182 | pred_std = pred.std(axis=dims)
183 | true_std = true.std(axis=dims)
184 | correlation = ((pred - pred_mean)*(true - true_mean)).mean(axis=dims) / (pred_std*true_std)
185 | index = (true_std != 0)
186 | correlation = (correlation[index]).mean()
187 | return correlation
188 |
189 | def All_Metrics(pred, true, mask1, mask2):
190 | #mask1 filter the very small value, mask2 filter the value lower than a defined threshold
191 | assert type(pred) == type(true)
192 | if type(pred) == np.ndarray:
193 | mae = MAE_np(pred, true, mask1)
194 | rmse = RMSE_np(pred, true, mask1)
195 | mape = MAPE_np(pred, true, mask2)
196 | rrse = RRSE_np(pred, true, mask1)
197 | corr = 0
198 | #corr = CORR_np(pred, true, mask1)
199 | #pnbi = PNBI_np(pred, true, mask1)
200 | #opnbi = oPNBI_np(pred, true, mask2)
201 | elif type(pred) == torch.Tensor:
202 | mae = MAE_torch(pred, true, mask1)
203 | rmse = RMSE_torch(pred, true, mask1)
204 | mape = MAPE_torch(pred, true, mask2)
205 | rrse = RRSE_torch(pred, true, mask1)
206 | corr = CORR_torch(pred, true, mask1)
207 | #pnbi = PNBI_torch(pred, true, mask1)
208 | #opnbi = oPNBI_torch(pred, true, mask2)
209 | else:
210 | raise TypeError
211 | return mae, rmse, mape, rrse, corr
212 |
213 | def SIGIR_Metrics(pred, true, mask1, mask2):
214 | rrse = RRSE_torch(pred, true, mask1)
215 | corr = CORR_torch(pred, true, 0)
216 | return rrse, corr
217 |
218 | if __name__ == '__main__':
219 | pred = torch.Tensor([1, 2, 3,4])
220 | true = torch.Tensor([2, 1, 4,5])
221 | print(All_Metrics(pred, true, None, None))
222 |
--------------------------------------------------------------------------------
/lib/normalization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class NScaler(object):
6 | def transform(self, data):
7 | return data
8 | def inverse_transform(self, data):
9 | return data
10 |
11 | class StandardScaler:
12 | """
13 | Standard the input
14 | """
15 |
16 | def __init__(self, mean, std):
17 | self.mean = mean
18 | self.std = std
19 |
20 | def transform(self, data):
21 | return (data - self.mean) / self.std
22 |
23 | def inverse_transform(self, data):
24 | if type(data) == torch.Tensor and type(self.mean) == np.ndarray:
25 | self.std = torch.from_numpy(self.std).to(data.device).type(data.dtype)
26 | self.mean = torch.from_numpy(self.mean).to(data.device).type(data.dtype)
27 | return (data * self.std) + self.mean
28 |
29 | class MinMax01Scaler:
30 | """
31 | Standard the input
32 | """
33 |
34 | def __init__(self, min, max):
35 | self.min = min
36 | self.max = max
37 |
38 | def transform(self, data):
39 | return (data - self.min) / (self.max - self.min)
40 |
41 | def inverse_transform(self, data):
42 | if type(data) == torch.Tensor and type(self.min) == np.ndarray:
43 | self.min = torch.from_numpy(self.min).to(data.device).type(data.dtype)
44 | self.max = torch.from_numpy(self.max).to(data.device).type(data.dtype)
45 | return (data * (self.max - self.min) + self.min)
46 |
47 | class MinMax11Scaler:
48 | """
49 | Standard the input
50 | """
51 |
52 | def __init__(self, min, max):
53 | self.min = min
54 | self.max = max
55 |
56 | def transform(self, data):
57 | return ((data - self.min) / (self.max - self.min)) * 2. - 1.
58 |
59 | def inverse_transform(self, data):
60 | if type(data) == torch.Tensor and type(self.min) == np.ndarray:
61 | self.min = torch.from_numpy(self.min).to(data.device).type(data.dtype)
62 | self.max = torch.from_numpy(self.max).to(data.device).type(data.dtype)
63 | return ((data + 1.) / 2.) * (self.max - self.min) + self.min
64 |
65 | class ColumnMinMaxScaler():
66 | #Note: to use this scale, must init the min and max with column min and column max
67 | def __init__(self, min, max):
68 | self.min = min
69 | self.min_max = max - self.min
70 | self.min_max[self.min_max==0] = 1
71 | def transform(self, data):
72 | print(data.shape, self.min_max.shape)
73 | return (data - self.min) / self.min_max
74 |
75 | def inverse_transform(self, data):
76 | if type(data) == torch.Tensor and type(self.min) == np.ndarray:
77 | self.min_max = torch.from_numpy(self.min_max).to(data.device).type(torch.float32)
78 | self.min = torch.from_numpy(self.min).to(data.device).type(torch.float32)
79 | #print(data.dtype, self.min_max.dtype, self.min.dtype)
80 | return (data * self.min_max + self.min)
81 |
82 | def one_hot_by_column(data):
83 | #data is a 2D numpy array
84 | len = data.shape[0]
85 | for i in range(data.shape[1]):
86 | column = data[:, i]
87 | max = column.max()
88 | min = column.min()
89 | #print(len, max, min)
90 | zero_matrix = np.zeros((len, max-min+1))
91 | zero_matrix[np.arange(len), column-min] = 1
92 | if i == 0:
93 | encoded = zero_matrix
94 | else:
95 | encoded = np.hstack((encoded, zero_matrix))
96 | return encoded
97 |
98 |
99 | def minmax_by_column(data):
100 | # data is a 2D numpy array
101 | for i in range(data.shape[1]):
102 | column = data[:, i]
103 | max = column.max()
104 | min = column.min()
105 | column = (column - min) / (max - min)
106 | column = column[:, np.newaxis]
107 | if i == 0:
108 | _normalized = column
109 | else:
110 | _normalized = np.hstack((_normalized, column))
111 | return _normalized
112 |
113 |
114 | if __name__ == '__main__':
115 |
116 |
117 | test_data = np.array([[0,0,0, 1], [0, 1, 3, 2], [0, 2, 1, 3]])
118 | print(test_data)
119 | minimum = test_data.min(axis=1)
120 | print(minimum, minimum.shape, test_data.shape)
121 | maximum = test_data.max(axis=1)
122 | print(maximum)
123 | print(test_data-minimum)
124 | test_data = (test_data-minimum) / (maximum-minimum)
125 | print(test_data)
126 | print(0 == 0)
127 | print(0.00 == 0)
128 | print(0 == 0.00)
129 | #print(one_hot_by_column(test_data))
130 | #print(minmax_by_column(test_data))
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | ###---Enable Dropout---###
4 | def enable_dropout(model):
5 | """ Function to enable the dropout layers during test-time """
6 | for m in model.modules():
7 | if m.__class__.__name__.startswith('Dropout'):
8 | m.train()
9 |
10 | ###---Save model---###
11 | def save_model_(model,model_name,dataset,pre_len):
12 | torch.save({'model_state_dict': model.state_dict(),
13 | }, f"check_points/{model_name}_{dataset}_{str(pre_len)}.pth" )
14 | print("Model saved!")
15 |
16 | ###---Load model---###
17 | def load_model_(model,model_name,dataset,pre_len):
18 | PATH1 = f"check_points/{model_name}_{dataset}_{str(pre_len)}.pth"
19 | checkpoint = torch.load(PATH1)
20 | model.load_state_dict(checkpoint['model_state_dict'])
21 | print("Model loaded!")
22 | return model
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | import os
5 | import sys
6 |
7 | import torch
8 | import numpy as np
9 | import torch.nn as nn
10 | import argparse
11 | import configparser
12 | from datetime import datetime
13 | from model.AGCRN_UQ import AGCRN_UQ as Network
14 | from model.BasicTrainer import Trainer
15 | from lib.TrainInits import init_seed
16 | from lib.dataloader import get_dataloader
17 | from lib.TrainInits import print_model_parameters
18 | from tqdm import tqdm
19 | from copy import deepcopy
20 | import torch.nn.functional as F
21 | import torchcontrib
22 | from model.train_methods import awa_train_combined, train_cali_mcx
23 | from model.test_methods import combined_test
24 |
25 |
26 | #*************************************************************************#
27 | Mode = 'train'
28 | DEBUG = 'True'
29 | DATASET = 'PEMS08' #PEMS04/8/3/7
30 | DEVICE = 'cuda:0'
31 | MODEL = 'AGCRN'
32 | MODEL_NAME = "combined"#"combined/basic/dropout/heter"
33 |
34 | #get configuration
35 | config_file = 'model/{}_{}.conf'.format(DATASET, MODEL)
36 | #print('Read configuration file: %s' % (config_file))
37 | config = configparser.ConfigParser()
38 | config.read(config_file)
39 |
40 | from lib.utils import enable_dropout,save_model_,load_model_
41 | from lib.metrics import All_Metrics
42 |
43 |
44 |
45 | from lib.metrics import MAE_torch
46 | def masked_mae_loss(scaler, mask_value):
47 | def loss(preds, labels):
48 | if scaler:
49 | preds = scaler.inverse_transform(preds)
50 | labels = scaler.inverse_transform(labels)
51 | mae = MAE_torch(pred=preds, true=labels, mask_value=mask_value)
52 | return mae
53 | return loss
54 |
55 | #parser
56 | args = argparse.ArgumentParser(description='arguments')
57 | args.add_argument('--dataset', default=DATASET, type=str)
58 | args.add_argument('--mode', default=Mode, type=str)
59 | args.add_argument('--device', default=DEVICE, type=str, help='indices of GPUs')
60 | args.add_argument('--debug', default=DEBUG, type=eval)
61 | args.add_argument('--model', default=MODEL, type=str)
62 | args.add_argument('--cuda', default=True, type=bool)
63 | #data
64 | args.add_argument('--val_ratio', default=config['data']['val_ratio'], type=float)
65 | args.add_argument('--test_ratio', default=config['data']['test_ratio'], type=float)
66 | args.add_argument('--lag', default=config['data']['lag'], type=int)
67 | args.add_argument('--horizon', default=config['data']['horizon'], type=int)
68 | args.add_argument('--num_nodes', default=config['data']['num_nodes'], type=int)
69 | args.add_argument('--tod', default=config['data']['tod'], type=eval)
70 | args.add_argument('--normalizer', default=config['data']['normalizer'], type=str)
71 | args.add_argument('--column_wise', default=config['data']['column_wise'], type=eval)
72 | args.add_argument('--default_graph', default=config['data']['default_graph'], type=eval)
73 | #model
74 | args.add_argument('--input_dim', default=config['model']['input_dim'], type=int)
75 | args.add_argument('--output_dim', default=config['model']['output_dim'], type=int)
76 | args.add_argument('--embed_dim', default=config['model']['embed_dim'], type=int)
77 | args.add_argument('--rnn_units', default=config['model']['rnn_units'], type=int)
78 | args.add_argument('--num_layers', default=config['model']['num_layers'], type=int)
79 | args.add_argument('--cheb_k', default=config['model']['cheb_order'], type=int)
80 | args.add_argument('--p1', default=config['model']['p1'], type=float)
81 |
82 | #train
83 | args.add_argument('--loss_func', default=config['train']['loss_func'], type=str)
84 | #args.add_argument('--loss_func', default='mse', type=str)
85 | args.add_argument('--seed', default=config['train']['seed'], type=int)
86 | args.add_argument('--batch_size', default=config['train']['batch_size'], type=int)
87 | args.add_argument('--epochs', default=config['train']['epochs'], type=int)
88 | #args.add_argument('--epochs', default=500, type=int)
89 | args.add_argument('--lr_init', default=config['train']['lr_init'], type=float)
90 | #args.add_argument('--lr_init', default=1e-2, type=float)
91 | args.add_argument('--lr_decay', default=config['train']['lr_decay'], type=eval)
92 | args.add_argument('--lr_decay_rate', default=config['train']['lr_decay_rate'], type=float)
93 | args.add_argument('--lr_decay_step', default=config['train']['lr_decay_step'], type=str)
94 | args.add_argument('--early_stop', default=config['train']['early_stop'], type=eval)
95 | args.add_argument('--early_stop_patience', default=config['train']['early_stop_patience'], type=int)
96 | args.add_argument('--grad_norm', default=config['train']['grad_norm'], type=eval)
97 | args.add_argument('--max_grad_norm', default=config['train']['max_grad_norm'], type=int)
98 | args.add_argument('--teacher_forcing', default=False, type=bool)
99 | args.add_argument('--tf_decay_steps', default=2000, type=int, help='teacher forcing decay steps')
100 | args.add_argument('--real_value', default=config['train']['real_value'], type=eval, help = 'use real value for loss calculation')
101 | #test
102 | args.add_argument('--mae_thresh', default=config['test']['mae_thresh'], type=eval)
103 | args.add_argument('--mape_thresh', default=config['test']['mape_thresh'], type=float)
104 | #log
105 | args.add_argument('--log_dir', default='./', type=str)
106 | args.add_argument('--log_step', default=config['log']['log_step'], type=int)
107 | args.add_argument('--plot', default=config['log']['plot'], type=eval)
108 | args.add_argument('--model_name', default=MODEL_NAME, type=str)
109 |
110 |
111 |
112 | args = args.parse_args([])
113 | init_seed(args.seed)
114 | if torch.cuda.is_available():
115 | torch.cuda.set_device(int(args.device[5]))
116 | else:
117 | args.device = 'cpu'
118 |
119 | #init model
120 | model = Network(args).to(args.device)
121 | for p in model.parameters():
122 | if p.dim() > 1:
123 | nn.init.xavier_uniform_(p)
124 | else:
125 | nn.init.uniform_(p)
126 | print_model_parameters(model, only_num=False)
127 |
128 | #load dataset
129 | train_loader, val_loader, test_loader, scaler = get_dataloader(args,
130 | normalizer=args.normalizer,
131 | tod=args.tod, dow=False,
132 | weather=False, single=False)
133 | #init loss function, optimizer
134 | if args.loss_func == 'mask_mae':
135 | loss = masked_mae_loss(scaler, mask_value=0.0)
136 | elif args.loss_func == 'mae':
137 | loss = torch.nn.L1Loss().to(args.device)
138 | elif args.loss_func == 'mse':
139 | loss = torch.nn.MSELoss().to(args.device)
140 | else:
141 | raise ValueError
142 |
143 | optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr_init, eps=1.0e-8,
144 | weight_decay=1e-6, amsgrad=False)
145 |
146 | #learning rate decay
147 | lr_scheduler = None
148 | if args.lr_decay:
149 | print('Applying learning rate decay.')
150 | lr_decay_steps = [int(i) for i in list(args.lr_decay_step.split(','))]
151 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
152 | milestones=lr_decay_steps,
153 | gamma=args.lr_decay_rate)
154 | #start training
155 | trainer = Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
156 | args, lr_scheduler=lr_scheduler)
157 |
158 |
159 | ### Pre-traning
160 | trainer.train()
161 |
162 |
163 | ### AWA Re-training
164 |
165 | #trainer.model = awa_train_combined(trainer,epochs=20)
166 |
167 |
168 | ### Model Calibration
169 |
170 | T = train_cali_mc(trainer.model,10, args, val_loader, scaler)
171 | combined_test(trainer.model,10,trainer.args, trainer.test_loader, scaler,T)
172 |
173 |
174 |
175 |
176 |
177 |
--------------------------------------------------------------------------------
/model/AGCN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 | class AVWGCN(nn.Module):
6 | def __init__(self, model_name,p1,dim_in, dim_out, cheb_k, embed_dim):
7 | super(AVWGCN, self).__init__()
8 | self.cheb_k = cheb_k
9 | self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
10 | self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
11 |
12 | self.model_name = model_name
13 | self.p1= p1
14 |
15 | def forward(self, x, node_embeddings):
16 | #x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N]
17 | #output shape [B, N, C]
18 | node_num = node_embeddings.shape[0]
19 | supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1)
20 | support_set = [torch.eye(node_num).to(supports.device), supports]
21 | #default cheb_k = 3
22 | for k in range(2, self.cheb_k):
23 | support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2])
24 | supports = torch.stack(support_set, dim=0)
25 |
26 |
27 | weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out
28 | bias = torch.matmul(node_embeddings, self.bias_pool) #N, dim_out
29 | x_g = torch.einsum("knm,bmc->bknc", supports, x) #B, cheb_k, N, dim_in
30 | x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in
31 |
32 |
33 | x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out
34 | ### add dropout
35 |
36 | if self.model_name == "dropout" or self.model_name == "combined":
37 | x_gconv = F.dropout(x_gconv, self.p1) #pems04=0.1/pems08=0.05
38 |
39 | return x_gconv
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/model/AGCRNCell.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from model.AGCN import AVWGCN
4 |
5 | class AGCRNCell(nn.Module):
6 | def __init__(self, model_name, p1, node_num, dim_in, dim_out, cheb_k, embed_dim):
7 | super(AGCRNCell, self).__init__()
8 | self.node_num = node_num
9 | self.hidden_dim = dim_out
10 | self.gate = AVWGCN(model_name, p1, dim_in+self.hidden_dim, 2*dim_out, cheb_k, embed_dim)
11 | self.update = AVWGCN(model_name, p1, dim_in+self.hidden_dim, dim_out, cheb_k, embed_dim)
12 |
13 | def forward(self, x, state, node_embeddings):
14 | #x: B, num_nodes, input_dim
15 | #state: B, num_nodes, hidden_dim
16 | state = state.to(x.device)
17 | input_and_state = torch.cat((x, state), dim=-1)
18 | z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))
19 | z, r = torch.split(z_r, self.hidden_dim, dim=-1)
20 | candidate = torch.cat((x, z*state), dim=-1)
21 | hc = torch.tanh(self.update(candidate, node_embeddings))
22 | h = r*state + (1-r)*hc
23 |
24 | #print(x.shape, state.shape, input_and_state.shape)
25 | return h
26 |
27 | def init_hidden_state(self, batch_size):
28 | return torch.zeros(batch_size, self.node_num, self.hidden_dim)
--------------------------------------------------------------------------------
/model/AGCRN_UQ.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.distributions as td
5 | from model.AGCRNCell import AGCRNCell
6 |
7 | class AVWDCRNN(nn.Module):
8 | def __init__(self, model_name, p1,node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
9 | super(AVWDCRNN, self).__init__()
10 | assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.'
11 | self.node_num = node_num
12 | self.input_dim = dim_in
13 | self.num_layers = num_layers
14 | self.dcrnn_cells = nn.ModuleList()
15 | self.dcrnn_cells.append(AGCRNCell(model_name, p1,node_num, dim_in, dim_out, cheb_k, embed_dim))
16 | for _ in range(1, num_layers):
17 | self.dcrnn_cells.append(AGCRNCell(model_name, p1,node_num, dim_out, dim_out, cheb_k, embed_dim))
18 |
19 | def forward(self, x, init_state, node_embeddings):
20 | #shape of x: (B, T, N, D)
21 | #shape of init_state: (num_layers, B, N, hidden_dim)
22 | assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
23 | seq_length = x.shape[1]
24 | current_inputs = x
25 | output_hidden = []
26 | for i in range(self.num_layers):
27 | state = init_state[i]
28 | inner_states = []
29 | for t in range(seq_length):
30 | state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, node_embeddings)
31 | inner_states.append(state)
32 | output_hidden.append(state)
33 | current_inputs = torch.stack(inner_states, dim=1)
34 | #current_inputs: the outputs of last layer: (B, T, N, hidden_dim)
35 | #output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim)
36 | #last_state: (B, N, hidden_dim)
37 | return current_inputs, output_hidden
38 |
39 | def init_hidden(self, batch_size):
40 | init_states = []
41 | for i in range(self.num_layers):
42 | init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size))
43 | return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim)
44 |
45 |
46 | ###========Main model=========
47 | class AGCRN_UQ(nn.Module):
48 | def __init__(self, args):
49 | super(AGCRN_UQ, self).__init__()
50 | self.num_node = args.num_nodes
51 | self.input_dim = args.input_dim
52 | self.hidden_dim = args.rnn_units
53 | self.output_dim = args.output_dim
54 | self.horizon = args.horizon
55 | self.num_layers = args.num_layers
56 |
57 | ###
58 | self.model_name = args.model_name
59 | self.p1= args.p1
60 |
61 |
62 | self.default_graph = args.default_graph
63 | self.node_embeddings = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
64 |
65 |
66 | self.encoder = AVWDCRNN(self.model_name, self.p1, args.num_nodes, args.input_dim, args.rnn_units, args.cheb_k,
67 | args.embed_dim, args.num_layers)
68 |
69 |
70 | if self.model_name == "combined":
71 | self.get_mu = nn.Sequential(
72 | nn.Conv2d(1, 32, kernel_size=(1,1), bias=True),
73 | nn.Dropout(0.2),
74 | nn.ReLU(),
75 | nn.Conv2d(32, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True),
76 | )
77 |
78 | self.get_log_var = nn.Sequential(
79 | nn.Conv2d(1, 32, kernel_size=(1,1), bias=True),
80 | nn.Dropout(0.2),
81 | nn.ReLU(),
82 | nn.Conv2d(32, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True),
83 | )
84 |
85 |
86 | def forward(self, source, targets, teacher_forcing_ratio=0.5):
87 | #source: B, T_1, N, D
88 | #target: B, T_2, N, D
89 | #supports = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec1.transpose(0,1))), dim=1)
90 |
91 | emb = self.node_embeddings
92 | init_state = self.encoder.init_hidden(source.shape[0])
93 | output, _ = self.encoder(source, init_state, emb) #B, T, N, hidden
94 | #print(output.shape)
95 | output = output[:, -1:, :, :]
96 |
97 | #CNN based predictor
98 |
99 | if self.model_name == "combined":
100 | mu = self.get_mu((output)) #B, T*C, N, 1
101 | mu = mu.squeeze(-1).reshape(-1, self.horizon, self.output_dim, self.num_node)
102 | mu = mu.permute(0, 1, 3, 2) #B, T, N, C
103 |
104 | log_var = self.get_log_var((output)) #B, T*C, N, 1
105 | log_var = log_var.squeeze(-1).reshape(-1, self.horizon, self.output_dim, self.num_node)
106 | log_var = log_var.permute(0, 1, 3, 2)
107 | return mu, log_var
108 |
109 |
110 |
111 |
112 |
113 |
114 |
--------------------------------------------------------------------------------
/model/BasicTrainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import os
4 | import time
5 | import copy
6 | import numpy as np
7 | #from lib.logger import get_logger
8 | from lib.metrics import All_Metrics
9 | from lib.utils import enable_dropout
10 | import torch.nn.functional as F
11 | from torch.autograd import grad
12 | import torch.nn as nn
13 |
14 | from torch.optim.swa_utils import AveragedModel, SWALR
15 | from torch.optim.lr_scheduler import CosineAnnealingLR
16 | #from model.train_methods import QuantileLoss
17 |
18 |
19 | class Trainer(object):
20 | def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader,
21 | scaler, args, lr_scheduler=None):
22 | super().__init__()
23 | self.model = model
24 | self.loss = loss
25 | self.optimizer = optimizer
26 | self.train_loader = train_loader
27 | self.val_loader = val_loader
28 | self.test_loader = test_loader
29 | self.scaler = scaler
30 | self.args = args
31 | self.lr_scheduler = lr_scheduler
32 | self.train_per_epoch = len(train_loader)
33 | #self.quantile_loss = QuantileLoss()
34 |
35 | def val_epoch(self, epoch, val_dataloader):
36 | self.model.eval()
37 | total_val_loss = 0
38 |
39 | with torch.no_grad():
40 | for batch_idx, (data, target) in enumerate(val_dataloader):
41 | data = data[..., :self.args.input_dim]
42 | label = target[..., :self.args.output_dim]
43 | if self.args.real_value:
44 | label = self.scaler.inverse_transform(label)
45 |
46 | if self.args.model_name == "basic" or self.args.model_name=="dropout":
47 | output = self.model(data, target, teacher_forcing_ratio=0.)
48 | loss = self.loss(output.cuda(), label)
49 |
50 | elif self.args.model_name == "heter" or self.args.model_name=="combined":
51 | output, _ = self.model(data, target, teacher_forcing_ratio=0.)
52 | loss = self.loss(output.cuda(), label)
53 |
54 | elif self.args.model_name == "quantile":
55 | output = self.model(data, target, teacher_forcing_ratio=0.)
56 | loss = self.quantile_loss(output,label)
57 |
58 | #a whole batch of Metr_LA is filtered
59 | if not torch.isnan(loss):
60 | total_val_loss += loss.item()
61 | val_loss = total_val_loss / len(val_dataloader)
62 | print('**********Val Epoch {}: average Loss: {:.6f}'.format(epoch, val_loss))
63 | return val_loss
64 |
65 | def train_epoch(self, epoch):
66 | self.model.train()
67 | total_loss = 0
68 | for batch_idx, (data, target) in enumerate(self.train_loader):
69 | data = data[..., :self.args.input_dim]
70 | label = target[..., :self.args.output_dim] # (..., 1)
71 |
72 | #label = torch.log(label)
73 |
74 | self.optimizer.zero_grad()
75 |
76 | #teacher_forcing for RNN encoder-decoder model
77 | #if teacher_forcing_ratio = 1: use label as input in the decoder for all steps
78 | if self.args.teacher_forcing:
79 | global_step = (epoch - 1) * self.train_per_epoch + batch_idx
80 | teacher_forcing_ratio = self._compute_sampling_threshold(global_step, self.args.tf_decay_steps)
81 | else:
82 | teacher_forcing_ratio = 1.
83 | #data and target shape: B, T, N, F; output shape: B, T, N, F
84 |
85 | if self.args.real_value:
86 | label = self.scaler.inverse_transform(label)
87 |
88 | if self.args.model_name == "basic" or self.args.model_name=="dropout":
89 | index = torch.randperm(data.shape[1])
90 | data = data[:,index,:,:]
91 |
92 | output = self.model(data, target, teacher_forcing_ratio=0.)
93 | loss = self.loss(output.cuda(), label)
94 |
95 | elif self.args.model_name == "heter" or self.args.model_name=="combined":
96 | mu, log_var = self.model(data, target, teacher_forcing_ratio=teacher_forcing_ratio)
97 | loss = torch.mean(torch.exp(-log_var)*(label-mu)**2 + log_var)
98 | loss = 0.1*loss + 0.9*self.loss(mu, label)
99 |
100 | elif self.args.model_name == "quantile":
101 | output = self.model(data, target, teacher_forcing_ratio=teacher_forcing_ratio)
102 | loss = self.quantile_loss(output,label)
103 |
104 | loss.backward()
105 |
106 | # add max grad clipping
107 | if self.args.grad_norm:
108 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
109 | self.optimizer.step()
110 | total_loss += loss.item()
111 |
112 | #log information
113 | if batch_idx % self.args.log_step == 0:
114 | print('Train Epoch {}: {}/{} Loss: {:.6f}'.format(
115 | epoch, batch_idx, self.train_per_epoch, loss.item()))
116 | train_epoch_loss = total_loss/self.train_per_epoch
117 | #self.logger.info('**********Train Epoch {}: averaged Loss: {:.6f}, tf_ratio: {:.6f}'.format(epoch, train_epoch_loss, teacher_forcing_ratio))
118 |
119 | #learning rate decay
120 | if self.args.lr_decay:
121 | self.lr_scheduler.step()
122 | return train_epoch_loss
123 |
124 | def train(self):
125 |
126 | best_model = None
127 | best_loss = float('inf')
128 | not_improved_count = 0
129 | train_loss_list = []
130 | val_loss_list = []
131 | start_time = time.time()
132 | for epoch in range(1, self.args.epochs + 1):
133 | #epoch_time = time.time()
134 | train_epoch_loss = self.train_epoch(epoch)
135 | #print(time.time()-epoch_time)
136 | #exit()
137 | if self.val_loader == None:
138 | val_dataloader = self.test_loader
139 | else:
140 | val_dataloader = self.val_loader
141 | val_epoch_loss = self.val_epoch(epoch, val_dataloader)
142 |
143 | #print('LR:', self.optimizer.param_groups[0]['lr'])
144 | train_loss_list.append(train_epoch_loss)
145 | val_loss_list.append(val_epoch_loss)
146 | if train_epoch_loss > 1e6:
147 | print('Gradient explosion detected. Ending...')
148 | break
149 | #if self.val_loader == None:
150 | #val_epoch_loss = train_epoch_loss
151 | if val_epoch_loss < best_loss:
152 | best_loss = val_epoch_loss
153 | not_improved_count = 0
154 | best_state = True
155 | else:
156 | not_improved_count += 1
157 | best_state = False
158 | # early stop
159 | if self.args.early_stop:
160 | if not_improved_count == self.args.early_stop_patience:
161 | print("Validation performance didn\'t improve for {} epochs. "
162 | "Training stops.".format(self.args.early_stop_patience))
163 | break
164 | # save the best state
165 | if best_state == True:
166 | print('*********************************Current best model saved!')
167 | best_model = copy.deepcopy(self.model.state_dict())
168 |
169 |
170 | training_time = time.time() - start_time
171 | print("Total training time: {:.4f}min, best loss: {:.6f}".format((training_time / 60), best_loss))
172 |
173 | #save the best model to file
174 | if not self.args.debug:
175 | torch.save(best_model, self.best_path)
176 | print("Saving current best model to " + self.best_path)
177 |
178 | #test
179 | self.model.load_state_dict(best_model)
180 | #self.val_epoch(self.args.epochs, self.test_loader)
181 | #self.test(self.model, self.args, self.test_loader, self.scaler, self.logger=None)
182 |
183 | @staticmethod
184 | def test(model, args, data_loader, scaler, logger=None, path=None):
185 | if path != None:
186 | check_point = torch.load(path)
187 | state_dict = check_point['state_dict']
188 | args = check_point['config']
189 | model.load_state_dict(state_dict)
190 | model.to(args.device)
191 |
192 | model.eval()
193 | #enable_dropout(model)
194 |
195 | y_pred = []
196 | y_true = []
197 | with torch.no_grad():
198 | for batch_idx, (data, target) in enumerate(data_loader):
199 | data = data[..., :args.input_dim]
200 | label = target[..., :args.output_dim]
201 | #output, z, mu, log_sigma = model.sample_code_(data)
202 |
203 | #label = torch.log(label)
204 |
205 | output = model(data, target, teacher_forcing_ratio=0)
206 | #output, *_ = model(data, target, teacher_forcing_ratio=0)
207 | y_true.append(label)
208 | y_pred.append(output)
209 | y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
210 | if args.real_value:
211 | y_pred = torch.cat(y_pred, dim=0)
212 | else:
213 | y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
214 | #np.save('./{}_true.npy'.format(args.dataset), y_true.cpu().numpy())
215 | #np.save('./{}_pred.npy'.format(args.dataset), y_pred.cpu().numpy())
216 | for t in range(y_true.shape[1]):
217 | mae, rmse, mape, _, _ = All_Metrics(y_pred[:, t, ...], y_true[:, t, ...],
218 | args.mae_thresh, args.mape_thresh)
219 | print("Horizon {:02d}, MAE: {:.2f}, RMSE: {:.2f}, MAPE: {:.4f}%".format(
220 | t + 1, mae, rmse, mape*100))
221 | mae, rmse, mape, _, _ = All_Metrics(y_pred, y_true, args.mae_thresh, args.mape_thresh)
222 | print("Average Horizon, MAE: {:.2f}, RMSE: {:.2f}, MAPE: {:.4f}%".format(
223 | mae, rmse, mape*100))
224 |
--------------------------------------------------------------------------------
/model/PEMS03_AGCRN.conf:
--------------------------------------------------------------------------------
1 | [data]
2 | num_nodes = 358
3 | lag = 12
4 | horizon = 12
5 | val_ratio = 0.2
6 | test_ratio = 0.2
7 | tod = False
8 | normalizer = std
9 | column_wise = False
10 | default_graph = True
11 |
12 | [model]
13 | input_dim = 1
14 | output_dim = 1
15 | embed_dim = 10
16 | rnn_units = 64
17 | num_layers = 2
18 | cheb_order = 2
19 | p1 = 0.1
20 |
21 | [train]
22 | loss_func = mae
23 | seed = 10
24 | batch_size = 64
25 | epochs = 100
26 | lr_init = 0.003
27 | lr_decay = False
28 | lr_decay_rate = 0.3
29 | lr_decay_step = 5,20,40,70
30 | early_stop = True
31 | early_stop_patience = 15
32 | grad_norm = False
33 | max_grad_norm = 5
34 | real_value = True
35 |
36 | [test]
37 | mae_thresh = None
38 | mape_thresh = 0.001
39 |
40 | [log]
41 | log_step = 20
42 | plot = False
--------------------------------------------------------------------------------
/model/PEMS04_AGCRN.conf:
--------------------------------------------------------------------------------
1 | [data]
2 | num_nodes = 307
3 | lag = 12
4 | horizon = 12
5 | val_ratio = 0.2
6 | test_ratio = 0.2
7 | tod = False
8 | normalizer = std
9 | column_wise = False
10 | default_graph = True
11 |
12 | [model]
13 | input_dim = 1
14 | output_dim = 1
15 | embed_dim = 10
16 | rnn_units = 64
17 | num_layers = 2
18 | cheb_order = 2
19 | p1 = 0.1
20 |
21 | [train]
22 | loss_func = mae
23 | seed = 10
24 | batch_size = 64
25 | epochs = 100
26 | lr_init = 0.003
27 | lr_decay = False
28 | lr_decay_rate = 0.3
29 | lr_decay_step = 5,20,40,70
30 | early_stop = True
31 | early_stop_patience = 15
32 | grad_norm = False
33 | max_grad_norm = 5
34 | real_value = True
35 |
36 | [test]
37 | mae_thresh = None
38 | mape_thresh = 0.
39 |
40 | [log]
41 | log_step = 20
42 | plot = False
--------------------------------------------------------------------------------
/model/PEMS07_AGCRN.conf:
--------------------------------------------------------------------------------
1 | [data]
2 | num_nodes = 883
3 | lag = 12
4 | horizon = 12
5 | val_ratio = 0.2
6 | test_ratio = 0.2
7 | tod = False
8 | normalizer = std
9 | column_wise = False
10 | default_graph = True
11 |
12 | [model]
13 | input_dim = 1
14 | output_dim = 1
15 | embed_dim = 10
16 | rnn_units = 64
17 | num_layers = 2
18 | cheb_order = 2
19 | p1 = 0.1
20 |
21 | [train]
22 | loss_func = mae
23 | seed = 10
24 | batch_size = 32
25 | epochs = 100
26 | lr_init = 0.003
27 | lr_decay = False
28 | lr_decay_rate = 0.3
29 | lr_decay_step = 5,20,40,70
30 | early_stop = True
31 | early_stop_patience = 15
32 | grad_norm = False
33 | max_grad_norm = 5
34 | real_value = True
35 |
36 | [test]
37 | mae_thresh = None
38 | mape_thresh = 0.
39 |
40 | [log]
41 | log_step = 20
42 | plot = False
--------------------------------------------------------------------------------
/model/PEMS08_AGCRN.conf:
--------------------------------------------------------------------------------
1 | [data]
2 | num_nodes = 170
3 | lag = 12
4 | horizon = 12
5 | val_ratio = 0.2
6 | test_ratio = 0.2
7 | tod = False
8 | normalizer = std
9 | column_wise = False
10 | default_graph = True
11 |
12 | [model]
13 | input_dim = 1
14 | output_dim = 1
15 | embed_dim = 2
16 | rnn_units = 64
17 | num_layers = 2
18 | cheb_order = 2
19 | p1 = 0.05
20 |
21 | [train]
22 | loss_func = mae
23 | seed = 12
24 | batch_size = 64
25 | epochs = 100
26 | lr_init = 0.003
27 | lr_decay = False
28 | lr_decay_rate = 0.3
29 | lr_decay_step = 5,20,40,70
30 | early_stop = True
31 | early_stop_patience = 15
32 | grad_norm = False
33 | max_grad_norm = 5
34 | real_value = True
35 |
36 | [test]
37 | mae_thresh = None
38 | mape_thresh = 0.
39 |
40 | [log]
41 | log_step = 20
42 | plot = False
--------------------------------------------------------------------------------
/model/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/test_methods.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from lib.utils import enable_dropout
4 | from lib.metrics import All_Metrics
5 | from tqdm import tqdm
6 |
7 | ####======MC+Heter========####
8 | def combined_test(model,num_samples,args, data_loader, scaler, T=torch.zeros(1).cuda(), logger=None, path=None):
9 | model.eval()
10 | enable_dropout(model)
11 | nll_fun = nn.GaussianNLLLoss()
12 | y_true = []
13 | with torch.no_grad():
14 | for batch_idx, (_, target) in enumerate(data_loader):
15 | label = target[..., :args.output_dim]
16 | y_true.append(label)
17 | y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)).squeeze(3)
18 |
19 | mc_mus = torch.empty(0, y_true.size(0), y_true.size(1), y_true.size(2)).cuda()
20 | mc_log_vars = torch.empty(0, y_true.size(0),y_true.size(1), y_true.size(2)).cuda()
21 |
22 | with torch.no_grad():
23 | for i in tqdm(range(num_samples)):
24 | mu_pred = []
25 | log_var_pred = []
26 | for batch_idx, (data, _) in enumerate(data_loader):
27 | data = data[..., :args.input_dim]
28 | mu, log_var = model.forward(data, target, teacher_forcing_ratio=0)
29 | #print(mu.size())
30 | mu_pred.append(mu.squeeze(3))
31 | log_var_pred.append(log_var.squeeze(3))
32 |
33 | if args.real_value:
34 | mu_pred = torch.cat(mu_pred, dim=0)
35 | else:
36 | mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0))
37 | log_var_pred = torch.cat(log_var_pred, dim=0)
38 |
39 | #print(mc_mus.size(),mu_pred.size())
40 | mc_mus = torch.vstack((mc_mus,mu_pred.unsqueeze(0)))
41 | mc_log_vars = torch.vstack((mc_log_vars,log_var_pred.unsqueeze(0)))
42 |
43 | temperature = torch.exp(T)
44 | y_pred = torch.mean(mc_mus, axis=0)
45 | total_var = torch.var(mc_mus, axis=0)+torch.exp(torch.mean(mc_log_vars, axis=0))/temperature
46 | total_std = total_var**0.5
47 |
48 | mpiw = 2*1.96*torch.mean(total_std)
49 | nll = nll_fun(y_pred.ravel(), y_true.ravel(), total_var.ravel())
50 | lower_bound = y_pred-1.96*total_std
51 | upper_bound = y_pred+1.96*total_std
52 | in_num = torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ))
53 | picp = in_num/(y_true.size(0)*y_true.size(1)*y_true.size(2))
54 |
55 |
56 | print("Average Horizon, MAE: {:.4f}, RMSE: {:.4f}, MAPE: {:.4f}%, NLL: {:.4f}, \
57 | PICP: {:.4f}%, MPIW: {:.4f}".format(mae, rmse, mape*100, nll, picp*100, mpiw))
58 |
--------------------------------------------------------------------------------
/model/train_methods.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from lib.utils import enable_dropout
4 | from torch.optim.swa_utils import AveragedModel, SWALR
5 | from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts
6 | from tqdm import tqdm
7 |
8 |
9 | ###====Adaptive Weight Averaging====###
10 | def awa_train_combined(trainer,epoch_swa, regularizer=None, lr_schedule=None):
11 | #total_loss = 0
12 | #criterion = torch.nn.L1Loss()#torch.nn.MSELoss()
13 | #loss_sum = 0.0
14 | num_iters = len(trainer.train_loader)
15 |
16 | lr1 = 0.003
17 | lr2 = lr1*0.01
18 |
19 | optimizer_swa = torch.optim.Adam(params=model.parameters(), lr=lr1, betas=(0.9,0.99),
20 | weight_decay=1e-4, amsgrad=False)
21 |
22 | cycle = num_iters
23 | swa_c = 1
24 | swa_model = AveragedModel(trainer.model)
25 | #scheduler = CosineAnnealingLR(optimizer_swa, T_max=cycle,eta_min=0.0)
26 | scheduler_swa = CosineAnnealingWarmRestarts(optimizer_swa,T_0=num_iters,
27 | T_mult=1, eta_min=lr2, last_epoch=-1)
28 |
29 | lr_ls = []
30 | for epoch in tqdm(range(epoch_swa)):
31 | trainer.model.train()
32 | for iter, (data, target) in enumerate(trainer.train_loader):
33 | input = data[..., :trainer.args.input_dim]
34 | label = target[..., :trainer.args.output_dim] # (..., 1)
35 | optimizer_swa.zero_grad()
36 |
37 | input = input.cuda(non_blocking=True)
38 | label = label.cuda(non_blocking=True)
39 |
40 | if trainer.args.teacher_forcing:
41 | global_step = (epoch - 1) * trainer.train_per_epoch + batch_idx
42 | teacher_forcing_ratio = trainer._compute_sampling_threshold(global_step, trainer.args.tf_decay_steps)
43 | else:
44 | teacher_forcing_ratio = 1.
45 | #output,log_var = trainer.model.forward_heter(data, target, teacher_forcing_ratio=teacher_forcing_ratio)
46 | mu,log_var = trainer.model.forward(data, target, teacher_forcing_ratio=0.5)
47 | if trainer.args.real_value:
48 | label = trainer.scaler.inverse_transform(label)
49 | loss = torch.mean(torch.exp(-log_var)*(label-mu)**2 + log_var)
50 | loss = 0.1*loss + 0.9*trainer.loss(mu, label)
51 | #loss = trainer.loss(mu, label)
52 | loss.backward()
53 | optimizer_swa.step()
54 | if (epoch % 2 ==0) & (iter != num_iters-1):
55 | scheduler_swa.step()
56 | else:
57 | optimizer_swa.param_groups[0]["lr"]=lr2
58 | #scheduler.step()
59 | if (epoch+1) % 2 ==0:#) & (epoch !=0):
60 | swa_model.update_parameters(trainer.model)
61 | torch.optim.swa_utils.update_bn(trainer.train_loader, swa_model)
62 |
63 | #swa_scheduler.step()
64 | #scheduler.step()
65 | return swa_model
66 |
67 |
68 | def swa_train(trainer,epoch_swa, regularizer=None, lr_schedule=None):
69 |
70 | num_iters = len(trainer.train_loader)
71 |
72 | lr1 = 0.003#0.1
73 | lr2 = lr1*0.01#0.001
74 |
75 | optimizer_swa = torch.optim.Adam(params=model.parameters(), lr=lr1, betas=(0.9,0.99),
76 | weight_decay=1e-4, amsgrad=False)
77 |
78 | cycle = num_iters
79 | swa_c = 1
80 | swa_model = AveragedModel(trainer.model)
81 | #scheduler = CosineAnnealingLR(optimizer_swa, T_max=cycle,eta_min=0.0)
82 | scheduler_swa = CosineAnnealingWarmRestarts(optimizer_swa,T_0=num_iters,
83 | T_mult=1, eta_min=lr2, last_epoch=-1)
84 | lr_ls = []
85 | for epoch in tqdm(range(epoch_swa)):
86 | trainer.model.train()
87 | for iter, (data, target) in enumerate(trainer.train_loader):
88 | input = data[..., :trainer.args.input_dim]
89 | label = target[..., :trainer.args.output_dim] # (..., 1)
90 | optimizer_swa.zero_grad()
91 |
92 | input = input.cuda(non_blocking=True)
93 | label = label.cuda(non_blocking=True)
94 |
95 | if trainer.args.teacher_forcing:
96 | global_step = (epoch - 1) * trainer.train_per_epoch + batch_idx
97 | teacher_forcing_ratio = trainer._compute_sampling_threshold(global_step, trainer.args.tf_decay_steps)
98 | else:
99 | teacher_forcing_ratio = 1.
100 | #output,log_var = trainer.model.forward_heter(data, target, teacher_forcing_ratio=teacher_forcing_ratio)
101 | output = trainer.model.forward(data, target, teacher_forcing_ratio=0.5)
102 | if trainer.args.real_value:
103 | label = trainer.scaler.inverse_transform(label)
104 | loss = trainer.loss(output, label)
105 | loss.backward()
106 | optimizer_swa.step()
107 | if (epoch % 2 ==0) & (iter != num_iters-1):
108 | scheduler_swa.step()
109 | else:
110 | optimizer_swa.param_groups[0]["lr"]=lr2
111 | #scheduler.step()
112 | if (epoch+1) % 2 ==0:
113 | swa_model.update_parameters(trainer.model)
114 | torch.optim.swa_utils.update_bn(trainer.train_loader, swa_model)
115 |
116 | return swa_model
117 |
118 | ###====Calibration====###
119 | class ModelCali(nn.Module):
120 | def __init__(self,args):
121 | super(ModelCali, self).__init__()
122 | #self.model = model
123 | #self.T = nn.Parameter(torch.ones(args.num_nodes)*1.5, requires_grad=True)
124 | self.T = nn.Parameter(torch.ones(1)*1.5, requires_grad=True)
125 |
126 |
127 | def train_cali(model, args, data_loader, scaler, logger=None, path=None):
128 | model_cali = ModelCali(args).cuda()
129 | optimizer_cali = torch.optim.LBFGS(list(model_cali.parameters()), lr=0.02, max_iter=500)
130 | model.eval()
131 | #nll_fun = nn.GaussianNLLLoss()
132 | y_true = []
133 | mu_pred = []
134 | log_var_pred = []
135 |
136 | with torch.no_grad():
137 | for batch_idx, (data, target) in enumerate(data_loader):
138 | data = data[..., :args.input_dim]
139 | label = target[..., :args.output_dim]
140 | mu, log_var = model.forward(data, target, teacher_forcing_ratio=0)
141 | mu_pred.append(mu)
142 | log_var_pred.append(log_var)
143 | y_true.append(label)
144 | if args.real_value:
145 | mu_pred = torch.cat(mu_pred, dim=0)
146 | else:
147 | mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0))
148 | log_var_pred = torch.cat(log_var_pred, dim=0)
149 | y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
150 |
151 | y_pred = mu_pred
152 | precision = torch.exp(-log_var_pred)
153 | def eval_():
154 | optimizer_cali.zero_grad()
155 | #loss(input, target, var)
156 | temperature = torch.exp(model_cali.T)
157 | #loss = nll_fun(y_pred,y_true,torch.exp(log_var_pred)/temperature)
158 | loss = torch.mean(temperature*precision*(y_true-y_pred)**2 + log_var_pred-model_cali.T)
159 | #print(loss.item())
160 | loss.backward()
161 | return loss
162 | optimizer_cali.step(eval_)
163 | print("Calibration finished!")
164 | return model_cali.T
165 |
166 |
167 | def train_cali_mc(model,num_samples, args, data_loader, scaler, logger=None, path=None):
168 | model_cali = ModelCali(args).cuda()
169 | optimizer_cali = torch.optim.LBFGS(list(model_cali.parameters()), lr=0.02, max_iter=500)
170 | model.eval()
171 | enable_dropout(model)
172 | nll_fun = nn.GaussianNLLLoss()
173 | y_true = []
174 | with torch.no_grad():
175 | for batch_idx, (_, target) in enumerate(data_loader):
176 | label = target[..., :args.output_dim]
177 | y_true.append(label)
178 | y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)).squeeze(3)
179 |
180 | mc_mus = torch.empty(0, y_true.size(0), y_true.size(1), y_true.size(2)).cuda()
181 | mc_log_vars = torch.empty(0, y_true.size(0),y_true.size(1), y_true.size(2)).cuda()
182 |
183 | with torch.no_grad():
184 | for i in tqdm(range(num_samples)):
185 | mu_pred = []
186 | log_var_pred = []
187 | for batch_idx, (data, target) in enumerate(data_loader):
188 | data = data[..., :args.input_dim]
189 | label = target[..., :args.output_dim]
190 | mu, log_var = model.forward(data, target, teacher_forcing_ratio=0)
191 | #print(mu.size())
192 | mu_pred.append(mu.squeeze(3))
193 | log_var_pred.append(log_var.squeeze(3))
194 |
195 | if args.real_value:
196 | mu_pred = torch.cat(mu_pred, dim=0)
197 | else:
198 | mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0))
199 | log_var_pred = torch.cat(log_var_pred, dim=0)
200 |
201 | #print(mc_mus.size(),mu_pred.size())
202 | mc_mus = torch.vstack((mc_mus,mu_pred.unsqueeze(0)))
203 | mc_log_vars = torch.vstack((mc_log_vars,log_var_pred.unsqueeze(0)))
204 |
205 | y_pred = torch.mean(mc_mus, axis=0)
206 | #pred_std = torch.sqrt(torch.exp(torch.mean(mc_log_vars, axis=0)))
207 | #mc_std = torch.std(mc_mus, axis=0)
208 | #total_std = mc_std+pred_std
209 | #total_var = total_std**2
210 | #log_var_total = 2*torch.log(mc_std+pred_std)
211 | log_var_total = torch.exp(torch.mean(mc_log_vars, axis=0))
212 | #precision = (mc_std+pred_std)**2
213 | precision = torch.exp(-torch.mean(mc_log_vars, axis=0))
214 |
215 | def eval_():
216 | optimizer_cali.zero_grad()
217 | #loss(input, target, var)
218 | temperature = torch.exp(model_cali.T)
219 | #loss = nll_fun(y_pred.ravel(),y_true.ravel(),torch.exp(log_var_total.ravel())/temperature)
220 | #loss = torch.mean(temperature*precision*(y_true-y_pred)**2 + log_var_total-model_cali.T)
221 | loss = torch.mean(temperature*precision*(y_true-y_pred)**2-model_cali.T)
222 | #print(loss.item())
223 | loss.backward()
224 | return loss
225 | optimizer_cali.step(eval_)
226 | print("Calibration finished!")
227 | return model_cali.T
228 |
229 |
--------------------------------------------------------------------------------
/notebooks/main_v2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 3,
6 | "id": "1af4aa5e",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "data": {
11 | "text/html": [
12 | ""
3266 | ],
3267 | "text/plain": [
3268 | ""
3269 | ]
3270 | },
3271 | "execution_count": 3,
3272 | "metadata": {},
3273 | "output_type": "execute_result"
3274 | }
3275 | ],
3276 | "source": [
3277 | "%reset -f\n",
3278 | "from jupyterthemes import get_themes\n",
3279 | "import jupyterthemes as jt\n",
3280 | "from jupyterthemes.stylefx import set_nb_theme\n",
3281 | "set_nb_theme('onedork')"
3282 | ]
3283 | },
3284 | {
3285 | "cell_type": "code",
3286 | "execution_count": 68,
3287 | "id": "d024eac8",
3288 | "metadata": {},
3289 | "outputs": [],
3290 | "source": [
3291 | "import torch\n",
3292 | "torch.cuda.empty_cache()"
3293 | ]
3294 | },
3295 | {
3296 | "cell_type": "code",
3297 | "execution_count": 69,
3298 | "id": "fa419dbc",
3299 | "metadata": {},
3300 | "outputs": [
3301 | {
3302 | "data": {
3303 | "text/html": [
3304 | ""
3305 | ],
3306 | "text/plain": [
3307 | ""
3308 | ]
3309 | },
3310 | "metadata": {},
3311 | "output_type": "display_data"
3312 | }
3313 | ],
3314 | "source": [
3315 | "from IPython.display import display, HTML\n",
3316 | "display(HTML(\"\"))"
3317 | ]
3318 | },
3319 | {
3320 | "cell_type": "code",
3321 | "execution_count": 70,
3322 | "id": "8ea4d3cc",
3323 | "metadata": {},
3324 | "outputs": [],
3325 | "source": [
3326 | "import os\n",
3327 | "import sys\n",
3328 | "# file_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))\n",
3329 | "# print(file_dir)\n",
3330 | "# sys.path.append(file_dir)\n",
3331 | "\n",
3332 | "import torch\n",
3333 | "import numpy as np\n",
3334 | "import torch.nn as nn\n",
3335 | "import argparse\n",
3336 | "import configparser\n",
3337 | "from datetime import datetime\n",
3338 | "from model.AGCRN import AGCRN_UQ as Network\n",
3339 | "from model.BasicTrainer import Trainer\n",
3340 | "from lib.TrainInits import init_seed\n",
3341 | "from lib.dataloader import get_dataloader\n",
3342 | "from lib.TrainInits import print_model_parameters\n",
3343 | "from tqdm import tqdm\n",
3344 | "from copy import deepcopy\n",
3345 | "import torch.nn.functional as F\n",
3346 | "import torchcontrib\n",
3347 | "from model.train_methods import swa_train_combined, swa_train, train_cali, train_cali_mc, train_fge\n",
3348 | "from model.test_methods import regular_test, heter_test, mc_test, combined_test, ensemble_test,quantile_test"
3349 | ]
3350 | },
3351 | {
3352 | "cell_type": "code",
3353 | "execution_count": null,
3354 | "id": "7d85451d",
3355 | "metadata": {},
3356 | "outputs": [],
3357 | "source": []
3358 | },
3359 | {
3360 | "cell_type": "code",
3361 | "execution_count": 71,
3362 | "id": "e1c79495",
3363 | "metadata": {},
3364 | "outputs": [],
3365 | "source": [
3366 | "#*************************************************************************#\n",
3367 | "Mode = 'train'\n",
3368 | "DEBUG = 'True'\n",
3369 | "DATASET = 'PEMS04' #PEMS04/8/3/7\n",
3370 | "DEVICE = 'cuda:0'\n",
3371 | "MODEL = 'AGCRN'\n",
3372 | "MODEL_NAME = \"combined\"#\"combined\" #\"combined\"#\"basic/dropout/heter/combined_swa\"\n",
3373 | "P1= 0.1 #04/03/07: 0.1; 08: 0.05\n",
3374 | "\n",
3375 | "#get configuration\n",
3376 | "config_file = 'model/{}_{}.conf'.format(DATASET, MODEL)\n",
3377 | "#print('Read configuration file: %s' % (config_file))\n",
3378 | "config = configparser.ConfigParser()\n",
3379 | "config.read(config_file)\n",
3380 | "\n",
3381 | "#config[\"data\"]\n",
3382 | "\n",
3383 | "from lib.utils import enable_dropout,save_model_,load_model_\n",
3384 | "from lib.metrics import All_Metrics\n"
3385 | ]
3386 | },
3387 | {
3388 | "cell_type": "code",
3389 | "execution_count": 72,
3390 | "id": "35cc4723",
3391 | "metadata": {},
3392 | "outputs": [
3393 | {
3394 | "name": "stdout",
3395 | "output_type": "stream",
3396 | "text": [
3397 | "*****************Model Parameter*****************\n",
3398 | "T torch.Size([1]) True\n",
3399 | "node_embeddings torch.Size([307, 10]) True\n",
3400 | "encoder.dcrnn_cells.0.gate.weights_pool torch.Size([10, 2, 65, 128]) True\n",
3401 | "encoder.dcrnn_cells.0.gate.bias_pool torch.Size([10, 128]) True\n",
3402 | "encoder.dcrnn_cells.0.update.weights_pool torch.Size([10, 2, 65, 64]) True\n",
3403 | "encoder.dcrnn_cells.0.update.bias_pool torch.Size([10, 64]) True\n",
3404 | "encoder.dcrnn_cells.1.gate.weights_pool torch.Size([10, 2, 128, 128]) True\n",
3405 | "encoder.dcrnn_cells.1.gate.bias_pool torch.Size([10, 128]) True\n",
3406 | "encoder.dcrnn_cells.1.update.weights_pool torch.Size([10, 2, 128, 64]) True\n",
3407 | "encoder.dcrnn_cells.1.update.bias_pool torch.Size([10, 64]) True\n",
3408 | "get_mu.0.weight torch.Size([32, 1, 1, 1]) True\n",
3409 | "get_mu.0.bias torch.Size([32]) True\n",
3410 | "get_mu.3.weight torch.Size([12, 32, 1, 64]) True\n",
3411 | "get_mu.3.bias torch.Size([12]) True\n",
3412 | "get_log_var.0.weight torch.Size([32, 1, 1, 1]) True\n",
3413 | "get_log_var.0.bias torch.Size([32]) True\n",
3414 | "get_log_var.3.weight torch.Size([12, 32, 1, 64]) True\n",
3415 | "get_log_var.3.bias torch.Size([12]) True\n",
3416 | "Total params num: 797335\n",
3417 | "*****************Finish Parameter****************\n",
3418 | "Load PEMS04 Dataset shaped: (16992, 307, 1) 919.0 0.0 211.7007794815878 180.0\n",
3419 | "Normalize the dataset by Standard Normalization\n",
3420 | "Train: (10173, 12, 307, 1) (10173, 12, 307, 1)\n",
3421 | "Val: (3375, 12, 307, 1) (3375, 12, 307, 1)\n",
3422 | "Test: (3375, 12, 307, 1) (3375, 12, 307, 1)\n"
3423 | ]
3424 | }
3425 | ],
3426 | "source": [
3427 | "from lib.metrics import MAE_torch\n",
3428 | "def masked_mae_loss(scaler, mask_value):\n",
3429 | " def loss(preds, labels):\n",
3430 | " if scaler:\n",
3431 | " preds = scaler.inverse_transform(preds)\n",
3432 | " labels = scaler.inverse_transform(labels)\n",
3433 | " mae = MAE_torch(pred=preds, true=labels, mask_value=mask_value)\n",
3434 | " return mae\n",
3435 | " return loss\n",
3436 | "\n",
3437 | "#parser\n",
3438 | "args = argparse.ArgumentParser(description='arguments')\n",
3439 | "args.add_argument('--dataset', default=DATASET, type=str)\n",
3440 | "args.add_argument('--mode', default=Mode, type=str)\n",
3441 | "args.add_argument('--device', default=DEVICE, type=str, help='indices of GPUs')\n",
3442 | "args.add_argument('--debug', default=DEBUG, type=eval)\n",
3443 | "args.add_argument('--model', default=MODEL, type=str)\n",
3444 | "args.add_argument('--cuda', default=True, type=bool)\n",
3445 | "#data\n",
3446 | "args.add_argument('--val_ratio', default=config['data']['val_ratio'], type=float)\n",
3447 | "args.add_argument('--test_ratio', default=config['data']['test_ratio'], type=float)\n",
3448 | "#args.add_argument('--val_ratio', default=0.1, type=float)\n",
3449 | "#args.add_argument('--test_ratio', default=0.85, type=float)\n",
3450 | "\n",
3451 | "args.add_argument('--lag', default=config['data']['lag'], type=int)\n",
3452 | "args.add_argument('--horizon', default=config['data']['horizon'], type=int)\n",
3453 | "args.add_argument('--num_nodes', default=config['data']['num_nodes'], type=int)\n",
3454 | "args.add_argument('--tod', default=config['data']['tod'], type=eval)\n",
3455 | "args.add_argument('--normalizer', default=config['data']['normalizer'], type=str)\n",
3456 | "args.add_argument('--column_wise', default=config['data']['column_wise'], type=eval)\n",
3457 | "args.add_argument('--default_graph', default=config['data']['default_graph'], type=eval)\n",
3458 | "#model\n",
3459 | "args.add_argument('--input_dim', default=config['model']['input_dim'], type=int)\n",
3460 | "args.add_argument('--output_dim', default=config['model']['output_dim'], type=int)\n",
3461 | "args.add_argument('--embed_dim', default=config['model']['embed_dim'], type=int)\n",
3462 | "args.add_argument('--rnn_units', default=config['model']['rnn_units'], type=int)\n",
3463 | "args.add_argument('--num_layers', default=config['model']['num_layers'], type=int)\n",
3464 | "args.add_argument('--cheb_k', default=config['model']['cheb_order'], type=int)\n",
3465 | "#train\n",
3466 | "args.add_argument('--loss_func', default=config['train']['loss_func'], type=str)\n",
3467 | "#args.add_argument('--loss_func', default='mse', type=str)\n",
3468 | "args.add_argument('--seed', default=config['train']['seed'], type=int)\n",
3469 | "args.add_argument('--batch_size', default=config['train']['batch_size'], type=int)\n",
3470 | "args.add_argument('--epochs', default=config['train']['epochs'], type=int)\n",
3471 | "#args.add_argument('--epochs', default=500, type=int)\n",
3472 | "args.add_argument('--lr_init', default=config['train']['lr_init'], type=float)\n",
3473 | "#args.add_argument('--lr_init', default=1e-2, type=float)\n",
3474 | "args.add_argument('--lr_decay', default=config['train']['lr_decay'], type=eval)\n",
3475 | "args.add_argument('--lr_decay_rate', default=config['train']['lr_decay_rate'], type=float)\n",
3476 | "args.add_argument('--lr_decay_step', default=config['train']['lr_decay_step'], type=str)\n",
3477 | "args.add_argument('--early_stop', default=config['train']['early_stop'], type=eval)\n",
3478 | "args.add_argument('--early_stop_patience', default=config['train']['early_stop_patience'], type=int)\n",
3479 | "args.add_argument('--grad_norm', default=config['train']['grad_norm'], type=eval)\n",
3480 | "args.add_argument('--max_grad_norm', default=config['train']['max_grad_norm'], type=int)\n",
3481 | "args.add_argument('--teacher_forcing', default=False, type=bool)\n",
3482 | "args.add_argument('--tf_decay_steps', default=2000, type=int, help='teacher forcing decay steps')\n",
3483 | "args.add_argument('--real_value', default=config['train']['real_value'], type=eval, help = 'use real value for loss calculation')\n",
3484 | "#test\n",
3485 | "args.add_argument('--mae_thresh', default=config['test']['mae_thresh'], type=eval)\n",
3486 | "args.add_argument('--mape_thresh', default=config['test']['mape_thresh'], type=float)\n",
3487 | "#log\n",
3488 | "args.add_argument('--log_dir', default='./', type=str)\n",
3489 | "args.add_argument('--log_step', default=config['log']['log_step'], type=int)\n",
3490 | "args.add_argument('--plot', default=config['log']['plot'], type=eval)\n",
3491 | "args.add_argument('--model_name', default=MODEL_NAME, type=str)\n",
3492 | "args.add_argument('--p1', default=P1, type=float)\n",
3493 | "\n",
3494 | "\n",
3495 | "args = args.parse_args([])\n",
3496 | "init_seed(args.seed)\n",
3497 | "if torch.cuda.is_available():\n",
3498 | " torch.cuda.set_device(int(args.device[5]))\n",
3499 | "else:\n",
3500 | " args.device = 'cpu'\n",
3501 | "\n",
3502 | "#init model\n",
3503 | "model = Network(args).to(args.device)\n",
3504 | "for p in model.parameters():\n",
3505 | " if p.dim() > 1:\n",
3506 | " nn.init.xavier_uniform_(p)\n",
3507 | " else:\n",
3508 | " nn.init.uniform_(p)\n",
3509 | "print_model_parameters(model, only_num=False)\n",
3510 | "\n",
3511 | "#load dataset\n",
3512 | "train_loader, val_loader, test_loader, scaler = get_dataloader(args,\n",
3513 | " normalizer=args.normalizer,\n",
3514 | " tod=args.tod, dow=False,\n",
3515 | " weather=False, single=False)\n",
3516 | "#init loss function, optimizer\n",
3517 | "if args.loss_func == 'mask_mae':\n",
3518 | " loss = masked_mae_loss(scaler, mask_value=0.0)\n",
3519 | "elif args.loss_func == 'mae':\n",
3520 | " loss = torch.nn.L1Loss().to(args.device)\n",
3521 | "elif args.loss_func == 'mse':\n",
3522 | " loss = torch.nn.MSELoss().to(args.device)\n",
3523 | "else:\n",
3524 | " raise ValueError\n",
3525 | "\n",
3526 | "# optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr_init, eps=1.0e-8,\n",
3527 | "# weight_decay=1e-6, amsgrad=False)\n",
3528 | "\n",
3529 | "#basic\n",
3530 | "optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr_init, eps=1.0e-8,\n",
3531 | " weight_decay=0, amsgrad=False)\n",
3532 | "#learning rate decay\n",
3533 | "lr_scheduler = None\n",
3534 | "if args.lr_decay:\n",
3535 | " print('Applying learning rate decay.')\n",
3536 | " lr_decay_steps = [int(i) for i in list(args.lr_decay_step.split(','))]\n",
3537 | " lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,\n",
3538 | " milestones=lr_decay_steps,\n",
3539 | " gamma=args.lr_decay_rate)\n",
3540 | "#start training\n",
3541 | "trainer = Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler,\n",
3542 | " args, lr_scheduler=lr_scheduler)\n"
3543 | ]
3544 | },
3545 | {
3546 | "cell_type": "markdown",
3547 | "id": "e640ac10",
3548 | "metadata": {},
3549 | "source": [
3550 | "# Pre-train model"
3551 | ]
3552 | },
3553 | {
3554 | "cell_type": "code",
3555 | "execution_count": 73,
3556 | "id": "206aff0e",
3557 | "metadata": {},
3558 | "outputs": [],
3559 | "source": [
3560 | "#trainer.train()"
3561 | ]
3562 | },
3563 | {
3564 | "cell_type": "markdown",
3565 | "id": "ae42fc74",
3566 | "metadata": {},
3567 | "source": [
3568 | "# AWA re-train model"
3569 | ]
3570 | },
3571 | {
3572 | "cell_type": "code",
3573 | "execution_count": null,
3574 | "id": "6ee40969",
3575 | "metadata": {},
3576 | "outputs": [],
3577 | "source": [
3578 | "#trainer.model = swa_train_combined(trainer,epoch_swa=20)"
3579 | ]
3580 | },
3581 | {
3582 | "cell_type": "markdown",
3583 | "id": "9bdea352",
3584 | "metadata": {},
3585 | "source": [
3586 | "# Save and load model"
3587 | ]
3588 | },
3589 | {
3590 | "cell_type": "code",
3591 | "execution_count": 1,
3592 | "id": "e0c72752",
3593 | "metadata": {},
3594 | "outputs": [],
3595 | "source": [
3596 | "#save_model_(model,args.model_name,args.dataset,args.horizon)\n",
3597 | "#trainer.model = load_model_(model,args.model_name,args.dataset,args.horizon)"
3598 | ]
3599 | },
3600 | {
3601 | "cell_type": "markdown",
3602 | "id": "48f3a147",
3603 | "metadata": {},
3604 | "source": [
3605 | "# MHCC: online calibration"
3606 | ]
3607 | },
3608 | {
3609 | "cell_type": "code",
3610 | "execution_count": 93,
3611 | "id": "d6179d41",
3612 | "metadata": {},
3613 | "outputs": [],
3614 | "source": [
3615 | "\n",
3616 | "\"\"\"\n",
3617 | "Online inference as validation.\n",
3618 | "\"\"\"\n",
3619 | "\n",
3620 | "def combined_conf_val(model,num_samples,args, data_loader, scaler, q,logger=None, path=None):\n",
3621 | " model.eval()\n",
3622 | " enable_dropout(model)\n",
3623 | " nll_fun = nn.GaussianNLLLoss()\n",
3624 | " y_true = []\n",
3625 | " with torch.no_grad():\n",
3626 | " for batch_idx, (_, target) in enumerate(data_loader):\n",
3627 | " label = target[..., :args.output_dim]\n",
3628 | " y_true.append(label)\n",
3629 | " y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)).squeeze(3)\n",
3630 | " \n",
3631 | " mc_mus = torch.empty(0, y_true.size(0), y_true.size(1), y_true.size(2)).cuda()\n",
3632 | " mc_log_vars = torch.empty(0, y_true.size(0),y_true.size(1), y_true.size(2)).cuda()\n",
3633 | " \n",
3634 | " with torch.no_grad():\n",
3635 | " for i in tqdm(range(num_samples)):\n",
3636 | " mu_pred = []\n",
3637 | " log_var_pred = []\n",
3638 | " for batch_idx, (data, _) in enumerate(data_loader):\n",
3639 | " data = data[..., :args.input_dim]\n",
3640 | " mu, log_var = model.forward(data, target, teacher_forcing_ratio=0)\n",
3641 | " #print(mu.size())\n",
3642 | " mu_pred.append(mu.squeeze(3))\n",
3643 | " log_var_pred.append(log_var.squeeze(3))\n",
3644 | " \n",
3645 | " if args.real_value:\n",
3646 | " mu_pred = torch.cat(mu_pred, dim=0)\n",
3647 | " else:\n",
3648 | " mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0)) \n",
3649 | " log_var_pred = torch.cat(log_var_pred, dim=0) \n",
3650 | "\n",
3651 | " #print(mc_mus.size(),mu_pred.size()) \n",
3652 | " mc_mus = torch.vstack((mc_mus,mu_pred.unsqueeze(0))) \n",
3653 | " mc_log_vars = torch.vstack((mc_log_vars,log_var_pred.unsqueeze(0))) \n",
3654 | " \n",
3655 | " y_pred = torch.mean(mc_mus, axis=0)\n",
3656 | " #total_var = (torch.var(mc_mus, axis=0)+torch.exp(torch.mean(mc_log_vars, axis=0)))#/temperature \n",
3657 | " total_var = torch.exp(torch.mean(mc_log_vars, axis=0))\n",
3658 | " total_std = total_var**0.5 \n",
3659 | " \n",
3660 | " mpiw = 2*torch.mean(torch.mul(total_std,q)) \n",
3661 | " nll = nll_fun(y_pred.ravel(), y_true.ravel(), total_var.ravel())\n",
3662 | " lower_bound = y_pred-torch.mul(total_std,q)\n",
3663 | " upper_bound = y_pred+torch.mul(total_std,q) \n",
3664 | "\n",
3665 | " #in_num = torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ))\n",
3666 | " #print(torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ),dim=0))\n",
3667 | " #picp = in_num/(y_true.size(0)*y_true.size(1)*y_true.size(2))\n",
3668 | " in_num = torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ),dim=0)\n",
3669 | " #picp = in_num/(y_true.size(0)*y_true.size(1)*y_true.size(2))\n",
3670 | " in_num = torch.sum(in_num,dim=1)\n",
3671 | " picp = in_num/(y_true.size(0)*y_true.size(2))#.shape\n",
3672 | " return y_true, y_pred, total_std, picp.detach().cpu().numpy()\n",
3673 | " "
3674 | ]
3675 | },
3676 | {
3677 | "cell_type": "code",
3678 | "execution_count": 94,
3679 | "id": "6665b648",
3680 | "metadata": {},
3681 | "outputs": [
3682 | {
3683 | "name": "stderr",
3684 | "output_type": "stream",
3685 | "text": [
3686 | "100%|██████████| 10/10 [00:55<00:00, 5.59s/it]\n"
3687 | ]
3688 | }
3689 | ],
3690 | "source": [
3691 | "y_true_val, y_pred_val, std_val, p = combined_conf_val(model,10,args,val_loader, scaler, q=1.96) \n",
3692 | "#y_true_val, mc_mus_val, mc_log_vars_val = combined_conf_val(model,10,args,test_loader, scaler, q=1.96) #"
3693 | ]
3694 | },
3695 | {
3696 | "cell_type": "code",
3697 | "execution_count": 96,
3698 | "id": "81e913cd",
3699 | "metadata": {},
3700 | "outputs": [],
3701 | "source": [
3702 | "scores = abs(y_pred_val-y_true_val)/std_val#.shape\n",
3703 | "n = y_true_val.shape[0]\n",
3704 | "def quantile_lwci(scores,n,alpha):\n",
3705 | " q = torch.empty(y_true_val.size(1)).cuda()\n",
3706 | " for i in range(y_true_val.shape[1]): \n",
3707 | " qq=np.quantile(scores[:,i,:].detach().cpu().numpy(),min((n+1.0)*(1-alpha)/n,1))\n",
3708 | " q[i]=qq\n",
3709 | " return q\n",
3710 | "def quantile_mhcc(scores,n,alpha_new):\n",
3711 | " q = torch.empty(y_true_val.size(1)).cuda()\n",
3712 | " for i in range(y_true_val.shape[1]): \n",
3713 | " qq=np.quantile(scores[:,i,:].detach().cpu().numpy(),min((n+1.0)*(1-alpha_new[i])/n,1))\n",
3714 | " q[i]=qq\n",
3715 | " return q"
3716 | ]
3717 | },
3718 | {
3719 | "cell_type": "code",
3720 | "execution_count": 97,
3721 | "id": "736f4549",
3722 | "metadata": {},
3723 | "outputs": [],
3724 | "source": [
3725 | "def combined_conf_test(model,num_samples,args, data_loader, scaler, q,logger=None, path=None):\n",
3726 | " model.eval()\n",
3727 | " enable_dropout(model)\n",
3728 | " nll_fun = nn.GaussianNLLLoss()\n",
3729 | " y_true = []\n",
3730 | " with torch.no_grad():\n",
3731 | " for batch_idx, (_, target) in enumerate(data_loader):\n",
3732 | " label = target[..., :args.output_dim]\n",
3733 | " y_true.append(label)\n",
3734 | " y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)).squeeze(3)\n",
3735 | " \n",
3736 | " mc_mus = torch.empty(0, y_true.size(0), y_true.size(1), y_true.size(2)).cuda()\n",
3737 | " mc_log_vars = torch.empty(0, y_true.size(0),y_true.size(1), y_true.size(2)).cuda()\n",
3738 | " \n",
3739 | " with torch.no_grad():\n",
3740 | " for i in tqdm(range(num_samples)):\n",
3741 | " mu_pred = []\n",
3742 | " log_var_pred = []\n",
3743 | " for batch_idx, (data, _) in enumerate(data_loader):\n",
3744 | " data = data[..., :args.input_dim]\n",
3745 | " mu, log_var = model.forward(data, target, teacher_forcing_ratio=0)\n",
3746 | " #print(mu.size())\n",
3747 | " mu_pred.append(mu.squeeze(3))\n",
3748 | " log_var_pred.append(log_var.squeeze(3))\n",
3749 | " \n",
3750 | " if args.real_value:\n",
3751 | " mu_pred = torch.cat(mu_pred, dim=0)\n",
3752 | " else:\n",
3753 | " mu_pred = scaler.inverse_transform(torch.cat(mu_pred, dim=0)) \n",
3754 | " log_var_pred = torch.cat(log_var_pred, dim=0) \n",
3755 | "\n",
3756 | " #print(mc_mus.size(),mu_pred.size()) \n",
3757 | " mc_mus = torch.vstack((mc_mus,mu_pred.unsqueeze(0))) \n",
3758 | " mc_log_vars = torch.vstack((mc_log_vars,log_var_pred.unsqueeze(0))) \n",
3759 | " \n",
3760 | " y_pred = torch.mean(mc_mus, axis=0)\n",
3761 | " total_var = (torch.var(mc_mus, axis=0)+torch.exp(torch.mean(mc_log_vars, axis=0)))#/temperature \n",
3762 | " total_std = total_var**0.5 \n",
3763 | " \n",
3764 | " mpiw = 2*torch.mean(torch.mul(total_std,q)) \n",
3765 | " nll = nll_fun(y_pred.ravel(), y_true.ravel(), total_var.ravel())\n",
3766 | " lower_bound = y_pred-torch.mul(total_std,q)\n",
3767 | " upper_bound = y_pred+torch.mul(total_std,q) \n",
3768 | " \n",
3769 | " in_num = torch.sum((y_true >= lower_bound)&(y_true <= upper_bound ),dim=0)\n",
3770 | " in_num = torch.sum(in_num,dim=1)\n",
3771 | " picp = in_num/(y_true.size(0)*y_true.size(2))#.shape\n",
3772 | " print(picp*100, torch.mean(picp).item()*100, mpiw.item())\n",
3773 | " #return y_true, y_pred, total_std"
3774 | ]
3775 | },
3776 | {
3777 | "cell_type": "markdown",
3778 | "id": "9988a0e6",
3779 | "metadata": {},
3780 | "source": [
3781 | "# Correct target signifcance level $\\alpha$"
3782 | ]
3783 | },
3784 | {
3785 | "cell_type": "code",
3786 | "execution_count": 103,
3787 | "id": "10579167",
3788 | "metadata": {},
3789 | "outputs": [
3790 | {
3791 | "name": "stdout",
3792 | "output_type": "stream",
3793 | "text": [
3794 | "[0.05340891 0.05050783 0.05023377 0.05054503 0.04958587 0.04828899\n",
3795 | " 0.04636179 0.04609852 0.0462424 0.04623555 0.04712426 0.04645186]\n",
3796 | "[0.05340891 0.05092526 0.05106862 0.0517973 0.05125556 0.05037611\n",
3797 | " 0.04886633 0.04902048 0.04958179 0.04999236 0.05129849 0.05104351]\n"
3798 | ]
3799 | }
3800 | ],
3801 | "source": [
3802 | "h = np.arange(args.horizon)\n",
3803 | "alpha = 0.05\n",
3804 | "gamma = 0.03#04/07/08:0.03, 03:0\n",
3805 | "alpha_new = p-(1-alpha)+alpha\n",
3806 | "alpha_new = alpha_new+(p[0]-p[-1])*gamma*h*2 #04/07/08:0.03\n",
3807 | "q = quantile_mhcc(scores,n,alpha_new)"
3808 | ]
3809 | },
3810 | {
3811 | "cell_type": "code",
3812 | "execution_count": 104,
3813 | "id": "30993a93",
3814 | "metadata": {},
3815 | "outputs": [
3816 | {
3817 | "name": "stderr",
3818 | "output_type": "stream",
3819 | "text": [
3820 | ":2: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
3821 | " q = torch.tensor(q.reshape(-1,12,1)).to(args.device)\n",
3822 | "100%|██████████| 10/10 [00:56<00:00, 5.67s/it]"
3823 | ]
3824 | },
3825 | {
3826 | "name": "stdout",
3827 | "output_type": "stream",
3828 | "text": [
3829 | "tensor([95.0162, 95.1789, 95.1004, 94.9924, 94.9851, 95.0227, 95.1640, 95.0996,\n",
3830 | " 95.0323, 95.0006, 94.8244, 94.8621], device='cuda:0') 95.0232207775116 103.91312408447266\n"
3831 | ]
3832 | },
3833 | {
3834 | "name": "stderr",
3835 | "output_type": "stream",
3836 | "text": [
3837 | "\n"
3838 | ]
3839 | }
3840 | ],
3841 | "source": [
3842 | "q = torch.tensor(q.reshape(-1,args.horizon,1)).to(args.device)\n",
3843 | "combined_conf_test(model,10,args,test_loader,scaler,q)"
3844 | ]
3845 | },
3846 | {
3847 | "cell_type": "markdown",
3848 | "id": "fee6a368",
3849 | "metadata": {},
3850 | "source": [
3851 | "# Online MHCC: online calibration\n",
3852 | "\n",
3853 | "### 1. Update nonconformity score\n",
3854 | "\n",
3855 | "### 2. Update empirical prediction interval percentage coverage\n",
3856 | "\n",
3857 | "### 3. Update alpha"
3858 | ]
3859 | },
3860 | {
3861 | "cell_type": "code",
3862 | "execution_count": null,
3863 | "id": "ce5242e3",
3864 | "metadata": {},
3865 | "outputs": [],
3866 | "source": [
3867 | "online_size = y_true_val.shape[0]#1000\n",
3868 | "y_true_online = y_true_val[:online_size,...]\n",
3869 | "y_pred_online = y_pred_val[:online_size,...]\n",
3870 | "std_online = std_val[:online_size,...]\n",
3871 | "score_online = scores[:online_size,...]\n",
3872 | "\n",
3873 | "\n",
3874 | "gamma_online = 0.03 #04/07/08:0.03, 03:0\n",
3875 | "h = np.arange(args.horizon)\n",
3876 | "update_freq = 1000\n",
3877 | "picp_online = p"
3878 | ]
3879 | },
3880 | {
3881 | "cell_type": "code",
3882 | "execution_count": null,
3883 | "id": "67984ca2",
3884 | "metadata": {},
3885 | "outputs": [],
3886 | "source": [
3887 | "alpha = 0.05\n",
3888 | "scores_online = scores\n",
3889 | "\n",
3890 | "def mhcc_online_(y_true_test, y_pred_test, std_test, q_online = q):\n",
3891 | " ii = 0\n",
3892 | " mpiw_ls = [] \n",
3893 | " picp_ls = [] \n",
3894 | " for y_t, p_t, s_t in tqdm(zip(y_true_test, y_pred_test, std_test)): \n",
3895 | " \n",
3896 | " #emperical picp\n",
3897 | " mpiw = 2*torch.mean(torch.mul(s_t,q_online))\n",
3898 | " lower_bound = p_t-torch.mul(s_t,q_online)\n",
3899 | " upper_bound = p_t+torch.mul(s_t,q_online) \n",
3900 | "\n",
3901 | " in_num = torch.sum((y_t >= lower_bound)&(y_t <= upper_bound),dim=0)\n",
3902 | " in_num = torch.sum(in_num,dim=1)\n",
3903 | " picp = in_num/y_true_test.size(2)\n",
3904 | "\n",
3905 | " mpiw_ls.append(mpiw) \n",
3906 | " picp_ls.append(picp) \n",
3907 | " picp = picp.detach().cpu().numpy()\n",
3908 | " picp_online = (picp_online*(online_size+ii) + picp)/(online_size+ii +1)\n",
3909 | " \n",
3910 | " #update alpha\n",
3911 | " score_new = abs(p_t-y_t)/s_t\n",
3912 | " scores_online = torch.cat([scores_online[1:,...],score_new.unsqueeze(0)],dim=0)\n",
3913 | "\n",
3914 | "\n",
3915 | " if (ii +1) % update_freq ==0:\n",
3916 | " #print(ii)\n",
3917 | " alpha_new = picp_online-(1-alpha)+alpha\n",
3918 | " alpha_new = alpha_new+(picp_online[0]-picp_online[-1])*gamma_online*h*2 #04/07/08:0.03\n",
3919 | " for i in range(y_true_test.shape[1]): \n",
3920 | " qq = np.quantile(scores_online[:,i,:].detach().cpu().numpy(),min((online_size+1.0)*(1-alpha_new[i])/online_size,1))\n",
3921 | " #print(qq.shape) \n",
3922 | " q_online[0,i,0]=torch.from_numpy(qq.reshape(-1,1))\n",
3923 | "\n",
3924 | " ii = ii+1 \n",
3925 | " \n",
3926 | " picp_test = torch.stack(picp_ls,dim=1) \n",
3927 | " mpiw_test = torch.stack(mpiw_ls,dim=0) \n",
3928 | " \n",
3929 | " ### update q\n",
3930 | " q_online = q\n",
3931 | " mpiw = 2*torch.mean(torch.mul(std_test,q_online))#.reshape(-1,y_true_test.shape[1],1)) ) \n",
3932 | " lower_bound = y_pred_test-torch.mul(std_test,q_online)\n",
3933 | " upper_bound = y_pred_test+torch.mul(std_test,q_online) \n",
3934 | "\n",
3935 | " in_num = torch.sum((y_true_test >= lower_bound)&(y_true_test <= upper_bound),dim=0)\n",
3936 | " in_num = torch.sum(in_num,dim=1)\n",
3937 | " picp = in_num/(y_true_test.size(0)*y_true_test.size(2)) \n",
3938 | "\n",
3939 | " print(\"PICP: {:.4f}, MPIW: {:.4f}\".format(picp,mpiw))\n",
3940 | " \n",
3941 | " "
3942 | ]
3943 | },
3944 | {
3945 | "cell_type": "code",
3946 | "execution_count": null,
3947 | "id": "ca062c76",
3948 | "metadata": {},
3949 | "outputs": [],
3950 | "source": [
3951 | " mhcc_online_(y_true_test, y_pred_test, std_test, q_online = q)"
3952 | ]
3953 | },
3954 | {
3955 | "cell_type": "code",
3956 | "execution_count": null,
3957 | "id": "87de181b",
3958 | "metadata": {},
3959 | "outputs": [],
3960 | "source": []
3961 | },
3962 | {
3963 | "cell_type": "code",
3964 | "execution_count": null,
3965 | "id": "32915a3c",
3966 | "metadata": {},
3967 | "outputs": [],
3968 | "source": []
3969 | },
3970 | {
3971 | "cell_type": "code",
3972 | "execution_count": null,
3973 | "id": "48a8c8bc",
3974 | "metadata": {},
3975 | "outputs": [],
3976 | "source": []
3977 | },
3978 | {
3979 | "cell_type": "code",
3980 | "execution_count": null,
3981 | "id": "c15307bc",
3982 | "metadata": {},
3983 | "outputs": [],
3984 | "source": []
3985 | }
3986 | ],
3987 | "metadata": {
3988 | "kernelspec": {
3989 | "display_name": "Python 3",
3990 | "language": "python",
3991 | "name": "python3"
3992 | },
3993 | "language_info": {
3994 | "codemirror_mode": {
3995 | "name": "ipython",
3996 | "version": 3
3997 | },
3998 | "file_extension": ".py",
3999 | "mimetype": "text/x-python",
4000 | "name": "python",
4001 | "nbconvert_exporter": "python",
4002 | "pygments_lexer": "ipython3",
4003 | "version": "3.8.10"
4004 | }
4005 | },
4006 | "nbformat": 4,
4007 | "nbformat_minor": 5
4008 | }
4009 |
--------------------------------------------------------------------------------
/notebooks/readme:
--------------------------------------------------------------------------------
1 | DeepSTUQ with MHCC
2 |
--------------------------------------------------------------------------------
/slide/uq_slide.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WeizhuQIAN/DeepSTUQ_Pytorch/4f2281531d07b48ba81aa1c519a90e65a71993e2/slide/uq_slide.pdf
--------------------------------------------------------------------------------