├── assets └── model.png ├── requirements.txt ├── utils ├── random_seed.py └── graph_build.py ├── LICENSE ├── .gitignore ├── model ├── FNN.py ├── GAT.py └── GFN.py ├── README.md ├── main.py └── graph_shap.py /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjtu-gwdg/GraphFourierNet/HEAD/assets/model.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | joblib==1.4.2 2 | matplotlib==3.10.1 3 | numpy==2.2.3 4 | pandas==2.2.3 5 | scikit_learn==1.6.1 6 | seaborn==0.13.2 7 | shap==0.47.0 8 | torch==2.5.1 9 | torch_geometric==2.6.1 10 | tqdm==4.67.1 11 | -------------------------------------------------------------------------------- /utils/random_seed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def set_random_seed(): 6 | torch.manual_seed(42) 7 | np.random.seed(42) 8 | if torch.cuda.is_available(): 9 | torch.cuda.manual_seed_all(42) 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = False 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 xjtu-gwdg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /utils/graph_build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def build_station_graph(code_mapping, device): 5 | """ 6 | Build an undirected graph based on station code mapping. 7 | 根据站点的编码构建无向图结构。 8 | 9 | Args: 10 | code_mapping (dict): Mapping from station id to station group code. 11 | 站点编号与对应的分组编码映射(如:相同线路的编码相同)。 12 | device (torch.device): The device on which the edge index tensor will be stored. 13 | 构建的边索引张量存放的设备(如:'cuda' 或 'cpu')。 14 | 15 | Returns: 16 | torch.LongTensor: Edge index tensor of shape [2, num_edges], stored on the given device. 17 | 构建好的边索引张量(形状为 [2, 边数])。 18 | """ 19 | edge_index = [[], []] # Edge list: [source_nodes, target_nodes] 边列表:源节点和目标节点。 20 | 21 | code_groups = {} # Group station ids by their shared code. 根据编码将站点分组。 22 | 23 | for sid, code in code_mapping.items(): 24 | code_groups.setdefault(code, []).append(sid) 25 | 26 | for group in code_groups.values(): 27 | # Create undirected edges between consecutive stations in the same group. 28 | # 为同一组中相邻站点创建无向边。 29 | for i in range(len(group) - 1): 30 | edge_index[0].append(group[i]) 31 | edge_index[1].append(group[i + 1]) 32 | edge_index[0].append(group[i + 1]) 33 | edge_index[1].append(group[i]) 34 | 35 | # Remove duplicate edges and move to the target device. 36 | # 去除重复边并转换到指定设备。 37 | return torch.LongTensor(edge_index).unique(dim=1).to(device) 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Virtual environment folders 7 | venv/ 8 | .venv/ 9 | env/ 10 | .env/ 11 | 12 | # VSCode project settings 13 | .vscode/ 14 | 15 | # PyCharm project settings 16 | .idea/ 17 | 18 | # Python egg files 19 | *.egg 20 | *.egg-info/ 21 | dist/ 22 | build/ 23 | eggs/ 24 | parts/ 25 | var/ 26 | sdist/ 27 | develop-eggs/ 28 | 29 | # Distribution / packaging 30 | *.whl 31 | pip-wheel-metadata/ 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .nox/ 41 | .coverage 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Pytest cache 49 | .pytest_cache/ 50 | 51 | # MyPy cache 52 | .mypy_cache/ 53 | 54 | # Other caches 55 | .pyre/ 56 | .pytype/ 57 | 58 | # Temporary files 59 | *.log 60 | *.tmp 61 | *.bak 62 | *.swp 63 | *.swo 64 | 65 | # Jupyter Notebook checkpoints 66 | .ipynb_checkpoints/ 67 | 68 | # OS-specific files 69 | .DS_Store 70 | Thumbs.db 71 | 72 | # Ignore data/output folders (optional if you generate files during training) 73 | data/ 74 | output/ 75 | results/ 76 | logs/ 77 | 78 | # Ignore model checkpoints (optional) 79 | *.pth 80 | *.ckpt 81 | *.h5 82 | 83 | # Ignore shap analysis output (if any) 84 | *.json 85 | *.npy 86 | mc_graph_shap_results/ 87 | shap_results/ 88 | 89 | # Ignore backups or crash reports 90 | *.orig 91 | *.rej 92 | 93 | # Ignore markdown temp files 94 | *.md~ 95 | 96 | # Ignore Visual Studio temporary files 97 | *.suo 98 | *.user 99 | *.userosscache 100 | *.sln.docstates 101 | -------------------------------------------------------------------------------- /model/FNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | 5 | 6 | class FNN(nn.Module): 7 | def __init__(self, pre_length, embed_size, feature_size, seq_length, hidden_size): 8 | super().__init__() 9 | self.embed_size = embed_size 10 | self.pre_length = pre_length 11 | self.feature_size = feature_size 12 | self.seq_length = seq_length 13 | self.scale = 0.1 14 | 15 | # Learnable embedding matrix for input features. 16 | # 输入特征的可学习嵌入矩阵。 17 | self.embeddings = nn.Parameter(torch.randn(feature_size, embed_size)) 18 | 19 | # Learnable weights and biases for two layers of complex linear transformations. 20 | # 用于两个复线性变换层的可学习权重和偏置。 21 | self.w1 = nn.Parameter(self.scale * torch.randn(2, embed_size, embed_size)) 22 | self.b1 = nn.Parameter(self.scale * torch.randn(2, embed_size)) 23 | self.w2 = nn.Parameter(self.scale * torch.randn(2, embed_size, embed_size)) 24 | self.b2 = nn.Parameter(self.scale * torch.randn(2, embed_size)) 25 | 26 | # Fully connected layers for downstream processing after Fourier transformation. 27 | # 傅里叶变换后的下游全连接层处理。 28 | self.fc = nn.Sequential( 29 | nn.Linear(seq_length * embed_size, hidden_size * 2), 30 | nn.LayerNorm(hidden_size * 2), # Apply Layer Normalization. 应用层归一化。 31 | nn.LeakyReLU(), # LeakyReLU activation. 使用LeakyReLU激活函数。 32 | nn.Linear(hidden_size * 2, hidden_size) 33 | ) 34 | 35 | def tokenEmb(self, x): 36 | """Applies embedding by matrix multiplication between input and embedding matrix. 37 | 通过输入与嵌入矩阵的乘法获得嵌入表示。 38 | """ 39 | return torch.matmul(x, self.embeddings) 40 | 41 | def fourierGC(self, x): 42 | """Applies complex-valued graph convolution in the frequency domain. 43 | 在频域中应用复数图卷积。 44 | """ 45 | # First complex linear transformation. 46 | # 第一个复线性变换。 47 | o1_real = torch.einsum('bli,io->blo', x.real, self.w1[0]) - \ 48 | torch.einsum('bli,io->blo', x.imag, self.w1[1]) + self.b1[0] 49 | o1_imag = torch.einsum('bli,io->blo', x.imag, self.w1[0]) + \ 50 | torch.einsum('bli,io->blo', x.real, self.w1[1]) + self.b1[1] 51 | 52 | # Second complex linear transformation with ReLU activation. 53 | # 第二个复线性变换,并使用ReLU激活。 54 | o2_real = f.relu(torch.einsum('bli,io->blo', o1_real, self.w2[0]) - \ 55 | torch.einsum('bli,io->blo', o1_imag, self.w2[1]) + self.b2[0]) 56 | o2_imag = f.relu(torch.einsum('bli,io->blo', o1_imag, self.w2[0]) + \ 57 | torch.einsum('bli,io->blo', o1_real, self.w2[1]) + self.b2[1]) 58 | 59 | # Stack real and imaginary parts to form a complex tensor. 60 | # 将实部与虚部堆叠形成复数张量。 61 | return torch.view_as_complex(torch.stack([o2_real, o2_imag], dim=-1)) 62 | 63 | def forward(self, x): 64 | """Forward pass through the FGN model. 65 | 执行FGN模型的前向传播。 66 | """ 67 | B, seq_len, _ = x.size() # B: batch size, seq_len: 序列长度 68 | 69 | x = self.tokenEmb(x) # Apply token embedding. 应用token嵌入。 70 | x = torch.fft.rfft(x, dim=1, norm='ortho') # Apply real FFT along the time dimension. 在时间维度上进行实数FFT。 71 | x = self.fourierGC(x) # Apply Fourier graph convolution. 应用傅里叶图卷积。 72 | x = torch.fft.irfft(x, n=seq_len, dim=1, norm="ortho") # Inverse FFT to return to time domain. 逆FFT返回时间域。 73 | 74 | return self.fc(x.reshape(B, -1)) # Flatten and pass through fully connected layers. 扁平化后输入全连接层。 75 | -------------------------------------------------------------------------------- /model/GAT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | 5 | 6 | class GATLayer(nn.Module): 7 | def __init__(self, in_features, out_features, heads=1): 8 | super().__init__() 9 | self.heads = heads 10 | self.out_features = out_features 11 | 12 | # Layer normalization for stabilizing training. 13 | # 添加层归一化以稳定训练。 14 | self.norm = nn.LayerNorm(out_features * heads) 15 | 16 | # Residual connection gating parameter. 17 | # 残差连接的门控参数(可学习的残差权重)。 18 | self.res_gate = nn.Parameter(torch.zeros(1)) 19 | 20 | # Linear transformation for input features. 21 | # 输入特征的线性变换。 22 | self.W = nn.Linear(in_features, out_features * heads, bias=False) 23 | 24 | # Attention mechanism for edge-level scoring. 25 | # 注意力机制,用于计算边的权重分数。 26 | self.attn = nn.Linear(2 * out_features, 1, bias=False) 27 | 28 | # Initialize weights. 29 | # 初始化权重。 30 | self.reset_parameters() 31 | 32 | def reset_parameters(self): 33 | # Xavier uniform initialization for weight matrices. 34 | # 使用 Xavier 均匀初始化。 35 | nn.init.xavier_uniform_(self.W.weight) 36 | nn.init.xavier_uniform_(self.attn.weight) 37 | 38 | def grouped_softmax(self, alpha, dst): 39 | """ 40 | Computes softmax over edges grouped by destination node. 41 | 针对每个目标节点,按组进行 softmax 归一化注意力分数。 42 | """ 43 | softmax_values = torch.zeros_like(alpha) 44 | for h in range(self.heads): 45 | alpha_h = alpha[:, h] 46 | unique_dst, inverse_indices = torch.unique(dst, return_inverse=True) 47 | 48 | # Get max value for numerical stability. 49 | # 为了数值稳定性,先减去每组的最大值。 50 | max_values = torch.zeros_like(unique_dst, dtype=alpha.dtype).scatter_reduce_( 51 | 0, inverse_indices, alpha_h, reduce='amax', include_self=False 52 | ) 53 | stable_alpha = alpha_h - max_values[inverse_indices] 54 | 55 | # Compute exp and normalized softmax. 56 | # 计算指数并归一化。 57 | exp_alpha = torch.exp(stable_alpha) 58 | sum_exp = torch.zeros_like(unique_dst, dtype=alpha.dtype).scatter_add_( 59 | 0, inverse_indices, exp_alpha 60 | ) 61 | softmax = exp_alpha / (sum_exp[inverse_indices] + 1e-8) 62 | softmax_values[:, h] = softmax 63 | 64 | return softmax_values 65 | 66 | def forward(self, x, edge_index): 67 | """ 68 | Forward pass of GATLayer. 69 | GAT 层的前向传播。 70 | :param x: Node features (节点特征). 71 | :param edge_index: Edge list [2, E] with source and target indices (边的索引矩阵,包含源和目标). 72 | :return: Updated node features, attention scores, and edge index. 73 | 返回更新后的节点特征、注意力权重和边索引。 74 | """ 75 | residual = x # Save residual connection. 保存残差。 76 | x = x.contiguous() 77 | n = x.size(0) 78 | 79 | # Linear projection and reshape into [N, heads, out_features]. 80 | # 线性变换后重塑为 [节点数, 注意力头数, 输出维度]。 81 | h = self.W(x).view(n, self.heads, self.out_features) 82 | 83 | src, dst = edge_index # Edge source and target. 边的起点和终点。 84 | h_src = h[src] 85 | h_dst = h[dst] 86 | 87 | # Concatenate source and target node embeddings to compute attention score. 88 | # 拼接源点和目标点特征,用于计算注意力分数。 89 | alpha = torch.cat([h_src, h_dst], dim=-1) 90 | alpha = f.leaky_relu(self.attn(alpha), 0.2) # LeakyReLU activation. 使用 LeakyReLU 激活。 91 | alpha = alpha.squeeze(-1) # Remove last dimension. 移除最后一维。 92 | 93 | # Apply grouped softmax to attention weights. 94 | # 对注意力分数按目标节点归一化(分组 softmax)。 95 | alpha = self.grouped_softmax(alpha, dst) 96 | 97 | # Attention-weighted message aggregation. 98 | # 加权聚合邻居信息。 99 | alpha = alpha.unsqueeze(-1) 100 | h_src = h[src] 101 | out = torch.zeros(n, self.heads, self.out_features, device=x.device) 102 | expanded_dst = dst.view(-1, 1, 1).expand(-1, self.heads, self.out_features) 103 | 104 | # Scatter messages to target nodes. 105 | # 将消息根据目标节点索引聚合。 106 | out.scatter_add_(0, expanded_dst, alpha * h_src) 107 | 108 | # Apply residual connection and normalization. 109 | # 应用残差连接与归一化。 110 | out = out.view(n, -1) 111 | out = self.norm(out + self.res_gate * residual) 112 | 113 | return out, alpha.squeeze(-1).permute(1, 0), edge_index 114 | # 返回节点特征、注意力权重([heads, edges])和边索引。 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Fourier Network (GFN) 2 | 3 | Official PyTorch implementation for **_"A Graph Fourier-Based Deep Learning Model to Predict Regional Groundwater Level Variations with Hydrogeological and Spatio-Temporal Interpretability"_**. 4 | 5 | ## Features 6 | - Unified integration of **Fourier Neural Networks (FNN)** and **Graph Attention Networks (GATLayer)**. 7 | - **Dynamic Graph Learner (DGL)** module for adaptive spatial structure learning. 8 | - Supports multiple meteorological drivers: *precipitation*, *temperature*, *NDVI*, *evapotranspiration*. 9 | - Enhanced spatiotemporal interpretability via **gradient sensitivity** and **SHAP analysis**. 10 | 11 | ## Installation 12 | ```bash 13 | # Clone the repository 14 | git clone https://github.com/xjtu-gwdg/GraphFourierNet.git 15 | cd GraphFourierNet 16 | 17 | # Install required packages 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | **Requirements:** 22 | - torch==2.5.1 23 | - torch_geometric==2.6.1 24 | - numpy==2.2.3 25 | - pandas==2.2.3 26 | - scikit_learn==1.6.1 27 | - matplotlib==3.10.1 28 | - seaborn==0.13.2 29 | - shap==0.47.0 30 | - tqdm==4.67.1 31 | - joblib==1.4.2 32 | 33 | 34 | ## Dataset Preparation 35 | 1. Create the directory `data/YRB/` 36 | 2. Place **station-wise Excel files** under `data/YRB/` 37 | 3. Each `.xlsx` file should contain the following columns: 38 | - `yearcol`, `monthcol`, `daycol` 39 | - `pre` (precipitation) 40 | - `tmp` (temperature) 41 | - `ndvi` (vegetation index) 42 | - `et` (evapotranspiration) 43 | - `sw` (target groundwater level, GWL) 44 | 45 | 46 | ## Training 47 | ```bash 48 | python main.py \ 49 | --data YRB \ 50 | --seq_len 6 \ 51 | --pred_len 5 \ 52 | --batch_size 8 \ 53 | --epochs 200 \ 54 | --lr 0.5 \ 55 | --gat_hidden 256 \ 56 | --gat_heads 4 57 | ``` 58 | 59 | 60 | ## Testing 61 | ```bash 62 | python main.py --test --data YRB 63 | ``` 64 | 65 | 66 | ## Main Configuration Options 67 | 68 | | Argument | Description | Default | 69 | |---------------------|-----------------------------------------------|---------| 70 | | `--data` | Dataset name | YRB | 71 | | `--seq_len` | Input sequence length (historical window) | 6 | 72 | | `--pred_len` | Prediction horizon (future steps) | 5 | 73 | | `--feat_dim` | Number of input features | 5 | 74 | | `--batch_size` | Batch size for training | 8 | 75 | | `--epochs` | Number of training epochs | 200 | 76 | | `--lr` | Initial learning rate | 0.5 | 77 | | `--gat_hidden` | Hidden size for GAT layers | 256 | 78 | | `--gat_heads` | Number of heads in GAT | 4 | 79 | | `--gat_layers` | Number of stacked GAT layers | 1 | 80 | | `--dropout` | Dropout rate | 0.5 | 81 | | `--edge_keep_ratio` | Top-K percentage for dynamic edge selection | 0.4 | 82 | | `--train_ratio` | Train set split ratio | 0.7 | 83 | | `--val_ratio` | Validation set split ratio | 0.2 | 84 | 85 | 86 | ## Model Architecture 87 | 88 | ### 1. Fourier Neural Network (FNN) 89 | - Frequency-domain modeling of temporal dynamics. 90 | - Two-layer complex-valued transformation in the Fourier domain. 91 | - Real-Imaginary recombination with inverse FFT back to time domain. 92 | 93 | ### 2. Graph Attention Networks (GATLayer) 94 | - Multi-head self-attention over neighboring nodes. 95 | - Residual connections and layer normalization for stability. 96 | - Supports **grouped softmax** based on target nodes. 97 | 98 | ### 3. Dynamic Graph Learner (DGL) 99 | - Learns new edges based on node embeddings. 100 | - Combines **static HGU-based graphs** with **dynamic edges**. 101 | - Edge scoring network with trainable similarity predictor. 102 | 103 | ### 4. Final Decoder 104 | - Fully connected layers with **LayerNorm** + **ELU** + **Dropout** for robust forecasting. 105 | 106 | 107 | ## Overall Model Flow 108 | ![Model Architecture](assets/model.png) 109 | 110 | 111 | ## Model Interpretability 112 | 113 | - **Spatiotemporal Feature SHAP Analysis**: 114 | Analyze contributions of different meteorological drivers (*precipitation*, *temperature*, *NDVI*, *evapotranspiration*) across multiple **time steps**, providing fine-grained temporal interpretability. 115 | 116 | - **Spatial Graph SHAP Analysis**: 117 | Introduce a custom **Monte Carlo-based Edge SHAP method** to quantify the importance of **nodes** and **edges** within the dynamically learned graph structure, enabling spatial interpretability of groundwater interactions. 118 | 119 | > ✨ This dual-stage SHAP framework enhances interpretability at both the feature-time dimension and the spatial graph dimension. 120 | 121 | 122 | 123 | ## Acknowledgements 124 | This research builds upon: 125 | - [Graph Attention Networks (GAT)](https://github.com/PetarV-/GAT) 126 | - [FourierGNN](https://github.com/aikunyi/FourierGNN) 127 | - [SHAP](https://github.com/shap/shap) 128 | 129 | 130 | --- 131 | 132 | Made with ❤️ by xjtu-gwdg Team. -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch.nn.functional as f 5 | import torch 6 | from torch import nn 7 | from torch_geometric.data import DataLoader 8 | from tqdm import tqdm 9 | from data.data import load_and_preprocess, GWLDataset 10 | from model.GFN import GraphFourierNet 11 | from utils.graph_build import build_station_graph 12 | from utils.random_seed import set_random_seed 13 | 14 | parser = argparse.ArgumentParser(description='Graph Fourier Net:') 15 | parser.add_argument('--data', type=str, default='YRB', help='data set') 16 | parser.add_argument('--seq_len', type=int, default=6, help='input length') 17 | parser.add_argument('--pred_len', type=int, default=5, help='predict length') 18 | parser.add_argument('--feat_dim', type=int, default=5, help='feature size') 19 | parser.add_argument('--batch_size', type=int, default=8, help='input data batch size') 20 | parser.add_argument('--epochs', type=int, default=200, help='train epochs') 21 | parser.add_argument('--lr', type=float, default=0.5, help='learning epochs') 22 | parser.add_argument('--train_ratio', type=float, default=0.7) 23 | parser.add_argument('--val_ratio', type=float, default=0.2) 24 | parser.add_argument('--gat_hidden', type=int, default=256, help='gat dimensions') 25 | parser.add_argument('--gat_heads', type=int, default=4, help='gat heads') 26 | parser.add_argument('--gat_layers', type=int, default=1, help='gat layers') 27 | parser.add_argument('--dropout', type=float, default=0.5, help='dropout') 28 | parser.add_argument('--weight_decay', type=float, default=1e-4) 29 | parser.add_argument('--fnn_embed_size', type=int, default=256) 30 | parser.add_argument('--edge_keep_ratio', type=float, default=0.4) 31 | args = parser.parse_args() 32 | 33 | 34 | def collate_fn(batch): 35 | x = torch.stack([item[0] for item in batch]) 36 | y = torch.stack([item[1] for item in batch]) 37 | return x, y 38 | 39 | 40 | def train(): 41 | print(f'Training configs: {args}') 42 | optimizer = torch.optim.RAdam(model.parameters(), lr=args.lr) 43 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5, verbose=True) 44 | best_loss = float('inf') 45 | for epoch in range(args.epochs): 46 | model.train() 47 | total_loss = 0 48 | for x, y in tqdm(train_loader): 49 | x, y = x.to(device), y.to(device) 50 | optimizer.zero_grad() 51 | pred, _, _ = model(x) 52 | loss = f.smooth_l1_loss(pred, y) 53 | for name, param in model.named_parameters(): 54 | if 'attn' in name and param.grad is not None: 55 | print(f"{name} grad mean: {param.grad.mean().item():.3e}") 56 | loss.backward() 57 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5) 58 | optimizer.step() 59 | total_loss += loss.item() 60 | 61 | val_loss = 0 62 | model.eval() 63 | with torch.no_grad(): 64 | for x, y in val_loader: 65 | x, y = x.to(device), y.to(device) 66 | pred, _, _ = model(x) 67 | val_loss += f.smooth_l1_loss(pred.view(-1), y.view(-1)).item() 68 | 69 | avg_train = total_loss / len(train_loader) 70 | avg_val = val_loss / len(val_loader) 71 | scheduler.step(avg_val) 72 | print(f'Epoch {epoch + 1:03d} | Train Loss: {avg_train:.7f} | Val Loss: {avg_val:.7f}') 73 | torch.save(model.state_dict(), f'data/{args.data}/pred_len={args.pred_len}_feat_dim={args.feat_dim}_best_model.pth') 74 | 75 | 76 | def test(): 77 | model.load_state_dict(torch.load(f'data/{args.data}/best_model.pth', weights_only=True)) 78 | model.eval() 79 | all_preds, all_trues = [], [] 80 | with torch.no_grad(): 81 | for x, y in test_loader: 82 | x = x.to(device) 83 | pred, alphas, emb = model(x) 84 | all_preds.append(pred.cpu().numpy()) 85 | all_trues.append(y.cpu().numpy()) 86 | # 整合结果 87 | preds = np.concatenate(all_preds, axis=0) # (num_samples, num_stations * pred_len) 88 | trues = np.concatenate(all_trues, axis=0) # (num_samples, num_stations * pred_len) 89 | 90 | # 重塑为三维数组 (num_samples, num_stations, pred_len) 91 | preds_3d = preds.reshape(-1, test_set.num_stations, args.pred_len) 92 | trues_3d = trues.reshape(-1, test_set.num_stations, args.pred_len) 93 | 94 | # 反标准化 95 | preds_inv = np.zeros_like(preds_3d) 96 | trues_inv = np.zeros_like(trues_3d) 97 | 98 | for i in range(test_set.num_stations): 99 | station_pred = preds_3d[:, i, :].reshape(-1, 1) 100 | preds_inv[:, i, :] = test_set.target_scalers[i].inverse_transform(station_pred).reshape(-1, args.pred_len) 101 | 102 | station_true = trues_3d[:, i, :].reshape(-1, 1) 103 | trues_inv[:, i, :] = test_set.target_scalers[i].inverse_transform(station_true).reshape(-1, args.pred_len) 104 | 105 | 106 | if __name__ == "__main__": 107 | set_random_seed() 108 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 109 | stations = [f.split('.')[0] for f in os.listdir(f'data/{args.data}') if f.endswith('.xlsx')] 110 | num_nodes = len(stations) 111 | features, targets, code_map, full_df = load_and_preprocess(f'data/{args.data}', stations) 112 | edge_index = build_station_graph(code_map, device) 113 | train_set = GWLDataset(features, targets, mode='train', full_df=full_df, cfg=args) 114 | val_set = GWLDataset(features, targets, feat_scalers=train_set.feat_scalers, 115 | target_scalers=train_set.target_scalers, mode='val', full_df=full_df, cfg=args) 116 | test_set = GWLDataset(features, targets, feat_scalers=train_set.feat_scalers, 117 | target_scalers=train_set.target_scalers, mode='test', full_df=full_df, cfg=args) 118 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn) 119 | val_loader = DataLoader(val_set, batch_size=args.batch_size, collate_fn=collate_fn) 120 | test_loader = DataLoader(test_set, batch_size=args.batch_size, collate_fn=collate_fn) 121 | model = GraphFourierNet(num_nodes, edge_index, args).to(device) 122 | 123 | 124 | def weights_init(m): 125 | if isinstance(m, nn.Linear): 126 | nn.init.orthogonal_(m.weight) 127 | if m.bias is not None: 128 | nn.init.constant_(m.bias, 0.1) 129 | elif isinstance(m, nn.Embedding): 130 | nn.init.normal_(m.weight, mean=0, std=0.1) 131 | 132 | 133 | model.apply(weights_init) 134 | train() 135 | test() 136 | -------------------------------------------------------------------------------- /graph_shap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import geopandas as gpd 4 | import seaborn as sns 5 | import pandas as pd 6 | import torch 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import networkx as nx 10 | import matplotlib 11 | from tqdm import tqdm 12 | from torch import nn 13 | from torch.utils.data import DataLoader 14 | from scipy.spatial.distance import pdist, squareform 15 | 16 | from data.data import load_and_preprocess, GWLDataset 17 | from main import args 18 | from model.GFN import GraphFourierNet 19 | from utils.graph_build import build_station_graph 20 | from utils.random_seed import set_random_seed 21 | 22 | matplotlib.use("TkAgg") 23 | 24 | # === Utility Functions === 25 | 26 | def load_shapefile(shapefile_path): 27 | gdf = gpd.read_file(shapefile_path) 28 | gdf['id'] = gdf['id'].astype(str) 29 | return gdf 30 | 31 | def build_pos_from_shapefile(station_shp, stations): 32 | gdf = load_shapefile(station_shp) 33 | stations_str = list(map(str, stations)) 34 | pos = {} 35 | for idx, station_id in enumerate(stations_str): 36 | matched = gdf[gdf['id'] == station_id] 37 | if not matched.empty: 38 | lon, lat = matched.geometry.values[0].x, matched.geometry.values[0].y 39 | pos[idx] = (lon, lat) 40 | return pos 41 | 42 | def map_station_to_region(station_shp, region_shp): 43 | stations_gdf = load_shapefile(station_shp) 44 | regions_gdf = load_shapefile(region_shp) 45 | stations_gdf = stations_gdf.set_geometry("geometry") 46 | regions_gdf = regions_gdf.set_geometry("geometry") 47 | joined = gpd.sjoin(stations_gdf, regions_gdf, how="left", predicate="within") 48 | station_to_region = dict(zip(joined['id_left'], joined['GWZ'])) 49 | return station_to_region 50 | 51 | # === Core Classes === 52 | 53 | class FixedGraphModel(nn.Module): 54 | def __init__(self, original_model): 55 | super().__init__() 56 | self.original_model = original_model 57 | self.static_edge_index = original_model.edge_index 58 | self.current_edges = self.static_edge_index 59 | if hasattr(self.original_model, 'dynamic_learner'): 60 | for param in self.original_model.dynamic_learner.parameters(): 61 | param.requires_grad_(False) 62 | 63 | def forward(self, x): 64 | saved = self.original_model.edge_index 65 | try: 66 | self.original_model.edge_index = self.current_edges 67 | return self.original_model(x) 68 | finally: 69 | self.original_model.edge_index = saved 70 | 71 | class MonteCarloEdgeExplainer: 72 | def __init__(self, fixed_model, edge_index, device='cuda', num_samples=50): 73 | self.model = fixed_model 74 | self.edge_index = edge_index.to(device) 75 | self.device = device 76 | self.num_edges = self.edge_index.size(1) 77 | self.num_samples = num_samples 78 | 79 | def explain(self, x): 80 | x = x.to(self.device) 81 | shap_values = np.zeros(self.num_edges, dtype=np.float32) 82 | for e_idx in tqdm(range(self.num_edges), desc="MC Explaining"): 83 | sum_diff = 0.0 84 | for _ in range(self.num_samples): 85 | sub_mask = torch.rand(self.num_edges, device=self.device) < 0.5 86 | sub_mask[e_idx] = False 87 | self.model.current_edges = self.edge_index[:, sub_mask] 88 | with torch.no_grad(): 89 | base_val = self.model(x)[0].mean().item() 90 | 91 | with_edge = sub_mask.clone() 92 | with_edge[e_idx] = True 93 | self.model.current_edges = self.edge_index[:, with_edge] 94 | with torch.no_grad(): 95 | compare_val = self.model(x)[0].mean().item() 96 | 97 | sum_diff += (compare_val - base_val) 98 | 99 | shap_values[e_idx] = sum_diff / self.num_samples 100 | return shap_values 101 | 102 | class MCGraphResult: 103 | def __init__(self, edge_index, edge_shap): 104 | self.edge_index = edge_index 105 | self.edge_shap = edge_shap 106 | 107 | def compute_node_contrib(self, num_nodes): 108 | node_contrib = np.zeros(num_nodes, dtype=np.float32) 109 | src = self.edge_index[0] 110 | dst = self.edge_index[1] 111 | for e_idx, val in enumerate(self.edge_shap): 112 | node_contrib[src[e_idx]] += val / 2 113 | node_contrib[dst[e_idx]] += val / 2 114 | return node_contrib 115 | 116 | def top_edges(self, k=10): 117 | idx_sort = np.argsort(-self.edge_shap) 118 | for i in range(min(k, len(idx_sort))): 119 | e = idx_sort[i] 120 | print(f"Edge {e}: shap={self.edge_shap[e]:.4f}, (src={self.edge_index[0, e]}, dst={self.edge_index[1, e]})") 121 | 122 | 123 | # === Main Execution === 124 | 125 | def main(): 126 | set_random_seed() 127 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 128 | data_dir = os.path.join('../../data', args.data) 129 | stations = [os.path.splitext(f)[0] for f in os.listdir(data_dir) if f.endswith('.xlsx')] 130 | 131 | features, targets, code_map, full_df = load_and_preprocess(data_dir, stations) 132 | num_nodes = len(stations) 133 | train_set = GWLDataset(features, targets, mode='train', full_df=full_df, cfg=args) 134 | edge_index = build_station_graph(code_map, device) 135 | model = GraphFourierNet(num_nodes, edge_index, args).to(device) 136 | 137 | ckpt_path = f"../../data/{args.data}/best_model.pth" 138 | if os.path.exists(ckpt_path): 139 | model.load_state_dict(torch.load(ckpt_path, map_location=device)) 140 | model.eval() 141 | 142 | fixed_model = FixedGraphModel(model).to(device) 143 | loader = DataLoader(train_set, batch_size=8, shuffle=False, collate_fn=lambda batch: (torch.stack([b[0] for b in batch]), torch.stack([b[1] for b in batch]))) 144 | batch_x, _ = next(iter(loader)) 145 | sample_x = batch_x[:1] 146 | 147 | save_dir = f"../../data/{args.data}/mc_graph_shap_results" 148 | os.makedirs(save_dir, exist_ok=True) 149 | save_path = os.path.join(save_dir, "mc_explain.json") 150 | 151 | if os.path.exists(save_path): 152 | with open(save_path, 'r') as f: 153 | out_data = json.load(f) 154 | edge_index_np = np.array(out_data['edge_index']) 155 | shap_values = np.array(out_data['shap_values']) 156 | node_contrib = np.array(out_data['node_contrib']) 157 | else: 158 | mc_explainer = MonteCarloEdgeExplainer(fixed_model, edge_index, device=device, num_samples=50) 159 | shap_values = mc_explainer.explain(sample_x) 160 | 161 | edge_index_np = edge_index.cpu().numpy() 162 | result = MCGraphResult(edge_index_np, shap_values) 163 | node_contrib = result.compute_node_contrib(num_nodes) 164 | 165 | out_data = { 166 | 'edge_index': edge_index_np.tolist(), 167 | 'shap_values': shap_values.tolist(), 168 | 'node_contrib': node_contrib.tolist(), 169 | } 170 | with open(save_path, 'w') as f: 171 | json.dump(out_data, f, indent=2) 172 | 173 | result = MCGraphResult(edge_index_np, shap_values) 174 | result.top_edges(k=10) 175 | 176 | if __name__ == "__main__": 177 | main() -------------------------------------------------------------------------------- /model/GFN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | from model.FNN import FNN 5 | from model.GAT import GATLayer 6 | 7 | 8 | class DynamicGraphLearner(nn.Module): 9 | def __init__(self, num_nodes, hidden_size, cfg): 10 | super().__init__() 11 | self.num_nodes = num_nodes 12 | 13 | # Learnable node embeddings. 14 | # 可学习的节点嵌入。 15 | self.node_emb = nn.Embedding(num_nodes, hidden_size) 16 | 17 | # Static graph encoder using GAT. 18 | # 使用 GAT 的静态图编码器。 19 | self.static_gat = GATLayer(hidden_size, hidden_size, heads=1) 20 | 21 | # Edge similarity scoring network. 22 | # 边相似度评分网络。 23 | self.sim_fc = nn.Sequential( 24 | nn.Linear(2 * hidden_size, 32), 25 | nn.Tanh(), 26 | nn.Linear(32, 1, bias=False) 27 | ) 28 | 29 | self.cfg = cfg 30 | self.reset_parameters() 31 | 32 | def reset_parameters(self): 33 | # Initialize parameters with Xavier uniform. 34 | # 使用 Xavier 均匀初始化参数。 35 | nn.init.xavier_uniform_(self.node_emb.weight) 36 | for layer in self.sim_fc: 37 | if isinstance(layer, nn.Linear): 38 | nn.init.xavier_uniform_(layer.weight) 39 | 40 | def forward(self, x, static_edge_index): 41 | """ 42 | Generate a dynamic graph structure based on node embeddings. 43 | 根据节点嵌入动态生成图结构。 44 | """ 45 | node_ids = torch.arange(self.num_nodes, device=x.device) 46 | emb = self.node_emb(node_ids) 47 | 48 | # Get node embeddings through static GAT. 49 | # 通过静态 GAT 得到节点嵌入。 50 | emb, _, _ = self.static_gat(emb, static_edge_index) 51 | 52 | # Construct all possible node pairs (excluding self-loops). 53 | # 构造所有可能的节点对(排除自环)。 54 | all_nodes = torch.arange(self.num_nodes, device=x.device) 55 | candidate_src, candidate_dst = torch.meshgrid(all_nodes, all_nodes, indexing='ij') 56 | mask = candidate_src != candidate_dst # Exclude self-loops. 排除自环。 57 | candidate_src = candidate_src[mask] 58 | candidate_dst = candidate_dst[mask] 59 | 60 | # Remove existing static edges from candidates. 61 | # 从候选边中排除已有的静态边。 62 | static_src, static_dst = static_edge_index 63 | static_set = set(zip(static_src.cpu().numpy(), static_dst.cpu().numpy())) 64 | 65 | candidate_pairs = zip(candidate_src.cpu().numpy(), candidate_dst.cpu().numpy()) 66 | keep_mask = [tuple(pair) not in static_set for pair in candidate_pairs] 67 | candidate_src = candidate_src[keep_mask] 68 | candidate_dst = candidate_dst[keep_mask] 69 | 70 | # Compute similarity scores for candidate edges. 71 | # 为候选边计算相似度得分。 72 | src_emb = emb[candidate_src] 73 | dst_emb = emb[candidate_dst] 74 | pair_feat = torch.cat([src_emb, dst_emb], dim=-1) 75 | scores = torch.tanh(self.sim_fc(pair_feat)).squeeze() 76 | 77 | # Select top-k scored edges to form dynamic edges. 78 | # 选择得分最高的 top-k 边构成动态边。 79 | k = int(len(scores) * self.cfg.edge_keep_ratio) 80 | _, topk_indices = torch.topk(scores, k=k) 81 | dynamic_edges = torch.stack([candidate_src[topk_indices], candidate_dst[topk_indices]]) 82 | 83 | # Combine static and dynamic edges, and remove duplicates. 84 | # 合并静态和动态边,并去重。 85 | combined_edges = torch.cat([static_edge_index, dynamic_edges], dim=1) 86 | combined_edges = combined_edges.unique(dim=1) 87 | 88 | return combined_edges, emb 89 | 90 | 91 | class GraphFourierNet(nn.Module): 92 | def __init__(self, num_nodes, edge_index, cfg): 93 | super().__init__() 94 | self.num_nodes = num_nodes 95 | 96 | # Register static edge index as a buffer (not a parameter). 97 | # 注册静态边索引为 buffer(非参数)。 98 | self.register_buffer('edge_index', edge_index) 99 | 100 | self.cfg = cfg 101 | self.edge_index = edge_index 102 | 103 | # Module to learn dynamic edges. 104 | # 用于学习动态边结构的模块。 105 | self.dynamic_learner = DynamicGraphLearner( 106 | num_nodes=num_nodes, 107 | hidden_size=cfg.gat_hidden, 108 | cfg=cfg 109 | ) 110 | 111 | # Fourier Graph Network for initial encoding. 112 | # FGN 模块用于初始时序编码。 113 | self.fnn = FNN( 114 | pre_length=cfg.pred_len, 115 | embed_size=cfg.fnn_embed_size, 116 | feature_size=cfg.feat_dim, 117 | seq_length=cfg.seq_len, 118 | hidden_size=cfg.gat_hidden 119 | ) 120 | 121 | # Multi-layer GAT encoder. 122 | # 多层 GAT 编码器。 123 | self.gat = nn.ModuleList() 124 | for _ in range(cfg.gat_layers): 125 | self.gat.append(GATLayer(cfg.gat_hidden, 126 | cfg.gat_hidden // cfg.gat_heads, 127 | heads=cfg.gat_heads)) 128 | 129 | # Final prediction decoder. 130 | # 最后的预测解码器。 131 | self.decoder = nn.Sequential( 132 | nn.Linear(cfg.gat_hidden, 64), 133 | nn.LayerNorm(64), 134 | nn.ELU(), 135 | nn.Dropout(cfg.dropout), 136 | nn.Linear(64, cfg.pred_len) 137 | ) 138 | 139 | def forward(self, x): 140 | """ 141 | Full forward pass through GraphFourierNet. 142 | 执行 GraphFourierNet 的完整前向传播。 143 | :param x: Input tensor of shape [B, N, T, F] (输入张量:[批次, 节点, 时间步, 特征]) 144 | :return: Prediction, attention weights, and final node embeddings. 145 | 返回:预测结果,注意力分数,节点嵌入。 146 | """ 147 | batch_size, num_nodes, seq_len, feat_dim = x.size() 148 | 149 | # Flatten for FGN input. 150 | # 调整形状用于 FGN 输入。 151 | x = x.view(-1, seq_len, feat_dim) 152 | gat_input = self.fnn(x) 153 | gat_input = gat_input.view(batch_size, num_nodes, -1) 154 | 155 | # Learn dynamic edge structure. 156 | # 学习动态边结构。 157 | dynamic_edge_index, node_emb = self.dynamic_learner(gat_input, self.edge_index) 158 | 159 | # Combine static and dynamic edges for current batch. 160 | # 当前 batch 合并静态与动态边。 161 | combined_edges = torch.cat([self.edge_index, dynamic_edge_index], dim=1) 162 | batch_edge_index = torch.cat( 163 | [combined_edges for _ in range(batch_size)], 164 | dim=1 165 | ) 166 | 167 | # GAT propagation over dynamic graph. 168 | # 在动态图上进行 GAT 传播。 169 | all_alphas = [] 170 | x_gat = gat_input.view(-1, self.cfg.gat_hidden) 171 | 172 | for i, gat_layer in enumerate(self.gat): 173 | residual = x_gat 174 | x_gat, alpha, edges = gat_layer(x_gat, batch_edge_index) 175 | all_alphas.append((alpha.cpu(), edges.cpu())) 176 | 177 | # Add residual connection. 178 | # 添加残差连接。 179 | x_gat = x_gat + residual 180 | 181 | if i != self.cfg.gat_layers - 1: 182 | x_gat = f.elu(x_gat) 183 | x_gat = f.dropout(x_gat, p=self.cfg.dropout, training=self.training) 184 | 185 | # Reshape and decode. 186 | # 重塑输出并进行解码。 187 | x_out = x_gat.view(batch_size, num_nodes, -1) 188 | output = self.decoder(x_out) 189 | 190 | return output.view(batch_size, -1), all_alphas, node_emb 191 | --------------------------------------------------------------------------------