├── exp ├── __init__.py ├── exp_basic.py └── exp_air.py ├── models ├── __init__.py ├── layers │ ├── tools.py │ ├── DiffAdv_Fusion.py │ ├── Diffeq_solver.py │ ├── soft_losses.py │ ├── timelags.py │ ├── Unk_Dynamics.py │ ├── Decoder.py │ ├── Dynamics_Funsion.py │ ├── Encoder.py │ ├── Embed.py │ └── Explicit_odefunc.py └── Air_DualODE.py ├── utils ├── __init__.py ├── tools.py ├── metrics.py ├── geo_utils.py └── utils.py ├── Evaluation ├── __init__.py └── evaluation.py ├── Data_Provider ├── __init__.py ├── data_factory.py └── data_loader.py ├── fig └── Air-DualODE.png ├── dataset ├── KnowAir │ ├── graph_data.npz │ └── station.csv └── Beijing1718 │ ├── graph_data.npz │ └── station.csv ├── requirements.txt ├── Model_Config ├── KnowAir │ └── Air-DualODE_config.yaml └── Beijing │ └── Air-DualODE_config.yaml ├── README.md └── Run ├── train.py ├── eval.py └── create_graph.py /exp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Data_Provider/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fig/Air-DualODE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionintelligence/Air-DualODE/HEAD/fig/Air-DualODE.png -------------------------------------------------------------------------------- /dataset/KnowAir/graph_data.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionintelligence/Air-DualODE/HEAD/dataset/KnowAir/graph_data.npz -------------------------------------------------------------------------------- /dataset/Beijing1718/graph_data.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionintelligence/Air-DualODE/HEAD/dataset/Beijing1718/graph_data.npz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | arrow==1.3.0 2 | pillow==10.3.0 3 | torch==2.3.0 4 | torch_geometric==2.5.3 5 | scipy>=1.5.2 6 | numpy>=1.19.1 7 | pandas>=1.1.5 8 | torchdiffeq==0.2.3 -------------------------------------------------------------------------------- /models/layers/tools.py: -------------------------------------------------------------------------------- 1 | class Air_Attrs: 2 | def __init__(self, args): 3 | self.args = args 4 | self.adj_mx = args.adj_mx 5 | self.num_nodes = args.adj_mx.shape[0] 6 | self.num_edges = args.edge_index.shape[1] 7 | 8 | self.seq_len = int(args.model.seq_len) 9 | self.horizon = int(args.model.horizon) 10 | self.input_dim = int(args.model.input_dim) 11 | self.X_dim = int(args.model.X_dim) -------------------------------------------------------------------------------- /models/layers/DiffAdv_Fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class Simple_Gated_Fusion(nn.Module): 6 | def __init__(self, num_nodes, var_dim): 7 | super(Simple_Gated_Fusion, self).__init__() 8 | 9 | self.num_nodes = num_nodes 10 | self.var_dim = var_dim 11 | self.gated_fc = nn.Linear(2, 1) 12 | 13 | def forward(self, grad_diff, grad_adv): 14 | """ 15 | 16 | :param grad_diff: B x NDout 17 | :param grad_adv: B x NDout 18 | :return: B x NDout 19 | """ 20 | B = grad_diff.shape[0] 21 | grad_diff = grad_diff.reshape(B, self.num_nodes, self.var_dim) 22 | grad_adv = grad_adv.reshape(B, self.num_nodes, self.var_dim) 23 | concat = torch.cat((grad_diff, grad_adv), dim=-1) # B x N x 2 24 | g = torch.sigmoid(self.gated_fc(concat)) 25 | 26 | grad_diff_adv = g * grad_diff + (1 - g) * grad_adv # B x N x 1 27 | 28 | return grad_diff_adv.reshape(B, self.num_nodes*self.var_dim) -------------------------------------------------------------------------------- /dataset/Beijing1718/station.csv: -------------------------------------------------------------------------------- 1 | station,longitude,latitude 2 | dongsi_aq,116.417,39.929 3 | tiantan_aq,116.407,39.886 4 | guanyuan_aq,116.339,39.929 5 | wanshouxigong_aq,116.352,39.878 6 | aotizhongxin_aq,116.397,39.982 7 | nongzhanguan_aq,116.461,39.937 8 | wanliu_aq,116.287,39.987 9 | beibuxinqu_aq,116.174,40.09 10 | zhiwuyuan_aq,116.207,40.002 11 | fengtaihuayuan_aq,116.279,39.863 12 | yungang_aq,116.146,39.824 13 | gucheng_aq,116.184,39.914 14 | fangshan_aq,116.136,39.742 15 | daxing_aq,116.404,39.718 16 | yizhuang_aq,116.506,39.795 17 | tongzhou_aq,116.663,39.886 18 | shunyi_aq,116.655,40.127 19 | pingchang_aq,116.23,40.217 20 | mentougou_aq,116.106,39.937 21 | pinggu_aq,117.1,40.143 22 | huairou_aq,116.628,40.328 23 | miyun_aq,116.832,40.37 24 | yanqin_aq,115.972,40.453 25 | dingling_aq,116.22,40.292 26 | badaling_aq,115.988,40.365 27 | miyunshuiku_aq,116.911,40.499 28 | donggaocun_aq,117.12,40.1 29 | yongledian_aq,116.783,39.712 30 | yufa_aq,116.3,39.52 31 | liulihe_aq,116,39.58 32 | qianmen_aq,116.395,39.899 33 | yongdingmennei_aq,116.394,39.876 34 | xizhimenbei_aq,116.349,39.954 35 | nansanhuan_aq,116.368,39.856 36 | dongsihuan_aq,116.483,39.939 37 | -------------------------------------------------------------------------------- /exp/exp_basic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models import Air_DualODE 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | 7 | class Exp_Basic(object): 8 | def __init__(self, args): 9 | self.args = args 10 | self.model_dict = { 11 | "Air-DualODE": Air_DualODE, 12 | } 13 | self.device = self._acquire_device(args.GPU) 14 | self.model = self._build_model().to(self.device) 15 | 16 | def _build_model(self): 17 | raise NotImplementedError 18 | 19 | def _build_TB_logger(self, setting): 20 | # TB_logger 21 | log_dir = os.path.join(self.args.TB_dir, setting) 22 | if not os.path.exists(log_dir): 23 | os.makedirs(log_dir) 24 | logger = SummaryWriter(log_dir) 25 | 26 | return logger 27 | 28 | def _acquire_device(self, args): 29 | if args.use_gpu: 30 | device = torch.device('cuda:{}'.format(args.gpu)) 31 | print('Use GPU: cuda:{}'.format(args.gpu)) 32 | else: 33 | device = torch.device('cpu') 34 | print('Use CPU') 35 | return device 36 | 37 | def _get_data(self, **kwargs): 38 | pass 39 | 40 | def vali(self, **kwargs): 41 | pass 42 | 43 | def train(self, **kwargs): 44 | pass 45 | 46 | def test(self, **kwargs): 47 | pass 48 | -------------------------------------------------------------------------------- /models/layers/Diffeq_solver.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class DiffeqSolver: 4 | def __init__(self, method, odeint_rtol=1e-5, 5 | odeint_atol=1e-5, adjoint=False): 6 | self.ode_method = method 7 | if adjoint: 8 | from torchdiffeq import odeint_adjoint as odeint 9 | else: 10 | from torchdiffeq import odeint 11 | self.odeint = odeint 12 | 13 | self.rtol = odeint_rtol 14 | self.atol = odeint_atol 15 | 16 | def solve(self, odefunc, first_point, time_steps_to_pred): 17 | """ 18 | Decoder the trajectory through the ODE Solver. 19 | :param time_steps_to_pred: horizon 20 | :param first_point: (batch_size, num_nodes * latent_dim) 21 | :return: pred_y: # shape (horizon, n_traj_samples, batch_size, self.num_nodes * self.output_dim) 22 | """ 23 | start_time = time.time() 24 | odefunc.nfe = 0 25 | pred_y = self.odeint(odefunc, 26 | first_point, 27 | time_steps_to_pred, 28 | rtol=self.rtol, 29 | atol=self.atol, 30 | method=self.ode_method) 31 | # pred_y: (seq_len + 1[first point]) x B x N 32 | time_fe = time.time() - start_time 33 | 34 | return pred_y, (odefunc.nfe, time_fe) 35 | -------------------------------------------------------------------------------- /Data_Provider/data_factory.py: -------------------------------------------------------------------------------- 1 | from Data_Provider.data_loader import Dataset_Beijing1718, Dataset_KnowAir 2 | from torch.utils.data import DataLoader 3 | 4 | data_dict = { 5 | 'Beijing1718': Dataset_Beijing1718, 6 | 'KnowAir': Dataset_KnowAir 7 | } 8 | 9 | 10 | def data_provider(args, flag): 11 | data_args = args.data 12 | model_args = args.model 13 | Data = data_dict[data_args.data_name] 14 | 15 | if flag == 'train': 16 | shuffle_flag = True 17 | drop_last = True 18 | else: 19 | shuffle_flag = False 20 | drop_last = False 21 | batch_size = data_args.batch_size 22 | 23 | if data_args.data_name == "Beijing1718_old": 24 | data_set = Data( 25 | root_path=data_args.root_path, 26 | flag=flag 27 | ) 28 | else: 29 | data_set = Data( 30 | root_path=data_args.root_path, 31 | flag=flag, 32 | seq_len=model_args.seq_len, 33 | pred_len=model_args.horizon, 34 | freq=data_args.interval, 35 | embed=data_args.embed, 36 | scale=True, 37 | normalized_col=data_args.normalized_columns 38 | ) 39 | print(flag, len(data_set)) 40 | data_loader = DataLoader( 41 | data_set, 42 | batch_size=batch_size, 43 | shuffle=shuffle_flag, 44 | num_workers=data_args.num_workers, 45 | drop_last=drop_last) 46 | return data_set, data_loader -------------------------------------------------------------------------------- /models/layers/soft_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from models.layers.timelags import * 4 | from torch import nn 5 | 6 | def temp_CL_soft(z1, z2, timelag_L, timelag_R, num_nodes): 7 | B, T = z1.size(0), z1.size(1) 8 | if T == 1: 9 | return z1.new_tensor(0.) 10 | z = torch.cat([z1, z2], dim=1) # B x 2T x C 11 | z = F.normalize(z, p=2, dim=-1) 12 | sim = torch.matmul(z, z.transpose(1, 2)) # B x 2T x 2T 13 | logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # B x 2T x (2T-1) 14 | logits += torch.triu(sim, diagonal=1)[:, :, 1:] 15 | logits = -F.log_softmax(logits, dim=-1) 16 | t = torch.arange(T, device=z1.device) 17 | loss = torch.sum(logits[:,t]*timelag_L) 18 | loss += torch.sum(logits[:,T + t]*timelag_R) 19 | loss /= (2*B*T*num_nodes) 20 | return loss 21 | 22 | def temporal_alignment(Z_P, Z_D, num_nodes, latent_dim): 23 | T, B, _ = Z_P.shape 24 | Z_P = Z_P.reshape(T, B, num_nodes, latent_dim) 25 | Z_D = Z_D.reshape(T, B, num_nodes, latent_dim) 26 | Z_P = Z_P.permute(1, 2, 0, 3) # B x N x T x D 27 | Z_D = Z_D.permute(1, 2, 0, 3) 28 | Z_P = Z_P.reshape(B*num_nodes, T, latent_dim) 29 | Z_D = Z_D.reshape(B*num_nodes, T, latent_dim) 30 | 31 | lag = torch.tensor(timelag_sigmoid(T), device=Z_P.device).float() 32 | timelag_L, timelag_R = dup_matrix(lag) 33 | loss = temp_CL_soft(Z_P, Z_D, timelag_L, timelag_R, num_nodes) 34 | 35 | return loss 36 | 37 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Optional 4 | from logging import Logger 5 | 6 | def count_parameters(model): 7 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 8 | 9 | class EarlyStopping: 10 | def __init__(self, patience=7, verbose=False, delta=0, logger: Optional[Logger]=None): 11 | self.patience = patience 12 | self.verbose = verbose 13 | self.counter = 0 14 | self.best_score = None 15 | self.early_stop = False 16 | self.val_loss_min = np.Inf 17 | self.delta = delta 18 | self.logger = logger 19 | 20 | def __call__(self, val_loss, model, path): 21 | score = -val_loss 22 | if self.best_score is None: 23 | self.best_score = score 24 | self.save_checkpoint(val_loss, model, path) 25 | elif score < self.best_score + self.delta: 26 | self.counter += 1 27 | message = f'EarlyStopping counter: {self.counter} out of {self.patience}' 28 | self.logger.info(message) 29 | if self.counter >= self.patience: 30 | self.early_stop = True 31 | else: 32 | self.best_score = score 33 | self.save_checkpoint(val_loss, model, path) 34 | self.counter = 0 35 | 36 | def save_checkpoint(self, val_loss, model, path): 37 | if self.verbose: 38 | message = f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...' 39 | self.logger.info(message) 40 | torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') 41 | self.val_loss_min = val_loss -------------------------------------------------------------------------------- /models/layers/timelags.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def dup_matrix(mat): 6 | mat0 = torch.tril(mat, diagonal=-1)[:, :-1] 7 | mat0 += torch.triu(mat, diagonal=1)[:, 1:] 8 | mat1 = torch.cat([mat0, mat], dim=1) 9 | mat2 = torch.cat([mat, mat0], dim=1) 10 | return mat1, mat2 11 | 12 | def timelag_sigmoid(T, sigma=1): 13 | dist = np.arange(T) 14 | dist = np.abs(dist - dist[:, np.newaxis]) 15 | matrix = 2 / (1 + np.exp(dist * sigma)) 16 | matrix = np.where(matrix < 1e-6, 0, matrix) # set very small values to 0 17 | return matrix 18 | 19 | 20 | def timelag_gaussian(T, sigma): 21 | dist = np.arange(T) 22 | dist = np.abs(dist - dist[:, np.newaxis]) 23 | matrix = np.exp(-(dist ** 2) / (2 * sigma ** 2)) 24 | matrix = np.where(matrix < 1e-6, 0, matrix) 25 | return matrix 26 | 27 | 28 | def timelag_same_interval(T): 29 | d = np.arange(T) 30 | X, Y = np.meshgrid(d, d) 31 | matrix = 1 - np.abs(X - Y) / T 32 | return matrix 33 | 34 | 35 | def timelag_sigmoid_window(T, sigma=1, window_ratio=1.0): 36 | dist = np.arange(T) 37 | dist = np.abs(dist - dist[:, np.newaxis]) 38 | matrix = 2 / (1 + np.exp(dist * sigma)) 39 | matrix = np.where(matrix < 1e-6, 0, matrix) 40 | dist_from_diag = np.abs(np.subtract.outer(np.arange(dist.shape[0]), np.arange(dist.shape[1]))) 41 | matrix[dist_from_diag > T * window_ratio] = 0 42 | return matrix 43 | 44 | 45 | def timelag_sigmoid_threshold(T, threshold=1.0): 46 | dist = np.ones((T, T)) 47 | dist_from_diag = np.abs(np.subtract.outer(np.arange(dist.shape[0]), np.arange(dist.shape[1]))) 48 | dist[dist_from_diag > T * threshold] = 0 49 | return dist 50 | -------------------------------------------------------------------------------- /models/layers/Unk_Dynamics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Unk_odefunc_ATT(nn.Module): 6 | def __init__(self, latent_dim, num_nodes, n_heads, device, adj_mask=None, d_f=32): 7 | super(Unk_odefunc_ATT, self).__init__() 8 | self.nfe = 0 9 | self.latent_dim = latent_dim 10 | self.num_nodes = num_nodes 11 | if adj_mask is None: 12 | self.adj_mask = None 13 | else: 14 | self.adj_mask = torch.tensor(adj_mask, dtype=torch.int8, 15 | device=device) + torch.eye(self.num_nodes, device=device) 16 | 17 | self.fc = nn.Linear(latent_dim, latent_dim) 18 | self.spatial_att = nn.MultiheadAttention(latent_dim, num_heads=n_heads, batch_first=True) 19 | self.layer_norm_1 = nn.LayerNorm(latent_dim) 20 | self.layer_norm_2 = nn.LayerNorm(latent_dim) 21 | # residual 22 | self.residual_1 = nn.Identity() 23 | self.residual_2 = nn.Identity() 24 | 25 | def forward(self, t, z): 26 | """ 27 | F^D with attention block 28 | :param t: 29 | :param z: B x N*latent_dim 30 | :return: 31 | """ 32 | self.nfe += 1 33 | B = z.shape[0] 34 | z = z.reshape(B, self.num_nodes, self.latent_dim) # B x N x latent_dim 35 | # att-add&norm 36 | # Masked self-attention 37 | z = self.residual_1(z) + self.spatial_att(z, z, z, attn_mask=self.adj_mask)[0] 38 | z = self.residual_1(z) + self.spatial_att(z, z, z)[0] 39 | z = self.layer_norm_1(z) 40 | # ffd-add&norm 41 | z = self.residual_2(z) + F.relu(self.fc(z)) 42 | z = self.layer_norm_2(z) 43 | 44 | return z.reshape(B, self.num_nodes * self.latent_dim) 45 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def MAE(pred, true): 5 | return torch.abs(pred - true) 6 | 7 | 8 | def MSE(pred, true): 9 | return (pred - true) ** 2 10 | 11 | 12 | def MAPE(pred, true): 13 | return torch.abs((pred - true) / true) 14 | 15 | 16 | def SMAPE(pred, true): 17 | # Avoid division by zero by adding a small constant 18 | denominator = (torch.abs(true) + torch.abs(pred)) / 2 + 1e-8 19 | 20 | # Calculate the SMAPE 21 | smape_value = torch.mean(torch.abs(pred - true) / denominator) 22 | 23 | return smape_value 24 | 25 | 26 | def masked_loss(y_pred, y_true, loss_func): 27 | y_true[y_true < 1e-4] = 0 28 | mask = (y_true != 0).float() 29 | mask /= mask.mean() # assign the sample weights of zeros to nonzero-values 30 | loss = loss_func(y_pred, y_true) 31 | loss = loss * mask 32 | loss[loss != loss] = 0 33 | return loss.mean() 34 | 35 | 36 | def masked_rmse_loss(y_pred, y_true): 37 | y_true[y_true < 1e-4] = 0 38 | mask = (y_true != 0).float() 39 | mask /= mask.mean() 40 | loss = torch.pow(y_pred - y_true, 2) 41 | loss = loss * mask 42 | loss[loss != loss] = 0 43 | return torch.sqrt(loss.mean()) 44 | 45 | 46 | def compute_all_metrics(y_pred, y_true): 47 | mae = masked_loss(y_pred, y_true, MAE).item() 48 | rmse = masked_rmse_loss(y_pred, y_true).item() 49 | smape = masked_loss(y_pred, y_true, SMAPE).item() 50 | return mae, smape, rmse 51 | 52 | 53 | if __name__ == '__main__': 54 | y_pred = torch.rand(24, 32, 35) 55 | y_true = torch.rand(24, 32, 35) 56 | 57 | mae, smape, rmse = compute_all_metrics(y_pred, y_true) 58 | print(mae, masked_loss(y_pred, y_true, MAE)) 59 | print(rmse, masked_rmse_loss(y_pred, y_true)) 60 | print(smape, masked_loss(y_pred, y_true, SMAPE)) 61 | -------------------------------------------------------------------------------- /utils/geo_utils.py: -------------------------------------------------------------------------------- 1 | import srtm 2 | import requests 3 | import numpy as np 4 | import os 5 | 6 | 7 | def get_elevation(lat, lon, srtm_path): 8 | elevation_data = srtm.get_data(local_cache_dir=srtm_path) 9 | elevation = elevation_data.get_elevation(lat, lon) 10 | if elevation is None: 11 | lat = round(round(lat * 1200, 4) / 1200.0, 2) 12 | lon = round(round(lon * 1200, 4) / 1200.0, 2) 13 | elevation = elevation_data.get_elevation(lat, lon) 14 | if elevation is None: 15 | lat_prefix = 'N' if lat >= 0 else 'S' 16 | lon_prefix = 'E' if lon >= 0 else 'W' 17 | hgt_file = f"{lat_prefix}{int(abs(lat)):02d}{lon_prefix}{int(abs(lon)):03d}.hgt" 18 | file_path = os.path.join(srtm_path, hgt_file) 19 | 20 | if not os.path.exists(file_path): 21 | elevation = 0 22 | else: 23 | elevation = get_elevation_online(lat, lon) 24 | 25 | return elevation 26 | 27 | 28 | def get_elevation_online(lat, lon): 29 | print("Online Querying") 30 | url = f"https://api.open-elevation.com/api/v1/lookup?locations={lat}, {lon}" 31 | response = requests.get(url) 32 | if response.status_code == 200: 33 | elevation_data = response.json() 34 | return elevation_data['results'][0]['elevation'] 35 | else: 36 | raise Exception("Error in API request: " + str(response.status_code)) 37 | 38 | 39 | def interpolate_points(point1, point2, num_points=15): 40 | lat1, lon1 = point1 41 | lat2, lon2 = point2 42 | 43 | lats = np.linspace(lat1, lat2, num_points) 44 | lons = np.linspace(lon1, lon2, num_points) 45 | 46 | return np.array([lats, lons]).T 47 | 48 | 49 | if __name__ == "__main__": 50 | latitude = 35.5 51 | longitude = 113.5 52 | elevation = get_elevation(latitude, longitude, "../dataset/srtm") 53 | print(elevation) -------------------------------------------------------------------------------- /Model_Config/KnowAir/Air-DualODE_config.yaml: -------------------------------------------------------------------------------- 1 | log_base_dir: ./logs 2 | to_log_file: True 3 | to_stdout: True 4 | TB_dir: False 5 | checkpoints: ./checkpoints 6 | log_level: INFO 7 | model_name: Air-DualODE 8 | 9 | data: 10 | batch_size: 64 11 | interval: '3h' 12 | data_name: KnowAir 13 | root_path: ../dataset/KnowAir 14 | num_workers: 0 15 | normalized_columns: [ 0, 1, 2, 3, 4, 5 ] 16 | embed: 0 17 | 18 | model: 19 | seq_len: 24 20 | horizon: 24 21 | input_dim: 6 22 | embed_dim: 6 23 | X_dim: 1 24 | 25 | embedding: 26 | hour2day: 4 27 | day2week: 3 28 | day2month: 4 29 | month2year: 3 30 | station: 4 31 | 32 | phy_func: 33 | enable: True 34 | # gnn para in phy_func 35 | knowledge: diff_adv 36 | gnn_layers: 3 37 | cheb_k: 3 38 | gnn_hid_dim: 64 39 | coeff_estimator: False 40 | rnn_layers: 1 41 | rnn_dim: 64 42 | input_dim: 1 43 | latent_dim: 64 44 | ode_method: dopri5 45 | odeint_atol: 1e-2 46 | odeint_rtol: 1e-2 47 | adjoint: True 48 | 49 | unk_func: 50 | enable: True 51 | rnn_layers: 1 52 | rnn_dim: 64 53 | input_dim: 6 54 | latent_dim: 64 55 | n_heads: 4 56 | d_f: 32 57 | ode_method: rk4 58 | odeint_atol: 1e-4 59 | odeint_rtol: 1e-4 60 | adjoint: True 61 | 62 | fusion: 63 | latent_dim: 64 64 | output_dim: 64 65 | num_layers: 3 66 | gnn_type: 'GCN' 67 | 68 | decoder: 69 | enable: True 70 | 71 | loss: 72 | kl_loss: False 73 | recon_loss: False 74 | pred_loss: True 75 | cl_loss: True 76 | cl_coeff: 5 77 | criterion: mae 78 | 79 | train: 80 | lr: 0.005 81 | lradj: 'default' 82 | epochs: 100 83 | pct_start: 0.4 84 | patience: 20 85 | steps: [20, 30, 40, 50] 86 | lr_decay_ratio: 0.1 87 | log_every: 1 88 | 89 | GPU: 90 | use_gpu: True 91 | gpu: 0 92 | use_multi_gpu: False 93 | devices: '0, 1, 2, 3' -------------------------------------------------------------------------------- /Model_Config/Beijing/Air-DualODE_config.yaml: -------------------------------------------------------------------------------- 1 | log_base_dir: ./logs 2 | to_log_file: True 3 | to_stdout: True 4 | TB_dir: False 5 | checkpoints: ./checkpoints 6 | log_level: INFO 7 | model_name: Air-DualODE 8 | 9 | data: 10 | batch_size: 32 11 | interval: '3h' 12 | data_name: Beijing1718 13 | root_path: ../dataset/Beijing1718 14 | num_workers: 0 15 | normalized_columns: [ "PM2.5", "temperature", "pressure", "humidity", "wind_speed", "wind_direction" ] 16 | embed: 0 17 | 18 | model: 19 | seq_len: 24 20 | horizon: 24 21 | input_dim: 6 22 | embed_dim: 6 23 | X_dim: 1 24 | 25 | embedding: 26 | hour2day: 3 27 | day2week: 2 28 | day2month: 3 29 | month2year: 2 30 | station: 3 31 | 32 | phy_func: 33 | enable: True 34 | knowledge: diff_adv 35 | gnn_layers: 2 36 | cheb_k: 3 37 | gnn_hid_dim: 64 38 | coeff_estimator: False 39 | rnn_layers: 1 40 | rnn_dim: 64 41 | input_dim: 1 42 | latent_dim: 64 43 | ode_method: dopri5 44 | odeint_atol: 1e-3 45 | odeint_rtol: 1e-3 46 | adjoint: False 47 | 48 | unk_func: 49 | enable: True 50 | rnn_layers: 1 51 | rnn_dim: 64 52 | input_dim: 6 53 | latent_dim: 64 54 | n_heads: 2 55 | d_f: 32 56 | ode_method: rk4 57 | odeint_atol: 1e-3 58 | odeint_rtol: 1e-3 59 | adjoint: False 60 | 61 | fusion: 62 | latent_dim: 128 63 | output_dim: 128 64 | num_layers: 3 65 | gnn_type: 'GCN' 66 | 67 | decoder: 68 | enable: True 69 | 70 | loss: 71 | kl_loss: False 72 | recon_loss: False 73 | cl_loss: True 74 | pred_loss: True 75 | cl_coeff: 1 76 | criterion: mae 77 | 78 | train: 79 | lr: 0.005 80 | lradj: 'default' 81 | epochs: 100 82 | pct_start: 0.4 83 | patience: 20 84 | steps: [20, 30, 40, 50] 85 | lr_decay_ratio: 0.1 86 | log_every: 1 87 | 88 | GPU: 89 | use_gpu: True 90 | gpu: 0 91 | use_multi_gpu: False 92 | devices: '0, 1, 2, 3' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Air-DualODE 2 | 3 | This repository contains the PyTorch implementation of our ICLR'25 paper, **“Air Quality Prediction with Physics-Guided Dual Neural ODEs in Open Systems.”** In this work, we introduce **Air-DualODE** for predicting air quality at both city and national levels. Our model is composed of three key components: **Physics Dynamics, Data-Driven Dynamics,** and **Dynamics Fusion.** 4 | 5 | 🚩 News (2025.1) Air-DualODE has been accepted by ICLR 2025 (poster). 6 | 7 | 🚩 News (2025.10) We provide [create_graph.py](https://github.com/decisionintelligence/Air-DualODE/tree/main/Run/create_graph.py) for generating the graph_data.npz file on other air quality datasets. Feel free to test Air-DualODE on different datasets. 8 | 9 | ![image-20250225200329134](./fig/Air-DualODE.png) 10 | 11 | ## Requirement 12 | 13 | * python >= 3.9 14 | 15 | ```shell 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Data Preparation 20 | 21 | Beijing: https://www.biendata.xyz/competition/kdd_2018/ 22 | 23 | KnowAir: https://github.com/shuowang-ai/PM2.5-GNN 24 | 25 | ## Train and Evaluation 26 | 27 | ```shell 28 | cd Run 29 | ``` 30 | 31 | **Beijing** 32 | 33 | ```python 34 | python train.py --config_filename ../Model_Config/Beijing/Air-DualODE_config.yaml --des 1 35 | python eval.py --config_filename ../Model_Config/Beijing/Air-DualODE_config.yaml --des 1 36 | ``` 37 | 38 | **KnowAir** 39 | 40 | ```python 41 | python train.py --config_filename ../Model_Config/KnowAir/Air-DualODE_config.yaml --des 1 42 | python eval.py --config_filename ../Model_Config/KnowAir/Air-DualODE_config.yaml --des 1 43 | ``` 44 | 45 | ## Citation 46 | 47 | If you find this repo useful, please cite our paper. 48 | 49 | ``` 50 | @inproceedings{tian2024air-dualode, 51 | title={Air quality prediction with Physics-Guided dual neural odes in open systems}, 52 | author={Tian, Jindong and Liang, Yuxuan and Xu, Ronghui and Chen, Peng and Guo, Chenjuan and Zhou, Aoying and Pan, Lujia and Rao, Zhongwen and Yang, Bin}, 53 | journal={ICLR}, 54 | year={2025} 55 | } 56 | ``` 57 | 58 | -------------------------------------------------------------------------------- /Run/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import yaml 3 | import os 4 | 5 | gpu_list = "0,1,2,3,4,5,6,7" # GPU lst 6 | device_map = {gpu: i for i, gpu in enumerate(gpu_list.split(','))} 7 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list 8 | sys.path.append('../../Air-DualODE') 9 | 10 | import argparse 11 | import torch 12 | import random 13 | from utils.utils import parsing_syntax, ConfigDict, load_config, update_config, fix_seed 14 | from exp.exp_air import Exp_Air_Pollution 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser(description='Air-DualODE') 19 | 20 | parser.add_argument('--config_filename', type=str, default='../Model_Config/basic_config.yaml', help='Configuration yaml file') 21 | parser.add_argument('--itr', type=int, default=1, help='Number of experiments.') 22 | parser.add_argument('--random_seed', type=int, default=2024, help='Random seed.') 23 | parser.add_argument('--des', type=str, help="description of experiment.") 24 | args, unknown = parser.parse_known_args() 25 | unknown = parsing_syntax(unknown) 26 | 27 | config = load_config(args.config_filename) 28 | config = ConfigDict(config) 29 | config = update_config(config, unknown) 30 | for attr, value in config.items(): 31 | setattr(args, attr, value) 32 | 33 | # random seed 34 | fix_seed(args.random_seed) 35 | 36 | args.GPU.use_gpu = True if torch.cuda.is_available() and args.GPU.use_gpu else False 37 | 38 | if args.GPU.use_gpu and not args.GPU.use_multi_gpu: 39 | try: 40 | args.GPU.gpu = device_map[str(args.GPU.gpu)] 41 | except KeyError: 42 | raise KeyError("This GPU isn't available.") 43 | 44 | if args.GPU.use_gpu and args.GPU.use_multi_gpu: 45 | args.GPU.devices = args.GPU.devices.replace(' ', '') 46 | device_ids = args.GPU.devices.split(',') 47 | args.GPU.device_ids = [int(id_) for id_ in device_ids] 48 | args.GPU.gpu = args.GPU.device_ids[0] 49 | 50 | rmse_list, mae_list, mape_list = [], [], [] 51 | for exp_idx in range(args.itr): 52 | args.exp_idx = exp_idx 53 | if args.to_stdout: 54 | print('\nNo%d experiment ~~~' % exp_idx) 55 | 56 | exp = Exp_Air_Pollution(args) 57 | exp.train() 58 | torch.cuda.empty_cache() 59 | -------------------------------------------------------------------------------- /models/layers/Decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class Linear_Decoder(nn.Module): 5 | def __init__(self, latent_dim, output_dim, num_nodes): 6 | super(Linear_Decoder, self).__init__() 7 | self.linear = nn.Linear(latent_dim*num_nodes, output_dim*num_nodes) 8 | 9 | def forward(self, z): 10 | # T x N*latent_dim -> T x N*output_dim 11 | return self.linear(z) 12 | 13 | class Conv1d_Decoder(nn.Module): 14 | def __init__(self, latent_dim, output_dim, num_nodes, k=1): 15 | super(Conv1d_Decoder, self).__init__() 16 | self.latent_dim = latent_dim 17 | self.num_nodes = num_nodes 18 | self.output_dim = output_dim 19 | padding = (k - 1) // 2 20 | self.decoder = nn.Sequential( 21 | nn.Conv1d(in_channels=latent_dim, 22 | out_channels=output_dim, 23 | kernel_size=k, 24 | padding=padding, 25 | bias=True) 26 | ) 27 | 28 | def forward(self, z): 29 | # T x B x N*latent_dim -> T x B x N*Dout 30 | T, B, _ = z.shape 31 | z = z.reshape(T, B*self.num_nodes, self.latent_dim) 32 | z = z.permute(1, 2, 0) 33 | z = self.decoder(z) # B*N x Dout x T 34 | z = z.permute(2, 0, 1) # T x B*N x Dout 35 | z = z.reshape(T, B, self.num_nodes * self.output_dim) 36 | return z 37 | 38 | 39 | class Conv2d_Decoder(nn.Module): 40 | def __init__(self, latent_dim, output_dim, num_nodes, d_f=64): 41 | super(Conv2d_Decoder, self).__init__() 42 | self.latent_dim = latent_dim 43 | self.num_nodes = num_nodes 44 | self.output_dim = output_dim 45 | self.decoder = nn.Sequential( 46 | nn.Conv2d(in_channels=latent_dim, 47 | out_channels=d_f, 48 | kernel_size=(1, 1), 49 | bias=True), 50 | nn.ReLU(), 51 | nn.Conv2d(in_channels=d_f, 52 | out_channels=output_dim, 53 | kernel_size=(1, 1), 54 | bias=True) 55 | ) 56 | 57 | def forward(self, z): 58 | # T x B x N*latent_dim -> T x B x N*output_dim 59 | T, B, _ = z.shape 60 | z = z.reshape(T, B, self.num_nodes, self.latent_dim) 61 | z = z.permute(1, 3, 2, 0) # B x latent_dim x N x T 62 | z = self.decoder(z) # B x output_dim x N x T 63 | z = z.permute(3, 0, 2, 1) 64 | z = z.reshape(T, B, self.num_nodes*self.output_dim) 65 | return z 66 | 67 | 68 | class Conv_seq_Decoder(nn.Module): 69 | def __init__(self, latent_dim, output_dim, num_nodes, seq_len=24): 70 | super(Conv_seq_Decoder, self).__init__() 71 | self.latent_dim = latent_dim 72 | self.output_dim = output_dim 73 | self.num_nodes = num_nodes 74 | self.seq_len = 24 75 | self.spatial_conv = nn.Sequential( 76 | nn.Conv1d(in_channels=seq_len*latent_dim, 77 | out_channels=seq_len*output_dim, 78 | kernel_size=1, 79 | bias=True) 80 | ) 81 | 82 | def forward(self, z): 83 | # T x B x N*latent_dim -> T x B x N*Dout 84 | # spatial_conv's input: B x T*latent_dim x N 85 | T, B, _ = z.shape 86 | assert T == self.seq_len 87 | z = z.reshape(T, B, self.num_nodes, self.latent_dim) 88 | z = z.permute(1, 0, 3, 2) # B x T x latent_dim x N 89 | z = z.reshape(B, T*self.latent_dim, self.num_nodes) 90 | z = self.spatial_conv(z) # B x T*Dout x N 91 | z = z.reshape(B, T, self.output_dim, self.num_nodes) 92 | z = z.permute(1, 0, 3, 2) 93 | z = z.reshape(T, B, self.num_nodes*self.output_dim) 94 | return z -------------------------------------------------------------------------------- /Run/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import yaml 3 | import os 4 | 5 | gpu_list = "0,1,2,3,4,5,6,7" # GPU lst 6 | device_map = {gpu: i for i, gpu in enumerate(gpu_list.split(','))} 7 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list 8 | sys.path.append('../../Air-DualODE') 9 | 10 | import argparse 11 | import torch 12 | import random 13 | import numpy as np 14 | from utils.utils import parsing_syntax, ConfigDict, load_config, update_config, fix_seed 15 | from Evaluation.evaluation import Evaluation_Air_Pollution 16 | 17 | 18 | def get_mean_std(data_list): 19 | return data_list.mean(), data_list.std() 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser(description='Air-DualODE') 24 | 25 | parser.add_argument('--config_filename', type=str, default='../Model_Config/basic_config.yaml', help='Configuration yaml file') 26 | parser.add_argument('--itr', type=int, default=1, help='Number of experiments.') 27 | parser.add_argument('--random_seed', type=int, default=2024, help='Random seed.') 28 | parser.add_argument('--des', type=str, help="description of experiment.") 29 | parser.add_argument("--report_filepath", type=str, default=None, help="evaluation report output") 30 | parser.add_argument("--save_results", type=bool, default=False, help="whether to save results") 31 | parser.add_argument("--save_plots", type=bool, default=False, help="whether to save plots") 32 | args, unknown = parser.parse_known_args() 33 | unknown = parsing_syntax(unknown) 34 | 35 | config = load_config(args.config_filename) 36 | config = ConfigDict(config) 37 | config = update_config(config, unknown) 38 | for attr, value in config.items(): 39 | setattr(args, attr, value) 40 | 41 | # random seed 42 | fix_seed(args.random_seed) 43 | 44 | args.GPU.use_gpu = True if torch.cuda.is_available() and args.GPU.use_gpu else False 45 | 46 | if args.GPU.use_gpu and not args.GPU.use_multi_gpu: 47 | try: 48 | args.GPU.gpu = device_map[str(args.GPU.gpu)] 49 | except KeyError: 50 | raise KeyError("This GPU isn't available.") 51 | 52 | if args.GPU.use_gpu and args.GPU.use_multi_gpu: 53 | args.GPU.devices = args.GPU.devices.replace(' ', '') 54 | device_ids = args.GPU.devices.split(',') 55 | args.GPU.device_ids = [int(id_) for id_ in device_ids] 56 | args.GPU.gpu = args.GPU.device_ids[0] 57 | 58 | rmse_list, mae_list, mape_list = [], [], [] 59 | for exp_idx in range(args.itr): 60 | args.exp_idx = exp_idx 61 | print('\nNo%d experiment ~~~' % exp_idx) 62 | 63 | exp = Evaluation_Air_Pollution(args) 64 | exp.vali() 65 | 66 | # 测试评估 67 | mae, mape, rmse, preds, truths = exp.test() 68 | mae_list.append(mae) 69 | mape_list.append(mape) 70 | rmse_list.append(rmse) 71 | 72 | mae_list = np.array(mae_list) # num_exp x num_seq 73 | mape_list = np.array(mape_list) 74 | rmse_list = np.array(rmse_list) 75 | 76 | seq_len = [(0, 8), (8, 16), (16, 24)] # seq_len * 3小时(3小时一个点) 77 | output_text = '' 78 | output_text += '--------- Air-DualODE Final Results ------------\n' 79 | for i, (start, end) in enumerate(seq_len): 80 | output_text += 'Evaluation seq {}h-{}h:\n'.format(start, end) 81 | output_text += 'MAE | mean: {:.4f} std: {:.4f}\n'.format(get_mean_std(mae_list[:, i])[0], 82 | get_mean_std(mae_list[:, i])[1]) 83 | output_text += 'MAPE | mean: {:.4f} std: {:.4f}\n'.format(get_mean_std(mape_list[:, i])[0], 84 | get_mean_std(mape_list[:, i])[1]) 85 | output_text += 'RMSE | mean: {:.4f} std: {:.4f}\n\n'.format(get_mean_std(rmse_list[:, i])[0], 86 | get_mean_std(rmse_list[:, i])[1]) 87 | 88 | # Write the output text to a file 89 | with open('logs/air-dualode_results.txt', 'a') as file: 90 | file.write(output_text) -------------------------------------------------------------------------------- /models/layers/Dynamics_Funsion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn import GCNConv, GATConv, ChebConv 4 | from utils.utils import load_graph_data 5 | from torch.nn import functional as F 6 | 7 | class GNN_Knowledge_Fusion(nn.Module): 8 | def __init__(self, num_nodes, phy_dim, unk_dim, output_dim, edge_index, edge_attr, 9 | hid_dim=64, gnn_type='GCN', num_layers=3): 10 | super(GNN_Knowledge_Fusion, self).__init__() 11 | self.num_nodes = num_nodes 12 | self.phy_dim = phy_dim 13 | self.unk_dim = unk_dim 14 | self.concat_dim = phy_dim + unk_dim 15 | self.output_dim = output_dim 16 | self.edge_index = edge_index # 2 x M 17 | self.edge_attr = edge_attr 18 | self.gnn_type = gnn_type 19 | if edge_attr.shape[1] > 1: 20 | assert gnn_type == 'GAT' 21 | self.hid_dim = hid_dim 22 | self.num_layers = num_layers 23 | self.residual = nn.Identity() 24 | 25 | self.fusion = self.get_fusion() 26 | 27 | def forward(self, phy_hidden, unk_hidden): 28 | T, B, _ = phy_hidden.shape 29 | phy_hidden = phy_hidden.reshape(T*B, self.num_nodes, self.phy_dim) 30 | unk_hidden = unk_hidden.reshape(T*B, self.num_nodes, self.unk_dim) 31 | 32 | concat_hidden = torch.cat((phy_hidden, unk_hidden), dim=2) 33 | out = self.fusion[0](concat_hidden, self.edge_index, self.edge_attr) 34 | out = F.relu(out) 35 | 36 | for gnn in self.fusion[1:-1]: 37 | residual = self.residual(out) 38 | out = gnn(out, self.edge_index, self.edge_attr) 39 | out = F.relu(out) + residual 40 | 41 | x = self.fusion[-1](out, self.edge_index, self.edge_attr) 42 | 43 | return x.reshape(T, B, self.num_nodes*self.output_dim) 44 | 45 | def get_fusion(self): 46 | fusion = nn.ModuleList() 47 | if self.gnn_type == "GCN": 48 | fusion.append( 49 | GCNConv(in_channels=self.concat_dim, out_channels=self.hid_dim, 50 | cached=True) 51 | ) 52 | 53 | for _ in range(self.num_layers - 2): 54 | fusion.append( 55 | GCNConv(in_channels=self.hid_dim, out_channels=self.hid_dim, 56 | cached=True) 57 | ) 58 | 59 | fusion.append( 60 | GCNConv(in_channels=self.hid_dim, out_channels=self.output_dim, 61 | cached=True) 62 | ) 63 | 64 | elif self.gnn_type == "Cheb": 65 | fusion.append( 66 | ChebConv(in_channels=self.concat_dim, out_channels=self.hid_dim, 67 | K=3) 68 | ) 69 | 70 | for _ in range(self.num_layers - 2): 71 | fusion.append( 72 | ChebConv(in_channels=self.hid_dim, out_channels=self.hid_dim, 73 | K=3) 74 | ) 75 | 76 | fusion.append( 77 | ChebConv(in_channels=self.hid_dim, out_channels=self.output_dim, 78 | K=3) 79 | ) 80 | 81 | elif self.gnn_type == "GAT": 82 | fusion.append( 83 | GATConv(in_channels=self.concat_dim, out_channels=self.hid_dim, 84 | heads=4, concat=False, edge_dim=self.edge_attr.shape[1]) 85 | ) 86 | 87 | for _ in range(self.num_layers - 2): 88 | fusion.append( 89 | GATConv(in_channels=self.hid_dim, out_channels=self.hid_dim, 90 | heads=4, concat=False, edge_dim=self.edge_attr.shape[1]) 91 | ) 92 | 93 | fusion.append( 94 | GATConv(in_channels=self.hid_dim, out_channels=self.output_dim, 95 | heads=4, concat=False, edge_dim=self.edge_attr.shape[1]) 96 | ) 97 | 98 | else: 99 | raise NotImplementedError 100 | 101 | return fusion 102 | 103 | -------------------------------------------------------------------------------- /models/layers/Encoder.py: -------------------------------------------------------------------------------- 1 | from utils.utils import init_network_weights, ConfigDict, split_last_dim 2 | from torch import nn 3 | import torch 4 | from typing import Optional 5 | from torch.nn.modules.rnn import GRU 6 | 7 | 8 | class Coeff_Estimator_new(nn.Module): 9 | # diffusion coefficient: BxNx1 10 | # boundary condition: BxNx1 11 | def __init__(self, input_dim, coeff_dim, num_nodes, rnn_dim, n_layers): 12 | nn.Module.__init__(self) 13 | 14 | self.input_dim = input_dim 15 | self.coeff_dim = coeff_dim 16 | self.num_nodes = num_nodes 17 | self.n_layers = n_layers 18 | self.rnn_dim = rnn_dim 19 | 20 | self.net = nn.Sequential( 21 | nn.Linear(self.input_dim, self.input_dim), 22 | nn.ReLU(), 23 | nn.Linear(self.input_dim, self.coeff_dim*2) 24 | ) 25 | 26 | def forward(self, inputs): 27 | """ 28 | encoder forward pass on t time steps 29 | :param inputs: T x B x N*D 30 | :return: alpha, beta: B x N*coeff_dim 31 | """ 32 | # shape of outputs: (seq_len, batch, num_senor * rnn_dim) 33 | seq_len, batch_size = inputs.size(0), inputs.size(1) 34 | inputs = inputs.reshape(seq_len, batch_size, self.num_nodes, self.input_dim) 35 | last_inputs = inputs[-1, ...] 36 | 37 | coeff = self.net(last_inputs) # B x N x 2 38 | coeff = torch.reshape(coeff, (batch_size, self.num_nodes*self.coeff_dim, 2)) 39 | 40 | alpha = coeff[..., 0] 41 | beta = coeff[..., 1] 42 | 43 | return alpha, beta 44 | 45 | 46 | class Encoder_phy_z(nn.Module): 47 | def __init__(self, input_dim, latent_dim, num_layers, num_nodes): 48 | nn.Module.__init__(self) 49 | 50 | self.input_dim = input_dim 51 | self.phy_latent_dim = latent_dim 52 | self.num_layers = num_layers 53 | self.num_nodes = num_nodes 54 | 55 | self.gru_rnn = GRU(input_dim, latent_dim, num_layers=self.num_layers) 56 | 57 | def forward(self, X_p): 58 | """ 59 | :param X_p: T x B x N x D(1 + station_embed_dim + time_embed_dim) 60 | :return: Z: T x B x N*phy_latent_dim 61 | """ 62 | T, B = X_p.size(0), X_p.size(1) 63 | X_p = X_p.reshape(T, B * self.num_nodes, self.input_dim) # T x B*N x input_dim 64 | 65 | Z_p, _ = self.gru_rnn(X_p) # T x B*N x phy_latent_dim 66 | 67 | Z_p = Z_p.reshape(T, B, self.num_nodes*self.phy_latent_dim) 68 | 69 | return Z_p 70 | 71 | 72 | class Encoder_unk_z(nn.Module): 73 | def __init__(self, input_dim, latent_dim, num_nodes, rnn_dim, n_layers): 74 | nn.Module.__init__(self) 75 | 76 | self.input_dim = input_dim 77 | self.latent_dim = latent_dim 78 | self.num_nodes = num_nodes 79 | self.n_layers = n_layers 80 | self.rnn_dim = rnn_dim 81 | self.gru_rnn = GRU(self.input_dim, rnn_dim, num_layers=n_layers) 82 | 83 | # hidden to z0 settings 84 | self.hiddens_to_z0 = nn.Sequential( 85 | nn.Linear(self.rnn_dim, 50), 86 | nn.Tanh(), 87 | nn.Linear(50, self.latent_dim)) 88 | 89 | init_network_weights(self.hiddens_to_z0) 90 | 91 | def forward(self, X): 92 | """ 93 | encoder forward pass on t time steps 94 | :param X: shape (seq_len, batch_size, num_nodes, D) 95 | :return: Z0: shape (batch_size, num_nodes * latent_dim) 96 | """ 97 | seq_len, batch_size = X.size(0), X.size(1) 98 | X = X.reshape(seq_len, batch_size * self.num_nodes, self.input_dim) # (24, 32 * 35 = 1120, input_dim) 99 | 100 | outputs, _ = self.gru_rnn(X) # 24 x 35*32 x 64(rnn_dim) 101 | 102 | last_output = outputs[-1] 103 | # (batch_size, num_nodes, rnn_dim) 104 | last_output = torch.reshape(last_output, (batch_size, self.num_nodes, self.rnn_dim)) 105 | Z0 = self.hiddens_to_z0(last_output) 106 | Z0 = torch.reshape(Z0, (batch_size, self.num_nodes*self.latent_dim)) 107 | 108 | return Z0 -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import logging 4 | import sys 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import random 9 | 10 | 11 | def load_config(config_path): 12 | with open(config_path, 'r') as file: 13 | config = yaml.safe_load(file) 14 | return config 15 | 16 | 17 | def parsing_syntax(unknown): 18 | unknown_dict = {} 19 | key = None 20 | for arg in unknown: 21 | if arg.startswith('--'): 22 | key = arg.lstrip('--') 23 | unknown_dict[key] = None 24 | else: 25 | if key: 26 | unknown_dict[key] = arg 27 | key = None 28 | return unknown_dict 29 | 30 | 31 | class ConfigDict(dict): 32 | def __init__(self, *args, **kwargs): 33 | super(ConfigDict, self).__init__(*args, **kwargs) 34 | for key, value in self.items(): 35 | if isinstance(value, dict): 36 | self[key] = ConfigDict(value) 37 | if key == 'data' and isinstance(value, str): 38 | dataset_config = load_config("../Model_Config/dataset_config/{}".format(value + ".yaml")) 39 | self[key]= ConfigDict(dataset_config) 40 | 41 | def __getattr__(self, item): 42 | try: 43 | return self[item] 44 | except KeyError: 45 | raise AttributeError(f"'ConfigDict' object has no attribute '{item}'") 46 | 47 | def __setattr__(self, key, value): 48 | self[key] = value 49 | 50 | 51 | def update_config(config, unknown_args): 52 | for key, value in unknown_args.items(): 53 | config_path = key.split('-') 54 | cur = config 55 | for node in config_path: 56 | assert node in cur.keys(), "path not exist" 57 | if isinstance(cur[node], ConfigDict): 58 | cur = cur[node] 59 | else: 60 | try: 61 | cur[node] = eval(value) 62 | except NameError: 63 | cur[node] = value 64 | return config 65 | 66 | 67 | def load_graph_data(dataset_path): 68 | npz_path = os.path.join(dataset_path, 'graph_data.npz') 69 | data = np.load(npz_path) 70 | 71 | adj_mx = data['adj_mx'] 72 | edge_index = data['edge_index'] 73 | edge_attr = data['edge_attr'] # {diff_dist, dist_km, direction} 74 | node_attr = data['node_attr'] 75 | 76 | return adj_mx, edge_index.T, edge_attr, node_attr 77 | 78 | 79 | def fix_seed(seed): 80 | os.environ['PYTHONHASHSEED'] = str(seed) 81 | random.seed(seed) 82 | np.random.seed(seed) 83 | torch.manual_seed(seed) 84 | torch.cuda.manual_seed(seed) 85 | torch.cuda.manual_seed_all(seed) 86 | torch.backends.cudnn.deterministic = True 87 | torch.backends.cudnn.benchmark = False 88 | 89 | 90 | def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO, to_stdout=True): 91 | logger = logging.getLogger(name) 92 | logger.setLevel(level) 93 | # Add console handler. 94 | if to_stdout: 95 | console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 96 | console_handler = logging.StreamHandler(sys.stdout) 97 | console_handler.setFormatter(console_formatter) 98 | logger.addHandler(console_handler) 99 | # Add file handler and stdout handler 100 | if log_dir: 101 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%m-%d %H:%M') 102 | file_handler = logging.FileHandler(os.path.join(log_dir, log_filename)) 103 | file_handler.setFormatter(formatter) 104 | logger.addHandler(file_handler) 105 | logger.info('Log directory: %s', log_dir) 106 | return logger 107 | 108 | 109 | def init_network_weights(net, std=0.1): 110 | """ 111 | Just for nn.Linear net. 112 | """ 113 | for m in net.modules(): 114 | if isinstance(m, nn.Linear): 115 | nn.init.normal_(m.weight, mean=0, std=std) 116 | nn.init.constant_(m.bias, val=0) 117 | 118 | 119 | def split_last_dim(data): 120 | last_dim = data.size()[-1] 121 | last_dim = last_dim // 2 122 | 123 | res = data[..., :last_dim], data[..., last_dim:] 124 | return res 125 | 126 | 127 | def exchange_df_column(df, col1, col2): 128 | """ 129 | exchange df column 130 | :return new_df 131 | """ 132 | assert (col1 in df.columns) and (col2 in df.columns) 133 | df[col1], df[col2] = df[col2].copy(), df[col1].copy() 134 | df = df.rename(columns={col1: 'temp', col2: col1}) 135 | df = df.rename(columns={'temp': col2}) 136 | return df -------------------------------------------------------------------------------- /Run/create_graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from collections import OrderedDict 4 | from scipy.spatial import distance 5 | from torch_geometric.utils import dense_to_sparse, to_dense_adj 6 | from geopy.distance import geodesic 7 | from metpy.units import units 8 | import metpy.calc as mpcalc 9 | from haversine import haversine, Unit 10 | import pandas as pd 11 | from utils.geo_utils import get_elevation, interpolate_points 12 | import os 13 | 14 | 15 | get_elevation = np.vectorize(get_elevation) 16 | 17 | class Graph: 18 | def __init__(self, sensor_path, srtm_path, dist_thres=300, alti_thres=1200, middle=15): 19 | """ 20 | dist_thres: km 21 | alti_thres: m 22 | adj_mx: N x N 23 | edge_index: M x 2 24 | edge_attr: M x D 25 | node_attr: N x D 26 | """ 27 | self.dist_thres = dist_thres 28 | self.alti_thres = alti_thres 29 | self.middle = middle 30 | 31 | self.sensor = pd.read_csv(sensor_path) 32 | self.nodes_num = self.sensor.shape[0] 33 | self.srtm_path = srtm_path 34 | 35 | self.query_altitude() 36 | self.adj_mx = self._gen_edges() 37 | self.edge_index, self.edge_attr, self.node_attr = self._gen_attr() 38 | 39 | def query_altitude(self): 40 | lat = self.sensor['latitude'].values 41 | lon = self.sensor['longitude'].values 42 | alt = get_elevation(lat, lon, self.srtm_path) 43 | self.sensor['altitude'] = alt 44 | 45 | def _gen_edges(self): 46 | points = self.sensor[['latitude', 'longitude']].values # N x 2 47 | 48 | dist_km = np.zeros((self.nodes_num, self.nodes_num)) 49 | for i, A in enumerate(points): 50 | for j, B in enumerate(points): 51 | dist_km[i][j] = haversine(A, B, unit=Unit.KILOMETERS) 52 | dist_adj = np.zeros((self.nodes_num, self.nodes_num), dtype=np.uint8) 53 | dist_adj[dist_km <= self.dist_thres] = 1 54 | np.fill_diagonal(dist_adj, 0) 55 | 56 | alt_adj = np.zeros_like(dist_adj) 57 | edge_index, _ = dense_to_sparse(torch.tensor(dist_adj)) 58 | edges = edge_index.t().tolist() 59 | edges = [sorted(e) for e in edges] 60 | unique_edges = list(map(list, set(map(tuple, edges)))) 61 | unique_edge_index = np.array(unique_edges).T 62 | 63 | M = unique_edge_index.shape[1] 64 | middle_points = np.zeros((M, self.middle, 2)) 65 | for i in range(M): 66 | src = self.sensor[['latitude', 'longitude']].iloc[unique_edge_index[0][i]] 67 | dest = self.sensor[['latitude', 'longitude']].iloc[unique_edge_index[1][i]] 68 | middle = interpolate_points(src, dest, num_points=self.middle) 69 | middle_points[i] = middle 70 | middle_points_alt = get_elevation(middle_points[:, :, 0], middle_points[:, :, 1], self.srtm_path) 71 | max_alt = np.max(middle_points_alt[:, [0, -1]], axis=1)[:, None] # M 72 | alt_gap = np.max(middle_points_alt - max_alt, axis=1) 73 | for i in range(M): 74 | src_index = unique_edge_index[0][i] 75 | dest_index = unique_edge_index[1][i] 76 | if alt_gap[i] <= self.alti_thres: 77 | alt_adj[src_index][dest_index] = 1 78 | alt_adj[dest_index][src_index] = 1 79 | adj = dist_adj * alt_adj 80 | 81 | return adj 82 | 83 | def _gen_attr(self): 84 | edge_index, _ = dense_to_sparse(torch.tensor(self.adj_mx)) 85 | edge_index = edge_index.numpy() 86 | node_attr = self.sensor['altitude'].values 87 | edge_attr = [] 88 | M = edge_index.shape[1] 89 | 90 | for i in range(M): 91 | src_index = edge_index[0][i] 92 | dest_index = edge_index[1][i] 93 | src = self.sensor[['latitude', 'longitude']].iloc[src_index] 94 | dest = self.sensor[['latitude', 'longitude']].iloc[dest_index] 95 | dist_km = haversine(src, dest, unit=Unit.KILOMETERS) 96 | # diff_dist = np.exp(-dist_km) / 0.001 97 | diff_dist = 1 / dist_km 98 | 99 | v, u = src['latitude'] - dest['latitude'], src['longitude'] - dest['longitude'] 100 | u = u * units.meter / units.second 101 | v = v * units.meter / units.second 102 | direction = mpcalc.wind_direction(u, v)._magnitude 103 | 104 | edge_attr.append([diff_dist, dist_km, direction]) 105 | 106 | edge_attr = np.array(edge_attr) 107 | edge_index = edge_index.T 108 | node_attr = node_attr[:, None] 109 | 110 | return edge_index, edge_attr, node_attr 111 | 112 | def save_npz(self, save_path): 113 | save_path = os.path.join(save_path, 'graph_data.npz') 114 | np.savez(save_path, adj_mx=self.adj_mx, node_attr=self.node_attr, 115 | edge_index=self.edge_index, edge_attr=self.edge_attr) 116 | 117 | # KnowAir Example 118 | if __name__ == "__main__": 119 | dataset_name = "KnowAir" 120 | graph = Graph("../dataset/{}/station.csv".format(dataset_name), "../dataset/srtm", 121 | dist_thres=300, alti_thres=1200, middle=15) 122 | graph.save_npz('../dataset/{}'.format(dataset_name)) 123 | print(graph.adj_mx.shape) 124 | print(graph.edge_index.shape) 125 | print(graph.edge_attr.shape) 126 | print(graph.node_attr.shape) -------------------------------------------------------------------------------- /dataset/KnowAir/station.csv: -------------------------------------------------------------------------------- 1 | station,longitude,latitude 2 | Beijing,116.39824999999998,40.045975000000006 3 | Tianjin,117.32216666666666,39.07768333333333 4 | Shijiazhuang,114.49317142857144,38.033628571428565 5 | Tangshan,118.18291666666666,39.64495333333334 6 | Qinhuangdao,119.606875,39.93625 7 | Handan,114.51387499999998,36.60786 8 | Baoding,115.4852,38.876266666666666 9 | Zhangjiakou,114.90085,40.80275 10 | Chengde,117.92774,40.96416 11 | Langfang,116.7151,39.52605 12 | Cangzhou,116.87156666666668,38.31576666666667 13 | Hengshui,115.6761,37.744800000000005 14 | Xingtai,114.506675,37.0771 15 | Taiyuan,112.509,37.8489 16 | Huhehaote,111.66738571428571,40.80768571428571 17 | Dalian,121.62808148888888,38.95904567888889 18 | Shanghai,121.45294433333332,31.210789 19 | Nanjing,118.7745,32.0488 20 | Suzhou,119.52118,32.002930000000006 21 | Nantong,120.8786,31.99044 22 | Lianyungang,119.176,34.5885 23 | Xuzhou,117.23046000000002,34.260400000000004 24 | Yangzhou,119.3965,32.39305 25 | Wuxi,120.28957142857143,31.570185714285717 26 | Changzhou,119.94533333333334,31.7911 27 | Zhenjiang,119.53366666666666,32.1784 28 | Taizhou,120.62981666666668,30.54818333333333 29 | Huaian,119.03584,33.58214 30 | Yancheng,120.16633333333334,33.39703333333333 31 | Suqian,118.3072,33.9517 32 | Hangzhou,120.0819090909091,30.21117272727272 33 | Ningbo,121.62942857142858,29.8667 34 | Shaoxing,120.605,29.9919 35 | Huzhou,120.07,30.8244 36 | Jiaxing,120.735,30.7712 37 | Jinhua,119.6665,29.08995 38 | Quzhou,118.86566666666668,28.960433333333334 39 | Lishui,119.90766666666666,28.444366666666667 40 | Hefei,117.2478,31.848240000000004 41 | Nanchang,115.873075,28.702825 42 | Jinan,117.00885,36.6576 43 | Qingdao,120.36067142857142,36.14881428571428 44 | Zhengzhou,113.66537142857143,34.760628571428576 45 | Wuhan,114.2870125,30.5767625 46 | Zhangsha,112.98363,28.205220000000004 47 | Zhongqing,106.50192307692306,29.617815384615387 48 | Chengdou,104.002,30.7155 49 | Xian,108.99450909090908,34.30269090909091 50 | Lanzhou,103.86275,36.0405 51 | Yinchuan,106.1066,38.53423333333333 52 | Baotou,109.90826666666668,40.627783333333326 53 | Eerduosi,109.90196,39.72498 54 | Huludao,120.864825,40.7384 55 | Laiwu,117.6789,36.2289 56 | Linyi,118.3286,35.04769999999999 57 | Liaocheng,115.98415,36.4584 58 | Binzhou,117.9952,37.37833333333333 59 | Zibo,118.02022000000002,36.75354 60 | Zaozhuang,117.51355,34.806025 61 | Yantai,121.36026666666667,37.5074 62 | Weifang,119.15213333333332,36.72526666666667 63 | Jining,116.59023333333332,35.41543333333333 64 | Taian,117.11326666666666,36.18806666666666 65 | Rizhao,119.49195,35.4206 66 | Dongying,118.5019,37.4658 67 | Heze,115.45083066666666,35.25212966666667 68 | Datong,113.31258,40.0977 69 | Zhangzhi,113.0949861,36.18655 70 | Linfen,111.518825,36.07985 71 | Yangquan,113.56226666666667,37.8607 72 | Jinzhou,121.1185,41.05882 73 | Wuhu,118.37395,31.36785 74 | Maanshan,118.52462,31.69678 75 | Jiujiang,115.98975714285714,29.673300000000005 76 | Luoyang,112.41968333333334,34.65716666666667 77 | Anyang,114.37566666666665,36.086 78 | Kaifeng,114.3336,34.79076666666667 79 | Jiaozuo,113.2242,35.21613333333333 80 | Pingdingshan,113.26013333333331,33.73246666666667 81 | Sanmenxia,111.1754,34.7856 82 | Yichang,111.31052,30.70758 83 | Jingzhou,112.2502,30.32483333333333 84 | Yueyang,113.15066666666664,29.403616666666665 85 | Changde,111.70665,29.05795 86 | Zhangjiajie,110.50303333333332,29.20593333333333 87 | Mianyang,104.73285,31.47535 88 | Yibin,104.62036666666668,28.77513833333333 89 | Luzhou,105.43465,28.9 90 | Zigong,104.76195,29.35195 91 | Deyang,104.39235,31.1204 92 | Nanchong,106.08621666666666,30.803066666666663 93 | Xianyang,108.6997,34.3399 94 | Tongchuan,108.98766666666666,34.949533333333335 95 | Yanan,109.48803333333336,36.5967 96 | Baoji,107.18958750000002,34.354125 97 | Weinan,109.46503333333334,34.4985 98 | Shizuishan,106.547775,39.047425 99 | Zhangqiu,117.541,36.71 100 | Jimo,120.47,36.39 101 | Jiaonan,120.005,35.878 102 | Jiaozhou,120.014,36.253 103 | Laixi,120.515,36.885 104 | Pingdu,119.952,36.792 105 | Penglai,120.76,37.817 106 | Zhaoyuan,120.399,37.374 107 | Laizhou,119.95,37.1785 108 | Rushan,121.531,36.913 109 | Wujiang,120.643,31.165666666666667 110 | Taicang,121.14,31.422 111 | Jurong,119.146,31.955 112 | Jiangyin,120.269,31.916333333333338 113 | Yixing,119.818,31.354 114 | Jintan,119.579,31.744 115 | Liyang,119.44899999999998,31.3585 116 | Linan,119.7183,30.2366 117 | Fuyang,117.2097,31.93403333333333 118 | Yiwu,120.0655,29.3195 119 | Xinyang,114.061525,32.13015 120 | Zhoukou,114.65715,33.621500000000005 121 | Jincheng,112.84638333333332,35.50058333333333 122 | Shuozhou,112.4312,39.34474 123 | Jinzhong,112.732,37.707233333333335 124 | Yuncheng,111.007025,35.055325 125 | Xinzhou,112.72473333333336,38.45443333333333 126 | Lu:liang,111.1406,37.5211 127 | Wuhai,106.80796666666669,39.67533333333333 128 | Bayannaoer,107.50735,40.8384 129 | Wulanchabu,113.1306,40.9923 130 | Alashanmeng,105.69966666666666,38.84296666666666 131 | Chaoyang,120.42083333333336,41.59093333333334 132 | Bangbu,117.35681666666667,32.92925 133 | Huainan,116.8364,32.66645 134 | Huaibei,116.7973,33.940266666666666 135 | Tongling,117.80895,30.929616666666664 136 | Anqing,117.02473333333334,30.546900000000004 137 | Huangshan,118.25546666666666,29.903033333333337 138 | Chuzhou,118.31653333333334,32.3 139 | Liuan,116.514725,31.76245 140 | Chizhou,117.48355,30.6578 141 | Xuancheng,118.73806666666668,30.954 142 | Jingdezhen,117.2242,29.30422 143 | Yingtan,117.01344,28.20856 144 | Shangrao,117.9583,28.448825 145 | Hebi,114.25113333333331,35.794533333333334 146 | Xinxiang,113.86766666666666,35.294999999999995 147 | Puyang,115.057,35.76573333333334 148 | Xuchang,113.81185,34.0035 149 | Luohe,114.0268,33.570325000000004 150 | Nanyang,112.5317,32.990125000000006 151 | Shangqiu,115.6547,34.421533333333336 152 | Zhumadian,114.00903333333332,32.98930000000001 153 | Huangshi,115.015125,30.213050000000003 154 | Shiyan,110.8579,32.563675 155 | Xiangfan,112.171825,32.07015 156 | Ezhou,114.8666,30.39303333333333 157 | Jingmen,112.20646666666669,31.025366666666667 158 | Xiaogan,113.92865,30.918 159 | Huanggang,114.9028,30.4742 160 | Xianning,114.313325,29.840425 161 | Suizhou,113.37996666666668,31.71636666666667 162 | Enshizhou,109.4864,30.2904 163 | Yiyang,112.34146666666668,28.58183333333333 164 | Xiangxizhou,109.71415,28.2922 165 | Guangyuan,105.8491,32.437375 166 | Suining,105.69723333333332,30.57433333333333 167 | Neijiang,105.052675,29.588475 168 | Leshan,103.7602,29.5738 169 | Meishan,103.8701,30.069 170 | Guangan,106.63206666666667,30.459666666666667 171 | Dazhou,107.50338,31.22138 172 | Yaan,103.0109,29.9834 173 | Ziyang,104.6405,30.133 174 | Hanzhong,107.01043333333332,33.07223333333334 175 | Ankang,109.02213333333331,32.68973333333333 176 | Shangluo,109.9154,33.8715 177 | Baiyin,104.1731,36.54695 178 | Tianshui,105.79475,34.573350000000005 179 | Pingliang,106.68655,35.53775 180 | Qingyang,107.65336666666668,35.65306666666667 181 | Dingxi,104.61985,35.58445 182 | Linxiazhou,103.21015,35.6018 183 | Wuzhong,106.19925,37.97835 184 | Zhongwei,105.18855,37.258700000000005 185 | Guyuan,106.24953333333332,36.055 186 | -------------------------------------------------------------------------------- /Evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | from exp.exp_basic import Exp_Basic 2 | import torch 3 | import torch.nn as nn 4 | from torch import optim 5 | import os 6 | import time 7 | import warnings 8 | import numpy as np 9 | from utils.utils import get_logger, load_graph_data 10 | from Data_Provider.data_factory import data_provider 11 | from utils.metrics import * 12 | from utils.tools import count_parameters 13 | import csv 14 | 15 | warnings.filterwarnings('ignore') 16 | 17 | 18 | class Evaluation_Air_Pollution(Exp_Basic): 19 | def __init__(self, args): 20 | adj_mx, edge_index, edge_attr, node_attr = load_graph_data(args.data.root_path) 21 | args.adj_mx = adj_mx # N x N 22 | args.edge_index = edge_index # 2 x M 23 | args.edge_attr = edge_attr # M x D 24 | args.node_attr = node_attr # N x D 25 | 26 | self._logger = get_logger(None, args.model_name, 'info.log', 27 | level=args.log_level, to_stdout=args.to_stdout) 28 | args.logger = self._logger 29 | 30 | if args.data.embed: 31 | args.model.input_dim = int(args.model.input_dim) + int(args.model.embed_dim) 32 | 33 | super(Evaluation_Air_Pollution, self).__init__(args) 34 | 35 | self.num_nodes = adj_mx.shape[0] 36 | self.input_var = int(self.args.model.input_dim) 37 | self.input_dim = int(self.args.model.X_dim) 38 | self.seq_len = int(self.args.model.seq_len) 39 | self.horizon = int(self.args.model.horizon) 40 | self.output_dim = int(self.args.model.X_dim) 41 | 42 | self.report_filepath = self.args.report_filepath 43 | self.result = [] 44 | self.result.append([self.model.setting]) 45 | self.result.append([self.model_parameters]) 46 | 47 | def _build_model(self): 48 | dataset, _ = self._get_data('val') 49 | self.args.data.mean_ = dataset.scaler.mean_ 50 | self.args.data.std_ = dataset.scaler.scale_ 51 | model = self.model_dict[self.args.model_name].Model(self.args).float() 52 | self.model_parameters = count_parameters(model) 53 | if self.args.GPU.use_multi_gpu and self.args.GPU.use_gpu: 54 | model = nn.DataParallel(model, device_ids=self.args.GPU.device_ids) 55 | return model 56 | 57 | def _get_data(self, flag): 58 | data_set, data_loader = data_provider(self.args, flag) 59 | return data_set, data_loader 60 | 61 | def test(self): 62 | test_data, test_loader = self._get_data(flag='test') 63 | self.inverse_transform = test_data.inverse_transform 64 | print('loading model') 65 | self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + self.model.setting, 'checkpoint.pth'))) 66 | 67 | with torch.no_grad(): 68 | self.model.eval() 69 | 70 | truths = [] 71 | preds = [] 72 | 73 | for _, (x, gt) in enumerate(test_loader): 74 | x, gt, y_embed = self._prepare_data(x, gt) 75 | pred, fe = self.model(x, y_embed) 76 | 77 | truths.append(gt.cpu().permute(1, 0, 2)) # B x T x N 78 | preds.append(pred.cpu().permute(1, 0, 2)) 79 | 80 | truths = torch.cat(truths, dim=0) # B x T x N 81 | preds = torch.cat(preds, dim=0) 82 | 83 | all_mae = [] 84 | all_smape = [] 85 | all_rmse = [] 86 | 87 | assert self.horizon == 24 88 | for i in range(0, self.horizon, 8): 89 | pred = preds[:, i: i + 8] 90 | truth = truths[:, i: i + 8] 91 | mae, smape, rmse = self._compute_loss_eval(truth, pred) 92 | all_mae.append(mae) 93 | all_smape.append(smape) 94 | all_rmse.append(rmse) 95 | self._logger.info('Evaluation {}h-{}h: - mae - {:.4f} - rmse - {:.4f} - mape - {:.4f}'.format( 96 | i*3, (i+8)*3, mae, rmse, smape)) 97 | 98 | # three days 99 | mae, smape, rmse = self._compute_loss_eval(truths, preds) 100 | all_mae.append(mae) 101 | all_smape.append(smape) 102 | all_rmse.append(rmse) 103 | self._logger.info('Evaluation all: - mae - {:.4f} - rmse - {:.4f} - mape - {:.4f}'.format( 104 | mae, rmse, smape)) 105 | 106 | all_metrics = {'mae': all_mae, 'rmse': all_rmse, 'smape': all_smape} 107 | 108 | test_res = list(np.array([v for k, v in all_metrics.items()]).T.flatten()) 109 | self.result.append(list(map(lambda x: round(x, 4), test_res))) 110 | 111 | truths_scaled = self.inverse_transform(truths).numpy() 112 | preds_scaled = self.inverse_transform(preds).numpy() 113 | 114 | return all_mae, all_smape, all_rmse, preds_scaled, truths_scaled 115 | 116 | def vali(self): 117 | vali_data, vali_loader = self._get_data(flag='val') 118 | self.inverse_transform = vali_data.inverse_transform 119 | print('loading model') 120 | self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + self.model.setting, 'checkpoint.pth'))) 121 | 122 | with torch.no_grad(): 123 | self.model.eval() 124 | 125 | truths = [] 126 | preds = [] 127 | for i, (x, gt) in enumerate(vali_loader): 128 | x, gt, y_embed = self._prepare_data(x, gt) 129 | 130 | pred, fe = self.model(x, y_embed) 131 | truths.append(gt.cpu().permute(1, 0, 2)) # B x T x N 132 | preds.append(pred.cpu().permute(1, 0, 2)) 133 | 134 | truths = torch.cat(truths, dim=0) 135 | preds = torch.cat(preds, dim=0) 136 | 137 | mae, smape, rmse = self._compute_loss_eval(truths, preds) 138 | 139 | self._logger.info('Evaluation: - mae - {:.4f} - smape - {:.4f} - rmse - {:.4f}' 140 | .format(mae, smape, rmse)) 141 | val_res = [mae, rmse, smape] 142 | self.result.append(list(map(lambda x: round(x, 4), val_res))) 143 | 144 | def _prepare_data(self, x, y): 145 | x, y = self._get_x_y(x, y) # B x 24(72 hours) x N x D 146 | x, y, y_embed = self._get_x_y_in_correct_dims(x, y) # 24 x B x N x D 147 | return x.to(self.device), y.to(self.device), y_embed # 24 x B x 35 * 11 148 | 149 | def _get_x_y(self, x, y): 150 | x = x.float() 151 | y = y.float() 152 | x = x.permute(1, 0, 2, 3) 153 | y = y.permute(1, 0, 2, 3) 154 | return x, y 155 | 156 | def _get_x_y_in_correct_dims(self, x, y): 157 | batch_size = x.size(1) 158 | if self.args.data.embed: 159 | station_x = torch.arange(0, self.num_nodes).unsqueeze(0).unsqueeze(0).unsqueeze(-1).repeat(self.seq_len, batch_size, 1, 1) 160 | station_y = torch.arange(0, self.num_nodes).unsqueeze(0).unsqueeze(0).unsqueeze(-1).repeat(self.horizon, batch_size, 1, 1) 161 | x = torch.cat([x, station_x], dim=-1) 162 | y = torch.cat([y, station_y], dim=-1) 163 | x = x.reshape(self.seq_len, batch_size, self.num_nodes * self.input_var) 164 | embed = [6, 7, 8, 9, 10, 11] 165 | y_embed = y[..., embed].reshape(self.horizon, batch_size, self.num_nodes*len(embed)) 166 | y = y[..., :self.output_dim].reshape(self.horizon, batch_size, 167 | self.num_nodes*self.output_dim) 168 | else: 169 | x = x[..., :self.input_var].reshape(self.seq_len, batch_size, self.num_nodes * self.input_var) 170 | y = y[..., :self.output_dim].reshape(self.horizon, batch_size, 171 | self.num_nodes * self.output_dim) 172 | y_embed = None 173 | return x, y, y_embed 174 | 175 | def _compute_loss_eval(self, y_true, y_predicted): 176 | y_true = self.inverse_transform(y_true) 177 | y_predicted = self.inverse_transform(y_predicted) 178 | return compute_all_metrics(y_predicted, y_true) -------------------------------------------------------------------------------- /models/layers/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm 5 | import math 6 | 7 | 8 | class PositionalEmbedding(nn.Module): 9 | def __init__(self, d_model, max_len=5000): 10 | super(PositionalEmbedding, self).__init__() 11 | # Compute the positional encodings once in log space. 12 | pe = torch.zeros(max_len, d_model).float() 13 | pe.require_grad = False 14 | 15 | position = torch.arange(0, max_len).float().unsqueeze(1) 16 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 17 | 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | 21 | pe = pe.unsqueeze(0) 22 | self.register_buffer('pe', pe) 23 | 24 | def forward(self, x): 25 | return self.pe[:, :x.size(1)] 26 | 27 | 28 | class TokenEmbedding(nn.Module): 29 | def __init__(self, c_in, d_model): 30 | super(TokenEmbedding, self).__init__() 31 | padding = 1 if torch.__version__ >= '1.5.0' else 2 32 | self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, 33 | kernel_size=3, padding=padding, padding_mode='circular', bias=False) 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv1d): 36 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') 37 | 38 | def forward(self, x): 39 | x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) 40 | return x 41 | 42 | 43 | class FixedEmbedding(nn.Module): 44 | def __init__(self, c_in, d_model): 45 | super(FixedEmbedding, self).__init__() 46 | 47 | w = torch.zeros(c_in, d_model).float() 48 | w.require_grad = False 49 | 50 | position = torch.arange(0, c_in).float().unsqueeze(1) 51 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 52 | 53 | w[:, 0::2] = torch.sin(position * div_term) 54 | w[:, 1::2] = torch.cos(position * div_term) 55 | 56 | self.emb = nn.Embedding(c_in, d_model) 57 | self.emb.weight = nn.Parameter(w, requires_grad=False) 58 | 59 | def forward(self, x): 60 | return self.emb(x).detach() 61 | 62 | 63 | class TemporalEmbedding(nn.Module): 64 | def __init__(self, d_model, embed_type='fixed', freq='h'): 65 | super(TemporalEmbedding, self).__init__() 66 | 67 | minute_size = 4 68 | hour_size = 24 69 | weekday_size = 7 70 | day_size = 32 71 | month_size = 13 72 | 73 | Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding 74 | if freq == 't': 75 | self.minute_embed = Embed(minute_size, d_model) 76 | self.hour_embed = Embed(hour_size, d_model) 77 | self.weekday_embed = Embed(weekday_size, d_model) 78 | self.day_embed = Embed(day_size, d_model) 79 | self.month_embed = Embed(month_size, d_model) 80 | 81 | def forward(self, x): 82 | x = x.long() 83 | 84 | minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0. 85 | hour_x = self.hour_embed(x[:, :, 3]) 86 | weekday_x = self.weekday_embed(x[:, :, 2]) 87 | day_x = self.day_embed(x[:, :, 1]) 88 | month_x = self.month_embed(x[:, :, 0]) 89 | 90 | return hour_x + weekday_x + day_x + month_x + minute_x 91 | 92 | 93 | class TimeFeatureEmbedding(nn.Module): 94 | def __init__(self, d_model, embed_type='timeF', freq='h'): 95 | super(TimeFeatureEmbedding, self).__init__() 96 | 97 | freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} 98 | d_inp = freq_map[freq] 99 | self.embed = nn.Linear(d_inp, d_model, bias=False) 100 | 101 | def forward(self, x): 102 | return self.embed(x) 103 | 104 | 105 | class DataEmbedding(nn.Module): 106 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 107 | super(DataEmbedding, self).__init__() 108 | 109 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 110 | self.position_embedding = PositionalEmbedding(d_model=d_model) 111 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 112 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 113 | d_model=d_model, embed_type=embed_type, freq=freq) 114 | self.dropout = nn.Dropout(p=dropout) 115 | 116 | def forward(self, x, x_mark): 117 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) 118 | return self.dropout(x) 119 | 120 | 121 | class DataEmbedding_wo_pos(nn.Module): 122 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 123 | super(DataEmbedding_wo_pos, self).__init__() 124 | 125 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 126 | self.position_embedding = PositionalEmbedding(d_model=d_model) 127 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 128 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 129 | d_model=d_model, embed_type=embed_type, freq=freq) 130 | self.dropout = nn.Dropout(p=dropout) 131 | 132 | def forward(self, x, x_mark): 133 | x = self.value_embedding(x) + self.temporal_embedding(x_mark) 134 | return self.dropout(x) 135 | 136 | class DataEmbedding_wo_pos_temp(nn.Module): 137 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 138 | super(DataEmbedding_wo_pos_temp, self).__init__() 139 | 140 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 141 | self.position_embedding = PositionalEmbedding(d_model=d_model) 142 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 143 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 144 | d_model=d_model, embed_type=embed_type, freq=freq) 145 | self.dropout = nn.Dropout(p=dropout) 146 | 147 | def forward(self, x, x_mark): 148 | x = self.value_embedding(x) 149 | return self.dropout(x) 150 | 151 | class DataEmbedding_wo_temp(nn.Module): 152 | def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): 153 | super(DataEmbedding_wo_temp, self).__init__() 154 | 155 | self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) 156 | self.position_embedding = PositionalEmbedding(d_model=d_model) 157 | self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, 158 | freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( 159 | d_model=d_model, embed_type=embed_type, freq=freq) 160 | self.dropout = nn.Dropout(p=dropout) 161 | 162 | def forward(self, x, x_mark): 163 | x = self.value_embedding(x) + self.position_embedding(x) 164 | return self.dropout(x) 165 | 166 | 167 | class AirEmbedding(nn.Module): 168 | ''' 169 | Embed catagorical variables. 170 | ''' 171 | def __init__(self, h2d, d2w, d2m, m2y, station, num_nodes): 172 | super(AirEmbedding, self).__init__() 173 | self.embed_h2d = nn.Embedding(24, h2d) 174 | self.embed_d2w = nn.Embedding(7, d2w) 175 | self.embed_d2m = nn.Embedding(31, d2m) 176 | self.embed_m2y = nn.Embedding(12, m2y) 177 | self.embed_station = nn.Embedding(num_nodes, station) 178 | 179 | def forward(self, x): 180 | x_h2d = self.embed_h2d(x[..., 0]) 181 | x_d2w = self.embed_d2w(x[..., 1]) 182 | x_d2m = self.embed_d2m(x[..., 2]) 183 | x_m2y = self.embed_m2y(x[..., 3]) 184 | x_is_holiday = x[..., 4:5] 185 | x_station = self.embed_station(x[..., 5]) 186 | 187 | out=torch.cat((x_h2d, x_d2w, x_d2m, x_m2y, x_is_holiday, x_station), dim=-1) 188 | return out -------------------------------------------------------------------------------- /models/layers/Explicit_odefunc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from torch_geometric.nn import ChebConv 6 | from torch_geometric.utils import dense_to_sparse 7 | from models.layers.DiffAdv_Fusion import Simple_Gated_Fusion as Gated_Fusion 8 | 9 | 10 | class ODEFunc(nn.Module): 11 | def __init__(self, gcn_hidden_dim, input_dim, adj_mx, edge_index, edge_attr, 12 | K_neighbour, num_nodes, device, num_layers=2, 13 | activation='tanh', filter_type="diff_adv", estimate=False): 14 | """ 15 | ODEs on explicit space 16 | :param gcn_hidden_dim: dimensionality of the hidden layers 17 | :param input_dim: dimensionality used for ODE (input and output). 18 | :param adj_mx: for diff_adj_mx and adv_adj_mx 19 | :param edge_index: M x 2 20 | :param edge_attr: M x D {diff_dist, dist_km, direction} 21 | :param K_neighbour: 22 | :param num_nodes: 23 | :param num_layers: hidden layers in each ode func. 24 | :param activation: 25 | :param filter_type: 26 | """ 27 | super(ODEFunc, self).__init__() 28 | self.device = device 29 | 30 | self._activation = torch.tanh if activation == 'tanh' else torch.relu 31 | 32 | self.num_nodes = num_nodes 33 | self.gcn_hidden_dim = gcn_hidden_dim 34 | self.input_dim = input_dim 35 | self.num_layers = num_layers 36 | self.nfe = 0 # Number of function integration 37 | 38 | self._filter_type = filter_type 39 | 40 | self.adj_mx = adj_mx 41 | self.edge_index = torch.tensor(edge_index, dtype=torch.int64).to(self.device) 42 | self.edge_attr = edge_attr 43 | self.K_neighbour = K_neighbour 44 | 45 | self.diff_edge_attr = self.edge_attr[:, 0] 46 | self.adv_edge_attr = None 47 | 48 | self.estimate = estimate 49 | if not estimate: 50 | self.diff_coeff = 0.1 51 | self.beta = nn.Parameter(torch.zeros(self.num_nodes * self.input_dim)) 52 | else: 53 | self.diff_coeff, self.beta = None, None 54 | 55 | self.residual = nn.Identity() 56 | 57 | if self._filter_type == "diff": 58 | self.diff_cheb_conv = self.laplacian_operator() 59 | 60 | elif self._filter_type == "adv": 61 | self.adv_cheb_conv = self.laplacian_operator() 62 | 63 | elif self._filter_type == "diff_adv": 64 | self.gated_fusion = Gated_Fusion(self.num_nodes, self.input_dim) 65 | 66 | self.diff_cheb_conv = self.laplacian_operator() 67 | self.adv_cheb_conv = self.laplacian_operator() 68 | else: 69 | raise "Knowledge not registered" 70 | 71 | def create_adv_matrix(self, last_wind_vars, wind_mean, wind_std): 72 | """ 73 | 74 | :param last_wind_vars: last_wind_vars: B x N x 2 (wind_speed[Norm], wind_direction) 75 | :return: adv_edge_attr B x M x 1 based on adj_mx 76 | """ 77 | batch_size = last_wind_vars.shape[0] 78 | edge_src, edge_target = self.edge_index 79 | node_src = last_wind_vars[:, edge_src, :] 80 | node_target = last_wind_vars[:, edge_target, :] 81 | 82 | src_wind_speed = node_src[:, :, 0] * wind_std[0] + wind_mean[0] # km/h 83 | src_wind_dir = node_src[:, :, 1] * wind_std[1] + wind_mean[1] 84 | dist = self.edge_attr[:, 1].unsqueeze(dim=0).repeat(batch_size, 1) 85 | dist_dir = self.edge_attr[:, 2].unsqueeze(dim=0).repeat(batch_size, 1) 86 | 87 | src_wind_dir = (src_wind_dir + 180) % 360 88 | theta = torch.abs(dist_dir - src_wind_dir) 89 | adv_edge_attr = F.relu(3 * src_wind_speed * torch.cos(theta) / dist) # B x M 90 | 91 | return adv_edge_attr 92 | 93 | def create_equation(self, last_wind_vars, wind_mean, wind_std, diff_coeff, beta): 94 | if self.estimate: 95 | self.diff_coeff = 0.1 96 | self.beta = beta 97 | if self._filter_type == "diff": 98 | pass 99 | elif self._filter_type == "adv": 100 | self.adv_edge_attr = self.create_adv_matrix(last_wind_vars, wind_mean, wind_std) 101 | elif self._filter_type == "diff_adv": 102 | self.adv_edge_attr = self.create_adv_matrix(last_wind_vars, wind_mean, wind_std) 103 | else: 104 | print("Invalid Filter Type") 105 | 106 | def forward(self, t_local, Xt): 107 | self.nfe += 1 108 | grad = self.get_ode_gradient_nn(t_local, Xt) 109 | return grad 110 | 111 | def get_ode_gradient_nn(self, t, Xt): 112 | if (self._filter_type == "diff"): 113 | grad = - self.diff_coeff * self.ode_func_net_diff(Xt, self.diff_edge_attr) 114 | elif (self._filter_type == "adv"): 115 | grad = - self.ode_func_net_adv(Xt, self.adv_edge_attr) 116 | elif (self._filter_type == "diff_adv"): 117 | grad_diff = - self.diff_coeff * self.ode_func_net_diff(Xt, self.diff_edge_attr) 118 | grad_adv = - self.ode_func_net_adv(Xt, self.adv_edge_attr) 119 | grad = self.gated_fusion(grad_diff, grad_adv) 120 | else: 121 | raise "Invalid Filter Type" 122 | 123 | grad = grad + self.beta * Xt 124 | 125 | return grad 126 | 127 | def ode_func_net_diff(self, x, edge_attr): 128 | # x: B x N*var_dim 129 | batch_size = x.shape[0] 130 | x = torch.reshape(x, (batch_size, self.num_nodes, self.input_dim)) 131 | 132 | x = self.diff_cheb_conv[0](x, self.edge_index, edge_attr, lambda_max=2) 133 | x = self._activation(x) 134 | 135 | for op in self.diff_cheb_conv[1:-1]: 136 | residual = self.residual(x) 137 | x = op(x, self.edge_index, edge_attr, lambda_max=2) 138 | x = self._activation(x) + residual 139 | 140 | x = self.diff_cheb_conv[-1](x, self.edge_index, edge_attr, lambda_max=2) 141 | 142 | return x.reshape((batch_size, self.num_nodes * self.input_dim)) 143 | 144 | def ode_func_net_adv(self, x, edge_attr): 145 | batch_size = x.shape[0] 146 | batch = torch.arange(0, batch_size) 147 | batch = torch.repeat_interleave(batch, self.num_nodes).to(self.device) 148 | x = x.reshape(batch_size * self.num_nodes, -1) # B*N x input_dim 149 | edge_indices = [] 150 | for i in range(batch_size): 151 | edge_indices.append(self.edge_index + i * self.num_nodes) 152 | edge_index = torch.cat(edge_indices, dim=1) # 2 x B*M 153 | edge_attr = edge_attr.flatten() # B*M 154 | 155 | x = self.adv_cheb_conv[0](x, edge_index, edge_attr, batch=batch, lambda_max=2) 156 | x = self._activation(x) 157 | 158 | for op in self.adv_cheb_conv[1:-1]: 159 | residual = self.residual(x) 160 | x = op(x, edge_index, edge_attr, batch=batch, lambda_max=2) 161 | x = self._activation(x) + residual 162 | 163 | x = self.adv_cheb_conv[-1](x, edge_index, edge_attr, batch=batch, lambda_max=2) 164 | 165 | x = x.reshape(batch_size, self.num_nodes, self.input_dim) 166 | return x.reshape((batch_size, self.num_nodes * self.input_dim)) 167 | 168 | @staticmethod 169 | def dense_to_sparse(adj: torch.Tensor): 170 | batch_size, num_nodes, _ = adj.size() 171 | edge_indices = [] 172 | edge_attrs = [] 173 | 174 | for i in range(batch_size): 175 | edge_index, edge_attr = dense_to_sparse(adj[i]) 176 | edge_indices.append(edge_index + i * num_nodes) 177 | edge_attrs.append(edge_attr) 178 | 179 | edge_index = torch.cat(edge_indices, dim=1) 180 | edge_attr = torch.cat(edge_attrs, dim=0) 181 | 182 | return edge_index, edge_attr 183 | 184 | def laplacian_operator(self): 185 | # approximate Laplacian 186 | operator = nn.ModuleList() 187 | operator.append( 188 | ChebConv(in_channels=self.input_dim, out_channels=self.gcn_hidden_dim, 189 | K=self.K_neighbour, normalization='sym', 190 | bias=True) 191 | ) 192 | 193 | for _ in range(self.num_layers - 2): 194 | operator.append( 195 | ChebConv(in_channels=self.gcn_hidden_dim, out_channels=self.gcn_hidden_dim, 196 | K=self.K_neighbour, normalization='sym', 197 | bias=True) 198 | ) 199 | 200 | operator.append( 201 | ChebConv(in_channels=self.gcn_hidden_dim, out_channels=self.input_dim, 202 | K=self.K_neighbour, normalization='sym', 203 | bias=True) 204 | ) 205 | 206 | return operator -------------------------------------------------------------------------------- /Data_Provider/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import glob 5 | import re 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from sklearn.preprocessing import StandardScaler 9 | from utils.utils import exchange_df_column 10 | import random 11 | import chinese_calendar as calendar 12 | from typing import Union, List 13 | import metpy.calc as mpcalc 14 | from metpy.units import units 15 | 16 | class Dataset_Beijing1718(Dataset): 17 | def __init__(self, root_path, flag='train', seq_len=24, pred_len=24, 18 | freq='1h', scale=True, embed=0, 19 | normalized_col: Union[str, List[str]]='default'): 20 | self.columns = ["time", "PM2.5", "temperature", "pressure", "humidity", "wind_speed", "wind_direction"] 21 | if normalized_col == 'default': 22 | self.normalized_col = ["PM2.5", "temperature", "pressure", "humidity", "wind_speed", "wind_direction"] 23 | else: 24 | self.normalized_col = normalized_col 25 | self.seq_len = seq_len 26 | self.pred_len = pred_len 27 | self.window_size = seq_len + pred_len 28 | assert freq in ['1h', '3h'] 29 | self.freq = freq 30 | self.scale = scale 31 | self.embed = embed 32 | if scale: 33 | self.scaler = StandardScaler() 34 | else: 35 | self.scaler = None 36 | 37 | assert flag in ['train', 'test', 'val'] 38 | type_map = {'train': 0, 'val': 1, 'test': 2} 39 | self.set_type = type_map[flag] 40 | 41 | self.root_path = root_path 42 | self.station_df_dict = {} 43 | self.station_info = pd.read_csv(os.path.join(self.root_path, "station.csv")) 44 | self.stations_dir = os.path.join(self.root_path, "stations") 45 | 46 | self.valid_indices, self.border = self._get_valid_indices() 47 | self.__read_data__() 48 | 49 | def _get_valid_indices(self): 50 | # find all valid index 51 | random_file = random.choice(os.listdir(self.stations_dir)) 52 | station_data = pd.read_csv(os.path.join(self.stations_dir, random_file), 53 | usecols=self.columns) 54 | station_data['time'] = pd.to_datetime(station_data['time']) 55 | station_data.set_index('time', inplace=True) 56 | # sample frequence 57 | if self.freq == '3h': 58 | station_data = station_data[::3] 59 | # 7:1:2 60 | border1s = [0, int(len(station_data) * 0.7), int(len(station_data) * 0.8)] 61 | border2s = [int(len(station_data) * 0.7), int(len(station_data) * 0.8), len(station_data)] 62 | self.train_border = (border1s[0], border2s[0]) 63 | border1 = border1s[self.set_type] 64 | border2 = border2s[self.set_type] 65 | # select dataset,find all valid index 66 | station_data = station_data.iloc[border1: border2] 67 | self.time_idx = station_data.index # time 68 | self.time_info = self.cal_time_info(self.time_idx) 69 | valid_indices = [] 70 | start = 0 71 | while start + self.window_size <= len(station_data): 72 | window = station_data.iloc[start: start + self.window_size] 73 | if not window.isna().any(axis=1).any(): 74 | valid_indices.append(start) 75 | start += 1 76 | return valid_indices, (border1, border2) 77 | 78 | def cal_time_info(self, time_idx): 79 | def check_holiday(date): 80 | return 1 if calendar.is_holiday(date) or calendar.is_in_lieu(date) else 0 81 | 82 | time_info = pd.DataFrame({ 83 | 'time': time_idx, 84 | 'hour_of_day': time_idx.hour, # hour-day 85 | 'day_of_week': time_idx.dayofweek, # day-week 86 | 'day_of_month': time_idx.day - 1, # day-month 87 | 'month_of_year': time_idx.month - 1, # month-year 88 | }) 89 | time_info['is_holiday'] = [check_holiday(d.date()) for d in time_idx] 90 | 91 | time_info.set_index('time', inplace=True) 92 | return time_info 93 | 94 | def __read_data__(self): 95 | train_set = [] 96 | for csv in os.listdir(self.stations_dir): 97 | station_df = pd.read_csv(os.path.join(self.stations_dir, csv), 98 | usecols=self.columns) 99 | station_df = exchange_df_column(station_df, 'wind_direction', 'wind_speed') 100 | station_df['time'] = pd.to_datetime(station_df['time']) 101 | station_df = station_df.set_index('time') 102 | if self.freq == '3h': 103 | station_df = station_df.iloc[::3] 104 | train_set.append(station_df[self.normalized_col].iloc[self.train_border[0]: self.train_border[1]]) 105 | station_df = station_df.iloc[self.border[0]: self.border[1]] 106 | station = csv.split(".")[0] 107 | self.station_df_dict[station] = station_df 108 | 109 | train_set = pd.concat(train_set, axis=0) 110 | if self.scale: 111 | self.scaler.fit(train_set) 112 | for station, df in self.station_df_dict.items(): 113 | df[self.normalized_col] = self.scaler.transform(df[self.normalized_col]) 114 | if self.embed: 115 | self.station_df_dict[station] = pd.concat([df, self.time_info], axis=1) 116 | else: 117 | self.station_df_dict[station] = pd.DataFrame(df, columns=self.columns[1:], index=self.time_idx) 118 | 119 | def __len__(self): 120 | return len(self.valid_indices) 121 | 122 | def __getitem__(self, idx): 123 | seq_x = [] 124 | seq_y = [] 125 | x_start = self.valid_indices[idx] 126 | x_end = x_start + self.seq_len 127 | y_start = self.valid_indices[idx] + self.seq_len 128 | y_end = y_start + self.pred_len 129 | 130 | for station in self.station_info['station']: 131 | df = self.station_df_dict[station] 132 | seq_x.append(df.iloc[x_start: x_end].values) 133 | seq_y.append(df.iloc[y_start: y_end].values) 134 | seq_x = np.stack(seq_x).transpose(1, 0, 2) 135 | seq_y = np.stack(seq_y).transpose(1, 0, 2) 136 | return seq_x, seq_y 137 | 138 | def inverse_transform(self, data): 139 | assert self.scale is True 140 | pm25_mean = self.scaler.mean_[0] 141 | pm25_std = self.scaler.scale_[0] 142 | return (data * pm25_std) + pm25_mean 143 | 144 | 145 | class Dataset_KnowAir(Dataset): 146 | def __init__(self, root_path, flag='train', seq_len=24, pred_len=24, 147 | freq='3h', scale=True, embed=0, 148 | normalized_col: Union[str, List[int]]='default'): 149 | if normalized_col == 'default': 150 | self.normalized_col = np.arange(0, 6) 151 | else: 152 | self.normalized_col = normalized_col 153 | 154 | self.seq_len = seq_len 155 | self.pred_len = pred_len 156 | self.window_size = seq_len + pred_len 157 | self.scale = scale 158 | self.embed = embed 159 | if scale: 160 | self.scaler = StandardScaler() 161 | else: 162 | self.scaler = None 163 | 164 | assert flag in ['train', 'test', 'val'] 165 | type_map = {'train': 0, 'val': 1, 'test': 2} 166 | self.set_type = type_map[flag] 167 | 168 | self.root_path = root_path 169 | self.station_info = pd.read_csv(os.path.join(self.root_path, "station.csv")) 170 | self.stations_npy = os.path.join(self.root_path, "KnowAir.npy") 171 | metero_var = ['100m_u_component_of_wind', '100m_v_component_of_wind', '2m_dewpoint_temperature', 172 | '2m_temperature', 'boundary_layer_height', 'k_index', 'relative_humidity+950', 173 | 'relative_humidity+975', 'specific_humidity+950', 'surface_pressure', 174 | 'temperature+925', 'temperature+950', 'total_precipitation', 'u_component_of_wind+950', 175 | 'v_component_of_wind+950', 'vertical_velocity+950', 'vorticity+950'] 176 | metero_use = ['2m_temperature', 'surface_pressure', 'relative_humidity+950', 177 | '100m_u_component_of_wind', '100m_v_component_of_wind'] 178 | self.metero_idx = [metero_var.index(var) for var in metero_use] 179 | self.time_idx = pd.date_range(start='2015-01-01', end='2018-12-31 21:00', freq='3H') 180 | 181 | self.__process_raw_data__() 182 | self.__read_data__() 183 | 184 | def __process_raw_data__(self): 185 | raw_data = np.load(self.stations_npy) 186 | self.pm25 = raw_data[:, :, -1:] 187 | self.feature = raw_data[:, :, :-1] 188 | self.feature = self.feature[:, :, self.metero_idx] 189 | u = self.feature[:, :, -2] * units.meter / units.second # m/s 190 | v = self.feature[:, :, -1] * units.meter / units.second # m/s 191 | speed = 3.6 * mpcalc.wind_speed(u, v)._magnitude # km/h 192 | direc = mpcalc.wind_direction(u, v)._magnitude 193 | self.feature[:, :, -2] = speed 194 | self.feature[:, :, -1] = direc 195 | 196 | self.raw_data = np.concatenate([self.pm25, self.feature], axis=-1) # T x N x D 197 | 198 | def __read_data__(self): 199 | # 2:1:1 200 | border1s = [0, int(len(self.raw_data) * 0.5), int(len(self.raw_data) * 0.75)] 201 | border2s = [int(len(self.raw_data) * 0.5), int(len(self.raw_data) * 0.75), len(self.raw_data)] 202 | self.train_border = (border1s[0], border2s[0]) 203 | border1 = border1s[self.set_type] 204 | border2 = border2s[self.set_type] 205 | self.data = self.raw_data[border1: border2] 206 | if self.embed: 207 | self.time_info = self.cal_time_info(self.time_idx[border1: border2]).values 208 | 209 | if self.scale: 210 | train_set = self.raw_data[self.train_border[0]: self.train_border[1], :, :] 211 | T, N, D = self.data.shape 212 | self.scaler.fit(train_set.reshape(-1, D)[:, self.normalized_col]) 213 | self.data = self.data.reshape(-1, D) 214 | self.data[:, self.normalized_col] = self.scaler.transform(self.data[:, self.normalized_col]) 215 | self.data = self.data.reshape(T, N, D) 216 | 217 | def cal_time_info(self, time_idx): 218 | def check_holiday(date): 219 | return 1 if calendar.is_holiday(date) or calendar.is_in_lieu(date) else 0 220 | 221 | time_info = pd.DataFrame({ 222 | 'time': time_idx, 223 | 'hour_of_day': time_idx.hour, # hour-day 224 | 'day_of_week': time_idx.dayofweek, # day-week 225 | 'day_of_month': time_idx.day - 1, # day-month 226 | 'month_of_year': time_idx.month - 1, # month-year 227 | }) 228 | time_info['is_holiday'] = [check_holiday(d.date()) for d in time_idx] 229 | 230 | time_info.set_index('time', inplace=True) 231 | return time_info 232 | 233 | def __len__(self): 234 | return len(self.data) - self.window_size + 1 235 | 236 | def __getitem__(self, idx): 237 | x_start = idx 238 | x_end = x_start + self.seq_len 239 | y_start = idx + self.seq_len 240 | y_end = y_start + self.pred_len 241 | 242 | seq_x = self.data[x_start: x_end] 243 | seq_y = self.data[y_start: y_end] 244 | if self.embed: 245 | seq_x_time_info = self.time_info[x_start: x_end] 246 | seq_x_time_info = np.expand_dims(seq_x_time_info, axis=1).repeat(seq_x.shape[1],axis=1) 247 | seq_x = np.concatenate([seq_x, seq_x_time_info], axis=2) 248 | 249 | seq_y_time_info = self.time_info[y_start: y_end] 250 | seq_y_time_info = np.expand_dims(seq_y_time_info, axis=1).repeat(seq_x.shape[1],axis=1) 251 | seq_y = np.concatenate([seq_y, seq_y_time_info], axis=2) 252 | 253 | return seq_x, seq_y 254 | 255 | def inverse_transform(self, data): 256 | assert self.scale is True 257 | pm25_mean = self.scaler.mean_[0] 258 | pm25_std = self.scaler.scale_[0] 259 | return (data * pm25_std) + pm25_mean 260 | -------------------------------------------------------------------------------- /models/Air_DualODE.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from models.layers.Explicit_odefunc import ODEFunc 4 | from models.layers.Diffeq_solver import DiffeqSolver 5 | from typing import Optional 6 | from models.layers.tools import Air_Attrs 7 | from models.layers.Encoder import Encoder_unk_z, Coeff_Estimator_new as Coeff_Estimator, Encoder_phy_z 8 | from models.layers.Decoder import Conv1d_Decoder as Conv_Decoder 9 | from models.layers.Dynamics_Funsion import GNN_Knowledge_Fusion 10 | from models.layers.Unk_Dynamics import Unk_odefunc_ATT as Unk_odefunc 11 | from models.layers.Embed import AirEmbedding 12 | from models.layers.soft_losses import * 13 | 14 | from utils.utils import ConfigDict, load_config 15 | 16 | class Model(nn.Module, Air_Attrs): 17 | def __init__(self, args: Optional[ConfigDict] = None): 18 | nn.Module.__init__(self) 19 | Air_Attrs.__init__(self, args) 20 | self._logger = args.logger 21 | self.setting = self.get_setting(args) 22 | 23 | # gpu device 24 | self.device = torch.device("cuda:{}".format(args.GPU.gpu)) 25 | 26 | # embedding 27 | self.embed_h2d = self.args.model.embedding.hour2day 28 | self.embed_d2w = self.args.model.embedding.day2week 29 | self.embed_d2m = self.args.model.embedding.day2month 30 | self.embed_m2y = self.args.model.embedding.month2year 31 | self.embed_station = self.args.model.embedding.station 32 | self.embedding_dim = 0 33 | 34 | if args.data.embed: 35 | self.embedding_dim = self.embed_h2d + self.embed_d2w + self.embed_d2m + \ 36 | self.embed_m2y + self.embed_station + 1 # holiday 37 | self.embedding_air = AirEmbedding(self.embed_h2d, self.embed_d2w, self.embed_d2m, 38 | self.embed_m2y, self.embed_station, self.num_nodes) 39 | 40 | # phy_func 41 | self.is_phy = self.args.model.phy_func.enable 42 | self.knowledge = self.args.model.phy_func.knowledge 43 | self.phy_gnn_layers = self.args.model.phy_func.gnn_layers 44 | self.phy_gnn_hid_dim = self.args.model.phy_func.gnn_hid_dim # GCN: N x D --> N x D' 45 | self.cheb_k = self.args.model.phy_func.cheb_k 46 | self.is_estimator = self.args.model.phy_func.coeff_estimator 47 | self.phy_rnn_layers = self.args.model.phy_func.rnn_layers 48 | self.phy_rnn_dim =self.args.model.phy_func.rnn_dim 49 | self.phy_input_dim = self.args.model.phy_func.input_dim + self.embedding_dim 50 | self.phy_latent_dim = self.args.model.phy_func.latent_dim 51 | self.phy_atol = float(self.args.model.phy_func.odeint_atol) 52 | self.phy_rtol = float(self.args.model.phy_func.odeint_rtol) 53 | self.phy_solver = DiffeqSolver( 54 | method=self.args.model.phy_func.ode_method, 55 | odeint_atol=self.phy_atol, 56 | odeint_rtol=self.phy_rtol, 57 | adjoint=self.args.model.phy_func.adjoint 58 | ) 59 | 60 | # unk_func 61 | self.is_unk = self.args.model.unk_func.enable 62 | self.unk_rnn_layers = self.args.model.unk_func.rnn_layers 63 | self.unk_rnn_dim =self.args.model.unk_func.rnn_dim 64 | self.unk_input_dim = self.args.model.unk_func.input_dim + self.embedding_dim 65 | self.unk_latent_dim = self.args.model.unk_func.latent_dim # Z^D_t: N x unk_latent_dim 66 | self.unk_n_heads = self.args.model.unk_func.n_heads # in NODE ATT 67 | self.unk_d_f = self.args.model.unk_func.d_f 68 | self.unk_atol = float(self.args.model.unk_func.odeint_atol) 69 | self.unk_rtol = float(self.args.model.unk_func.odeint_rtol) 70 | self.unk_solver = DiffeqSolver( 71 | method=self.args.model.unk_func.ode_method, 72 | odeint_atol=self.unk_atol, 73 | odeint_rtol=self.unk_rtol, 74 | adjoint=self.args.model.unk_func.adjoint 75 | ) 76 | 77 | # fusion 78 | self.fusion_latent_dim = self.args.model.fusion.latent_dim 79 | self.fusion_output_dim = self.args.model.fusion.output_dim 80 | self.fusion_gnn_layers = self.args.model.fusion.num_layers 81 | self.fusion_gnn_type = self.args.model.fusion.gnn_type 82 | 83 | # decoder 84 | self.is_decoder = self.args.model.decoder.enable 85 | 86 | # Adjacent Matrix 87 | self.adj_mx = args.adj_mx 88 | self.edge_index = torch.tensor(args.edge_index, dtype=torch.int32).to(self.device) 89 | self.edge_attr = torch.from_numpy(args.edge_attr).float().to(self.device) 90 | 91 | # wind mean and std 92 | self.wind_mean = args.data.mean_[-2:] 93 | self.wind_std = args.data.std_[-2:] 94 | 95 | if self.is_phy: 96 | self.phy_odefunc = ODEFunc(self.phy_gnn_hid_dim, self.X_dim, self.adj_mx, self.edge_index, self.edge_attr, 97 | self.cheb_k, self.num_nodes, self.device, num_layers=self.phy_gnn_layers, 98 | filter_type=self.knowledge, estimate=self.is_estimator) 99 | self.coeff_estimator = None 100 | if self.is_estimator: 101 | self.coeff_estimator = Coeff_Estimator(input_dim=self.input_dim, 102 | coeff_dim=self.X_dim, num_nodes=self.num_nodes, 103 | rnn_dim=self.phy_rnn_dim, n_layers=self.phy_rnn_layers) 104 | 105 | self.RNN_encoder_pred = Encoder_phy_z(self.phy_input_dim, self.phy_latent_dim, 106 | self.phy_rnn_layers, self.num_nodes) 107 | 108 | if self.is_unk: 109 | self.encoder = Encoder_unk_z(self.unk_input_dim, self.unk_latent_dim, self.num_nodes, 110 | self.unk_rnn_dim, self.unk_rnn_layers) 111 | self.unk_odefunc = Unk_odefunc(self.unk_latent_dim, self.num_nodes, self.unk_n_heads, 112 | self.device, self.adj_mx, self.unk_d_f) 113 | 114 | # Knowledge_Fusion 115 | if self.is_phy and self.is_unk: 116 | self.gatef_fusion = GNN_Knowledge_Fusion(self.num_nodes, self.phy_latent_dim, self.unk_latent_dim, 117 | self.fusion_output_dim, self.edge_index, self.edge_attr[:, :1], 118 | hid_dim=self.fusion_latent_dim, gnn_type=self.fusion_gnn_type, 119 | num_layers=self.fusion_gnn_layers) # T x B x N*output_dim 120 | 121 | if self.is_decoder: 122 | if self.is_phy and self.is_unk: 123 | ld = self.fusion_output_dim 124 | elif self.is_phy and not self.is_unk: 125 | ld = self.unk_latent_dim 126 | elif not self.is_phy and self.is_unk: 127 | ld = self.phy_latent_dim 128 | else: 129 | ld = None 130 | assert NotImplementedError 131 | 132 | self.decoder = Conv_Decoder(latent_dim=ld, 133 | output_dim=self.X_dim, 134 | num_nodes=self.num_nodes) 135 | else: 136 | assert self.fusion_output_dim == self.X_dim 137 | 138 | def forward(self, inputs, y_embed=None): 139 | # (X, A) 140 | seq_len, batch_size = inputs.size(0), inputs.size(1) 141 | inputs = inputs.reshape(seq_len, batch_size, self.num_nodes, self.input_dim) # T x B x N x D 142 | X = inputs[:, :, :, :self.X_dim].reshape((seq_len, batch_size, self.num_nodes * self.X_dim)) # T x B x N*X_dim 143 | last_X = X[-1] # B x N*X_dim 144 | wind_vars = inputs[:, :, :, 4: 6] # T x B x N x 2 wind speed and wind direction 145 | last_wind_vars = wind_vars[-1] # B x N x 2 146 | 147 | if self.embedding_dim: 148 | x_embed = self.embedding_air(inputs[..., 6:].long()) 149 | inputs = torch.cat((inputs[..., :6], x_embed), -1) # after embedding 150 | 151 | # MOL on PDEs and solve 152 | if self.is_phy: 153 | alpha, beta = None, None 154 | if self.is_estimator: # estimate alpha and beta 155 | alpha, beta = self.coeff_estimator(inputs) 156 | phy_y, phy_fe = self.phy_part(last_X, last_wind_vars, alpha, beta) # T x B x N*X_dim 157 | if self.embedding_dim: 158 | y_embed = y_embed.reshape(seq_len, batch_size, self.num_nodes, -1) 159 | y_embed = self.embedding_air(y_embed.long().to(self.device)) 160 | phy_y = torch.cat([phy_y.unsqueeze(-1), y_embed], -1) 161 | phy_z = self.RNN_encoder_pred(phy_y) # T x B x N*phy_latent_dim 162 | else: 163 | phy_z, phy_fe = None, (0, 0) 164 | 165 | # Data-Driven Dynamics 166 | if self.is_unk: 167 | Z0 = self.encoder(inputs) # B x N*latent_dim 168 | unk_z, unk_fe = self.unk_part(Z0) # T x B x N*unk_latent_dim 169 | else: 170 | unk_z, unk_fe = None, (0, 0) 171 | 172 | # Dynamics Fusion 173 | self.loss_CL = None 174 | if self.is_phy and self.is_unk: 175 | assert self.unk_latent_dim == self.phy_latent_dim 176 | self.loss_CL = temporal_alignment(phy_z, unk_z, self.num_nodes, self.unk_latent_dim) 177 | Z = self.gatef_fusion(phy_z, unk_z) # T x B x N x latent_dim 178 | if self.is_decoder: 179 | self.pred_y = self.decoder(Z) 180 | else: 181 | self.pred_y = Z 182 | elif self.is_unk and not self.is_phy: 183 | self.pred_y = self.decoder(unk_z) 184 | elif self.is_phy and not self.is_unk: 185 | self.pred_y = self.decoder(phy_z) 186 | 187 | return self.pred_y, phy_fe + unk_fe 188 | 189 | def phy_part(self, last_X, last_wind_vars, alpha=None, beta=None): 190 | self.phy_odefunc.create_equation(last_wind_vars, self.wind_mean, self.wind_std, alpha, beta) 191 | time_steps_to_predict = torch.arange(start=0, end=self.horizon + 1, step=1).float() # horizon 1 + 24 192 | time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict) 193 | pred_y, fe = self.phy_solver.solve(self.phy_odefunc, last_X, time_steps_to_predict) # T x B x N*D 194 | pred_y = pred_y[1:] 195 | 196 | return pred_y, fe 197 | 198 | def unk_part(self, Z0): 199 | time_steps_to_predict = torch.arange(start=0, end=self.horizon + 1, step=1).float() # horizon 1 + 24 200 | time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict) 201 | pred_z, fe = self.unk_solver.solve(self.unk_odefunc, Z0, time_steps_to_predict) # T x B x N*D 202 | pred_z = pred_z[1:] 203 | return pred_z, fe 204 | 205 | def get_setting(self, args): 206 | setting = 'Air-DualODE_{}--{}_{}_lr{}_loss{}-{}_cl-coeff_{}_bs{}_ft{}_sl{}_pl{}_Phy{}_Unk{}_Fusion{}_des_{}-{}'.format( 207 | args.model.phy_func.knowledge, 208 | f"{int(args.model.phy_func.enable)}{int(args.model.unk_func.enable)}", 209 | args.data.data_name + "_" + args.data.interval, 210 | args.train.lr, 211 | f"{int(args.model.loss.cl_loss)}{int(args.model.loss.pred_loss)}", 212 | args.model.loss.criterion, 213 | args.model.loss.cl_coeff, 214 | args.data.batch_size, 215 | args.model.input_dim, 216 | args.model.seq_len, 217 | args.model.horizon, 218 | 219 | f"{args.model.phy_func.rnn_layers}-" \ 220 | f"{args.model.phy_func.rnn_dim}-"\ 221 | f"{args.model.phy_func.latent_dim}-"\ 222 | f"{args.model.phy_func.gnn_hid_dim}-"\ 223 | f"{args.model.phy_func.gnn_layers}", 224 | 225 | f"{args.model.unk_func.rnn_layers}-"\ 226 | f"{args.model.unk_func.rnn_dim}-"\ 227 | f"{args.model.unk_func.latent_dim}-"\ 228 | f"{args.model.unk_func.n_heads}", 229 | 230 | f"{args.model.fusion.latent_dim}-"\ 231 | f"{args.model.fusion.output_dim}-"\ 232 | f"{args.model.fusion.num_layers}-", 233 | 234 | args.des, 235 | args.exp_idx 236 | ) 237 | if len(self._logger.handlers) == 0: 238 | print('Setting: ', setting) 239 | else: 240 | self._logger.info(setting) 241 | return setting 242 | -------------------------------------------------------------------------------- /exp/exp_air.py: -------------------------------------------------------------------------------- 1 | from exp.exp_basic import Exp_Basic 2 | import torch.nn as nn 3 | from torch import optim 4 | import os 5 | import time 6 | import warnings 7 | import numpy as np 8 | from utils.utils import get_logger, load_graph_data 9 | from Data_Provider.data_factory import data_provider 10 | from utils.metrics import * 11 | from utils.tools import EarlyStopping, count_parameters 12 | 13 | warnings.filterwarnings('ignore') 14 | 15 | class Exp_Air_Pollution(Exp_Basic): 16 | def __init__(self, args): 17 | adj_mx, edge_index, edge_attr, node_attr = load_graph_data(args.data.root_path) 18 | args.adj_mx = adj_mx # N x N 19 | args.edge_index = edge_index # adjacent list: 2 x M 20 | args.edge_attr = edge_attr # M x D 21 | args.node_attr = node_attr # N x D 22 | 23 | if args.to_log_file: 24 | self._log_dir = self._get_log_dir(args) 25 | else: 26 | self._log_dir = None 27 | self._logger = get_logger(self._log_dir, args.model_name, 'info.log', 28 | level=args.log_level, to_stdout=args.to_stdout) 29 | args.logger = self._logger 30 | 31 | if args.data.embed: 32 | args.model.input_dim = int(args.model.input_dim) + int(args.model.embed_dim) 33 | 34 | super(Exp_Air_Pollution, self).__init__(args) 35 | 36 | self.num_nodes = adj_mx.shape[0] 37 | self.input_var = int(self.args.model.input_dim) 38 | self.input_dim = int(self.args.model.X_dim) 39 | self.seq_len = int(self.args.model.seq_len) 40 | self.horizon = int(self.args.model.horizon) 41 | self.output_dim = int(self.args.model.X_dim) 42 | 43 | def _build_model(self): 44 | dataset, _ = self._get_data('val') 45 | self.args.data.mean_ = dataset.scaler.mean_ 46 | self.args.data.std_ = dataset.scaler.scale_ 47 | model = self.model_dict[self.args.model_name].Model(self.args).float() 48 | self._logger.info("Model created") 49 | self._logger.info( 50 | "Total trainable parameters {}".format(count_parameters(model)) 51 | ) 52 | if self.args.GPU.use_multi_gpu and self.args.GPU.use_gpu: 53 | model = nn.DataParallel(model, device_ids=self.args.GPU.device_ids) 54 | return model 55 | 56 | def _get_data(self, flag): 57 | data_set, data_loader = data_provider(self.args, flag) 58 | return data_set, data_loader 59 | 60 | def _select_optimizer(self): 61 | model_optim = optim.Adam(self.model.parameters(), lr=self.args.train.lr, eps=1e-8) 62 | return model_optim 63 | 64 | def _select_criterion(self): 65 | if self.args.model.loss.criterion == "mse": 66 | criterion = nn.MSELoss() 67 | elif self.args.model.loss.criterion == "mae": 68 | criterion = nn.L1Loss() 69 | else: 70 | criterion = nn.L1Loss() 71 | return criterion 72 | 73 | def _select_lr_scheduler(self, optimizer, train_loader): 74 | if self.args.train.lradj == 'MultiStep': 75 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.args.train.steps, 76 | gamma=self.args.train.lr_decay_ratio) 77 | elif self.args.train.lradj == 'TST': 78 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, 79 | steps_per_epoch=len(train_loader), 80 | pct_start=self.args.train.pct_start, 81 | epochs=self.args.train.epochs, 82 | max_lr=self.args.train.lr) 83 | else: 84 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True) 85 | return lr_scheduler 86 | 87 | def vali(self, vali_data, vali_loader, epoch_num, save=False): 88 | with torch.no_grad(): 89 | self.model.eval() 90 | 91 | preds = [] 92 | truths = [] 93 | for i, (x, gt) in enumerate(vali_loader): 94 | x, gt, y_embed = self._prepare_data(x, gt) 95 | pred, fe = self.model(x, y_embed) 96 | 97 | truths.append(gt.cpu()) # T x B x N 98 | preds.append(pred.cpu()) 99 | 100 | if self.args.model_name in ["AirPhyNet", "Origin_AirPhyNet"]: 101 | loss = self.pred_loss(gt, pred) 102 | else: 103 | loss = self.get_loss(x, gt) 104 | 105 | if self.TB_logger: 106 | self.TB_logger.add_scalar("val/loss", loss, epoch_num * len(vali_loader) + i) 107 | 108 | truths = torch.cat(truths, dim=1) 109 | preds = torch.cat(preds, dim=1) # T x B x N 110 | val_loss = self.criterion(truths, preds) 111 | 112 | truths = truths.permute(1, 0, 2) 113 | preds = preds.permute(1, 0, 2) # B x T x N 114 | mae, smape, rmse = self._compute_loss_eval(truths, preds) 115 | 116 | self._logger.info('Evaluation: - mae - {:.4f} - smape - {:.4f} - rmse - {:.4f}' 117 | .format(mae, smape, rmse)) 118 | 119 | return val_loss 120 | 121 | def train(self): 122 | if self.args.TB_dir: 123 | self.TB_logger = self._build_TB_logger(self.model.setting) 124 | else: 125 | self.TB_logger = None 126 | self._logger.info('Model mode: train') 127 | train_data, train_loader = self._get_data(flag='train') 128 | vali_data, vali_loader = self._get_data(flag='val') 129 | self.inverse_transform = train_data.inverse_transform 130 | self.criterion = self._select_criterion() 131 | 132 | model_save_path = os.path.join(self.args.checkpoints, self.model.setting) 133 | if not os.path.exists(model_save_path): 134 | os.makedirs(model_save_path) 135 | optimizer = self._select_optimizer() 136 | early_stopping = EarlyStopping(patience=self.args.train.patience, verbose=True, logger=self._logger) 137 | lr_scheduler = self._select_lr_scheduler(optimizer, train_loader) 138 | 139 | time_now = time.time() 140 | train_steps = len(train_loader) 141 | 142 | self._logger.info('Start training ...') 143 | num_batches = self.args.data.batch_size 144 | self._logger.info("num_batches: {}".format(num_batches)) 145 | 146 | for epoch_num in range(1, self.args.train.epochs + 1): 147 | if self.args.to_stdout: 148 | print('\nTrain epoch %s:' % (epoch_num)) 149 | self.model.train() 150 | 151 | losses = [] 152 | iter_count = 0 153 | for i, (batch_x, batch_y) in enumerate(train_loader): 154 | iter_count += 1 155 | optimizer.zero_grad() 156 | 157 | batch_x, batch_y, y_embed = self._prepare_data(batch_x, batch_y) 158 | output, fe = self.model(batch_x, y_embed) 159 | 160 | if self.args.model_name in ["AirPhyNet", "Origin_AirPhyNet"]: 161 | loss = self.pred_loss(batch_y, output) 162 | else: 163 | loss = self.get_loss(batch_x, batch_y) 164 | self._logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item())) 165 | 166 | self._logger.debug(loss.item()) 167 | losses.append(loss.item()) 168 | 169 | loss.backward() 170 | optimizer.step() 171 | 172 | if self.TB_logger: 173 | self.TB_logger.add_scalar("train/loss", loss.item(), epoch_num * train_steps + i) 174 | 175 | if self.args.train.lradj == 'TST': 176 | lr_scheduler.step() 177 | 178 | del output, loss, batch_x, batch_y 179 | torch.cuda.empty_cache() 180 | 181 | val_loss = self.vali(vali_data, vali_loader, epoch_num) 182 | 183 | if (epoch_num % self.args.train.log_every) == self.args.train.log_every - 1: 184 | speed = (time.time() - time_now) / iter_count 185 | left_time = speed * ((self.args.train.epochs - epoch_num) * train_steps - i) 186 | message = ('Epoch [{}/{}] train_loss: {:.4f}, val_loss: {:.4f}, lr: {:.6f}' 187 | .format(epoch_num, self.args.train.epochs, 188 | np.mean(losses), val_loss, optimizer.param_groups[0]['lr'])) 189 | self._logger.info(message) 190 | self._logger.info('speed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) 191 | iter_count = 0 192 | time_now = time.time() 193 | 194 | # 学习率动态调整 195 | if self.args.train.lradj == 'MultiStep': 196 | lr_scheduler.step() 197 | elif self.args.train.lradj == 'TST': 198 | pass 199 | else: 200 | lr_scheduler.step(val_loss) 201 | 202 | early_stopping(val_loss, self.model, model_save_path) 203 | if early_stopping.early_stop: 204 | print("Early stopping") 205 | break 206 | 207 | self._logger.info("---" * 30) 208 | 209 | @staticmethod 210 | def _get_log_dir(args): 211 | log_dir = args.train.get('log_dir') 212 | if log_dir is None: 213 | run_id = '%s_%s/' % ( 214 | args.model_name, time.strftime('%m-%d-%H-%M-%S')) 215 | base_dir = args.log_base_dir 216 | log_dir = os.path.join(base_dir, run_id) 217 | if not os.path.exists(log_dir): 218 | os.makedirs(log_dir) 219 | return log_dir 220 | 221 | def _prepare_data(self, x, y): 222 | x, y = self._get_x_y(x, y) # B x 24(72 hours) x N x D 223 | x, y, y_embed = self._get_x_y_in_correct_dims(x, y) # 24 x B x N x D 224 | return x.to(self.device), y.to(self.device), y_embed # 24 x B x 35 * 11 225 | 226 | def _get_x_y(self, x, y): 227 | x = x.float() 228 | y = y.float() 229 | x = x.permute(1, 0, 2, 3) 230 | y = y.permute(1, 0, 2, 3) 231 | return x, y 232 | 233 | def _get_x_y_in_correct_dims(self, x, y): 234 | batch_size = x.size(1) 235 | if self.args.data.embed: 236 | station_x = torch.arange(0, self.num_nodes).unsqueeze(0).unsqueeze(0).unsqueeze(-1).repeat(self.seq_len, batch_size, 1, 1) 237 | station_y = torch.arange(0, self.num_nodes).unsqueeze(0).unsqueeze(0).unsqueeze(-1).repeat(self.horizon, batch_size, 1, 1) 238 | x = torch.cat([x, station_x], dim=-1) 239 | y = torch.cat([y, station_y], dim=-1) 240 | x = x.reshape(self.seq_len, batch_size, self.num_nodes * self.input_var) 241 | embed = [6, 7, 8, 9, 10, 11] 242 | y_embed = y[..., embed].reshape(self.horizon, batch_size, self.num_nodes*len(embed)) 243 | y = y[..., :self.output_dim].reshape(self.horizon, batch_size, 244 | self.num_nodes*self.output_dim) 245 | else: 246 | x = x[..., :self.input_var].reshape(self.seq_len, batch_size, self.num_nodes * self.input_var) 247 | y = y[..., :self.output_dim].reshape(self.horizon, batch_size, 248 | self.num_nodes * self.output_dim) 249 | y_embed = None 250 | return x, y, y_embed 251 | 252 | def pred_loss(self, y_true, y_predicted): 253 | y_true = self.inverse_transform(y_true) 254 | y_predicted = self.inverse_transform(y_predicted) 255 | return masked_loss(y_predicted, y_true, MAE) 256 | 257 | def _compute_loss_eval(self, y_true, y_predicted): 258 | y_true = self.inverse_transform(y_true) 259 | y_predicted = self.inverse_transform(y_predicted) 260 | return compute_all_metrics(y_predicted, y_true) 261 | 262 | def kl_loss(self, mu, logvar): 263 | # n_traj x B x N x Latent_dim 264 | var = torch.exp(logvar) 265 | loss = 1/2 * (var + mu**2 - logvar - 1) 266 | return torch.mean(loss.sum(dim=(-1, -2))) 267 | 268 | def get_loss(self, x_true, y_true): 269 | loss = torch.zeros(1).to(self.device) 270 | x_true = torch.reshape(x_true, (self.seq_len, -1, self.num_nodes, self.input_dim)) 271 | x_true = x_true[:, :, :, 0] 272 | if self.args.model.loss.kl_loss: 273 | kl_loss = self.kl_loss(self.model.means_z0, self.model.logvar_z0) 274 | loss += kl_loss 275 | if self.args.model.loss.recon_loss: 276 | recon_loss = self.criterion(x_true, self.model.recon_x) 277 | loss += self.args.model.loss.recon_coeff * recon_loss 278 | if self.args.model.loss.pred_loss: 279 | pred_loss = self.criterion(y_true, self.model.pred_y) 280 | loss += pred_loss 281 | if self.args.model.loss.cl_loss: 282 | loss += self.args.model.loss.cl_coeff * self.model.loss_CL 283 | return loss --------------------------------------------------------------------------------