├── .gitignore ├── LICENSE ├── README.md ├── asset └── architecture.png ├── config.yaml ├── data.py ├── dynamics.py ├── main.py ├── metrics.py ├── model.py ├── requirements.txt ├── simluator.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # *.png 2 | # *.pdf 3 | # *.gif 4 | # *.jpg 5 | # *.svg 6 | 7 | *.npz 8 | *.npy 9 | *.pkl 10 | *.pth 11 | *.h5 12 | *.ckpt 13 | *.pt 14 | 15 | *.csv 16 | 17 | *.json 18 | *.conf 19 | 20 | # mercator 21 | *.inf* 22 | *.obs* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 FIB LAB, Tsinghua University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiskNet 2 | 3 | The repo is the official implementation for our paper: "Predicting Long-term Dynamics of Complex Networks via Identifying Skeleton in Hyperbolic Space” (KDD 2024). 4 | 5 | 6 | 7 | ## Overall Architecture 8 | 9 | DiskNet: (1) Hyperbolic Renormalization Group, which identifies the representation and skeleton of network dynamics; (2) Neural Dynamics on Skeleton, which models the dynamics of super-nodes on the skeleton; and (3) Degree-based Super-Resolution, which lifts the predicted values of super-nodes to the original nodes. 10 | 11 | ![architecture](./asset/architecture.png) 12 | 13 | 14 | 15 | ## Environment Setup 16 | 17 | ``` 18 | conda create --name --file requirement.txt 19 | ``` 20 | 21 | 22 | 23 | ## Usage 24 | 25 | **Config:** 26 | 27 | graph_type: `BA `, `WS`, `Drosophila`, `Social`, `Web`, `PowerGrid` or `Airport`; 28 | 29 | dynamics: `HindmarshRose`, `FitzHughNagumo` or `CoupledRossler` 30 | 31 | **Run:** 32 | 33 | ```shell 34 | python main.py 35 | ``` 36 | 37 | 38 | 39 | ## Citation 40 | 41 | If you find this repo helpful, please cite our paper. 42 | 43 | ``` 44 | @inproceedings{li2024predicting, 45 | title={Predicting Long-term Dynamics of Complex Networks via Identifying Skeleton in Hyperbolic Space}, 46 | author={Li, Ruikun and Wang, Huandong and Piao, Jinghua and Liao, Qingmin and Li, Yong}, 47 | booktitle={Proceedings of the 30th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, 48 | pages={1655--1666}, 49 | year={2024} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /asset/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsinghua-fib-lab/DiskNet/b53b41f372b5764e4ad6edeb080f1bdd1cefcad4/asset/architecture.png -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # Key 2 | graph_type: BA # BA, WS, Drosophila, Social, Web, Airport, PowerGrid 3 | model: DiskNet 4 | dynamics: HindmarshRose # HindmarshRose, FitzHughNagumo, CoupledRossler 5 | 6 | # Common 7 | cpu_num: 0 8 | seed: 612 9 | device: cuda:5 10 | log_root: logs/ 11 | data_root: data/ 12 | 13 | # NetworkSimulator 14 | node_num: 5000 15 | layout: random 16 | edge_num: 3 # BA 17 | ring_lattice_k: 5 # WS 18 | rewiring_prob: 0.1 # WS 19 | 20 | # Mercator 21 | quiet_mode: True 22 | fast_mode: True 23 | validation_mode: True 24 | post_kappa: True 25 | refine: False 26 | 27 | # Model 28 | lr: 0.001 29 | max_epoch: 50 30 | lr_step: 50 31 | lr_decay: 0.9 32 | val_interval: 1 33 | 34 | DiskNet: 35 | n_dim: 2 36 | ratio: 0.02 37 | ag_hid_dim: 8 38 | sr_hid_dim: 16 39 | ode_hid_dim: 16 40 | k: 5 41 | log: True 42 | method: euler 43 | pretrain_epoch: 3000 44 | prior_init: True 45 | 46 | # Dataset 47 | lookback: 12 48 | horizon: 120 49 | train_ratio: 0.6 50 | val_ratio: 0.2 51 | test_ratio: 0.2 52 | batch_size: 8 53 | 54 | # HindmarshRose 55 | HindmarshRose: 56 | dim: 3 57 | total_t: 20.0 58 | sim_dt: 0.01 59 | dt: 0.04 60 | epsilon: 0.15 61 | a: 1.0 62 | b: 3.0 63 | c: 1.0 64 | u: 5.0 65 | s: 4.0 66 | r: 0.005 67 | I: 3.24 68 | v: 2.0 69 | lam: 10.0 70 | omega: 1.0 71 | x0: -1.6 72 | z_min: [-1, -7, 2.5] 73 | z_max: [2, 0.5, 4] 74 | # FitzHughNagumo 75 | FitzHughNagumo: 76 | dim: 2 77 | total_t: 50.0 78 | sim_dt: 0.01 79 | dt: 0.1 80 | a: 0.28 81 | b: 0.5 82 | c: -0.04 83 | epsilon: 1.0 84 | z_min: [-1.6, -0.5] 85 | z_max: [1.5, 6.5] 86 | # CoupledRossler 87 | CoupledRossler: 88 | dim: 3 89 | total_t: 50.0 90 | sim_dt: 0.005 91 | dt: 0.1 92 | epsilon: 0.15 93 | a: 0.2 94 | b: 0.2 95 | c: -6.0 96 | delta: 0.2 97 | z_min: [-10, -10, 0] 98 | z_max: [10, 10, 5] -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | 7 | class Dataset(Dataset): 8 | 9 | def __init__(self, args, mode='train'): 10 | super().__init__() 11 | self.args = args 12 | self.mode = mode 13 | self.lookback = args.lookback 14 | self.horizon = args.horizon 15 | 16 | try: 17 | processed_data = np.load(f'{args.data_dir}/{args.dynamics}/dataset/{mode}_{self.args.lookback}_{self.args.horizon}.npz') 18 | self.X = torch.from_numpy(processed_data['X']).float().to(self.args.device) 19 | self.Y = torch.from_numpy(processed_data['Y']).float().to(self.args.device) 20 | self.mean_y, self.std_y = processed_data['mean_y'], processed_data['std_y'] 21 | except: 22 | self.process() 23 | 24 | def __getitem__(self, index): 25 | return self.X[index], self.Y[index] 26 | 27 | def __len__(self): 28 | return len(self.X) 29 | 30 | def process(self): 31 | # origin data: (sequence_length, node_num, feature_dim) 32 | # return: X: (sequence_length-lookback-horizon+1, lookback, node_num, feature_dim), Y: (sequence_length-lookback-horizon+1, horizon, node_num, feature_dim) 33 | simulation = np.load(f'{self.args.data_dir}/{self.args.dynamics}/dynamics.npz')['X'] 34 | 35 | lookback = self.args.lookback 36 | horizon = self.args.horizon 37 | 38 | # Sliding window 39 | idx = np.arange(0, simulation.shape[0]-lookback-horizon+1) 40 | X = np.stack([simulation[i:i+lookback] for i in idx], axis=0) 41 | Y = np.stack([simulation[i+lookback:i+lookback+horizon] for i in idx], axis=0) 42 | 43 | # Normalize 44 | self.mean_x, self.mean_y = X.mean(axis=(0, 1, 2), keepdims=True), Y.mean(axis=(0, 1, 2), keepdims=True) 45 | self.std_x, self.std_y = X.std(axis=(0, 1, 2), keepdims=True), Y.std(axis=(0, 1, 2), keepdims=True) 46 | X = (X - self.mean_x) / self.std_x 47 | Y = (Y - self.mean_y) / self.std_y 48 | 49 | # Split train and test 50 | train_ratio, val_ratio = self.args.train_ratio, self.args.val_ratio 51 | train_size, val_ratio = int(X.shape[0] * train_ratio), int(X.shape[0] * val_ratio) 52 | train_idx = np.random.choice(np.arange(X.shape[0]), train_size, replace=False) 53 | val_idx = np.random.choice(np.setdiff1d(np.arange(X.shape[0]), train_idx), val_ratio, replace=False) 54 | test_idx = np.setdiff1d(np.arange(X.shape[0]), np.concatenate([train_idx, val_idx])) 55 | X_train, Y_train = X[train_idx], Y[train_idx] 56 | X_val, Y_val = X[val_idx], Y[val_idx] 57 | X_test, Y_test = X[test_idx], Y[test_idx] 58 | 59 | # Save 60 | os.makedirs(f'{self.args.data_dir}/{self.args.dynamics}/dataset', exist_ok=True) 61 | np.savez(f'{self.args.data_dir}/{self.args.dynamics}/dataset/train_{self.args.lookback}_{self.args.horizon}.npz', X=X_train, Y=Y_train, mean_y=self.mean_y, std_y=self.std_y) 62 | np.savez(f'{self.args.data_dir}/{self.args.dynamics}/dataset/val_{self.args.lookback}_{self.args.horizon}.npz', X=X_val, Y=Y_val, mean_y=self.mean_y, std_y=self.std_y) 63 | np.savez(f'{self.args.data_dir}/{self.args.dynamics}/dataset/test_{self.args.lookback}_{self.args.horizon}.npz', X=X_test, Y=Y_test, mean_y=self.mean_y, std_y=self.std_y) 64 | 65 | if self.mode=='train': 66 | self.X, self.Y = X_train, Y_train 67 | elif self.mode=='val': 68 | self.X, self.Y = X_val, Y_val 69 | elif self.mode=='test': 70 | self.X, self.Y = X_test, Y_test 71 | 72 | # Convert to torch tensor 73 | self.X = torch.from_numpy(self.X).float().to(self.args.device) 74 | self.Y = torch.from_numpy(self.Y).float().to(self.args.device) 75 | 76 | def getLoader(self): 77 | return DataLoader(self, batch_size=self.args.batch_size, shuffle=True, drop_last=True) -------------------------------------------------------------------------------- /dynamics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class FitzHughNagumo: 5 | def __init__(self, args, A): 6 | self.L = A - np.diag(np.sum(A, axis=1)) # Difussion matrix: x_j - x_i 7 | 8 | param = args[args.dynamics] 9 | self.a = param.a 10 | self.b = param.b 11 | self.c = param.c 12 | self.epsilon = param.epsilon 13 | self.k_in = np.sum(A, axis=1) # in-degree 14 | 15 | def f(self, x, t): 16 | 17 | node_num = x.shape[0] // 2 18 | x1, x2 = x[:node_num], x[node_num:] 19 | 20 | # x1 21 | f_x1 = x1 - (x1 ** 3)/3 - x2 22 | outer_x1 = self.epsilon * np.dot(self.L, 1/self.k_in) # epsilon * sum_j Aij * ((x1_j - x1_i) / k_in_i) 23 | dx1dt = f_x1 + outer_x1 24 | 25 | # x2 26 | f_x2 = self.a + self.b * x1 + self.c * x2 27 | outer_x2 = 0.0 28 | dx2dt = f_x2 + outer_x2 29 | 30 | if np.isnan(dx1dt).any(): 31 | print('nan during simulation!') 32 | exit() 33 | 34 | dxdt = np.concatenate([dx1dt, dx2dt], axis=0) 35 | return dxdt 36 | 37 | def g(self, x, t): 38 | """inherent noise""" 39 | return np.diag([0.0] * x.shape[0]) 40 | 41 | 42 | class HindmarshRose: 43 | def __init__(self, args, A): 44 | self.A = A # Adjacency matrix 45 | 46 | param = args[args.dynamics] 47 | self.a = param.a 48 | self.b = param.b 49 | self.c = param.c 50 | self.u = param.u 51 | self.s = param.s 52 | self.r = param.r 53 | self.epsilon = param.epsilon 54 | self.v = param.v 55 | self.lam = param.lam 56 | self.I = param.I 57 | self.omega = param.omega 58 | self.x0 = param.x0 59 | 60 | def f(self, x, t): 61 | 62 | node_num = x.shape[0] // 3 63 | x1, x2, x3 = x[:node_num], x[node_num:2*node_num], x[2*node_num:] 64 | mu_xj = 1 / (1 + np.exp(-self.lam * (x1 - self.omega))) 65 | 66 | # x1 67 | f_x1 = x2 - self.a * x1 ** 3 + self.b * x1 ** 2 - x3 + self.I 68 | outer_x1 = self.epsilon * (self.v - x1) * np.dot(self.A, mu_xj) 69 | dx1dt = f_x1 + outer_x1 70 | 71 | # x2 72 | f_x2 = self.c - self.u * x1 ** 2 - x2 73 | outer_x2 = 0.0 74 | dx2dt = f_x2 + outer_x2 75 | 76 | # x3 77 | f_x3 = self.r * (self.s * (x1 - self.x0) - x3) 78 | outer_x3 = 0.0 79 | dx3dt = f_x3 + outer_x3 80 | 81 | if np.isnan(dx1dt).any(): 82 | print('nan during simulation!') 83 | exit() 84 | 85 | dxdt = np.concatenate([dx1dt, dx2dt, dx3dt], axis=0) 86 | return dxdt 87 | 88 | def g(self, x, t): 89 | """inherent noise""" 90 | return np.diag([0.0] * x.shape[0]) 91 | 92 | 93 | class CoupledRossler: 94 | def __init__(self, args, A): 95 | self.L = A - np.diag(np.sum(A, axis=1)) # Difussion matrix: x_j - x_i 96 | 97 | param = args[args.dynamics] 98 | self.a = param.a 99 | self.b = param.b 100 | self.c = param.c 101 | self.epsilon = param.epsilon 102 | self.delta = param.delta 103 | 104 | def f(self, x, t): 105 | 106 | node_num = x.shape[0] // 3 107 | x1, x2, x3 = x[:node_num], x[node_num:2*node_num], x[2*node_num:] 108 | omega = np.random.normal(1, self.delta, size=node_num) 109 | 110 | # x1 111 | f_x1 = - omega * x2 - x3 112 | outer_x1 = self.epsilon * np.dot(self.L, x1) 113 | dx1dt = f_x1 + outer_x1 114 | 115 | # x2 116 | f_x2 = omega * x1 + self.a * x2 117 | outer_x2 = 0.0 118 | dx2dt = f_x2 + outer_x2 119 | 120 | # x3 121 | f_x3 = self.b + x3 * (x1 + self.c) 122 | outer_x3 = 0.0 123 | dx3dt = f_x3 + outer_x3 124 | 125 | if np.isnan(dx1dt).any(): 126 | print('nan during simulation!') 127 | exit() 128 | 129 | dxdt = np.concatenate([dx1dt, dx2dt, dx3dt], axis=0) 130 | return dxdt 131 | 132 | def g(self, x, t): 133 | """inherent noise""" 134 | return np.diag([0.0] * x.shape[0]) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import warnings; warnings.filterwarnings('ignore') 3 | 4 | from model import DiskNet 5 | from data import Dataset 6 | from simluator import NetworkSimulator 7 | from utils import * 8 | 9 | 10 | def main(): 11 | # Read config 12 | conf = OmegaConf.load('config.yaml') 13 | 14 | # Set random seed and cpu number 15 | set_cpu_num(conf.cpu_num) 16 | seed_everything(conf.seed) 17 | 18 | # Set data and log directory 19 | if conf.graph_type in ['BA', 'WS']: 20 | conf.data_dir = f'data/{conf.graph_type}_n{conf.node_num}_{conf.seed}' 21 | conf.log_dir = f'logs/{conf.graph_type}_n{conf.node_num}_{conf.seed}/l{conf.lookback}_h{conf.horizon}' 22 | else: 23 | conf.data_dir = f'data/{conf.graph_type}' 24 | conf.log_dir = f'logs/{conf.graph_type}/l{conf.lookback}_h{conf.horizon}' 25 | 26 | # Create graph 27 | simulator = NetworkSimulator(args=conf) 28 | network, adj_matrix = simulator.buildNetwork() 29 | 30 | # Draw graph 31 | conf.node_num = network.number_of_nodes() 32 | drawGraph(network, layout='string', filter='random', threshold=0.5, out_path=f'{conf.data_dir}/graph.png') 33 | 34 | # Simulate network dynamics 35 | simulator.getSimTraj() 36 | 37 | # Dataset 38 | train_dataset = Dataset(conf, mode='train') 39 | val_dataset = Dataset(conf, mode='val') 40 | test_dataset = Dataset(conf, mode='test') 41 | 42 | # Model 43 | model = DiskNet(conf, adj_matrix) 44 | print_model_summary(model) 45 | 46 | # Train 47 | model.fit(train_dataset.getLoader(), val_dataset.getLoader()) 48 | 49 | # Test 50 | model.test(test_dataset.getLoader()) 51 | 52 | 53 | if __name__ == '__main__': 54 | 55 | main() -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def MAE(y_true, y_pred, keep_step=False, keep_node=False): 5 | if not keep_step and not keep_node: 6 | return np.mean(np.abs(y_true - y_pred)) 7 | elif keep_step and not keep_node: 8 | return np.mean(np.abs(y_true - y_pred), axis=(0,2,3)) 9 | elif not keep_step and keep_node: 10 | return np.mean(np.abs(y_true - y_pred), axis=(0,1,3)) 11 | 12 | def MSE(y_true, y_pred, keep_step=False, keep_node=False): 13 | if not keep_step and not keep_node: 14 | return np.mean(np.square(y_true - y_pred)) 15 | elif keep_step and not keep_node: 16 | return np.mean(np.square(y_true - y_pred), axis=(0,2,3)) 17 | elif not keep_step and keep_node: 18 | return np.mean(np.abs(y_true - y_pred), axis=(0,1,3)) 19 | 20 | def RMSE(y_true, y_pred, keep_step=False): 21 | if not keep_step: 22 | return np.sqrt(np.mean(np.square(y_true - y_pred))) 23 | else: 24 | return np.sqrt(np.mean(np.square(y_true - y_pred), axis=(0,2,3))) 25 | 26 | def NMSE(y_true, y_pred, keep_step=False): 27 | if not keep_step: 28 | return np.mean(np.square(y_true - y_pred)) / (np.mean(np.square(y_true)) + 1e-7) 29 | else: 30 | return np.mean(np.square(y_true - y_pred), axis=(0,2,3)) / (np.mean(np.square(y_true), axis=(0,2,3)) + 1e-7) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | from torch.optim.lr_scheduler import StepLR 8 | import mercator 9 | import matplotlib.pyplot as plt 10 | import sklearn.cluster as cluster 11 | import torchdiffeq as ode 12 | 13 | from utils import drawTraj, draw_embedding 14 | 15 | 16 | def normalized_laplacian(A: torch.Tensor): 17 | """Symmetrically Normalized Laplacian: I - D^-1/2 * ( A ) * D^-1/2""" 18 | out_degree = torch.sum(A, dim=1) 19 | int_degree = torch.sum(A, dim=0) 20 | 21 | out_degree_sqrt_inv = torch.pow(out_degree, -0.5) 22 | int_degree_sqrt_inv = torch.pow(int_degree, -0.5) 23 | mx_operator = torch.eye(A.shape[0], device=A.device) - torch.diag(out_degree_sqrt_inv) @ A @ torch.diag(int_degree_sqrt_inv) 24 | 25 | return mx_operator 26 | 27 | 28 | class HyperbolicEmbedding: 29 | def __init__(self, args: dict): 30 | self.args = args 31 | 32 | def fit_transform(self): 33 | if not os.path.exists(f'{self.args.log_dir}/HE/he.inf_coord'): 34 | os.makedirs(f'{self.args.log_dir}/HE', exist_ok=True) 35 | mercator.embed( 36 | edgelist_filename=f'{self.args.data_dir}/graph.txt', 37 | quiet_mode=self.args.quiet_mode, 38 | fast_mode=self.args.fast_mode, 39 | output_name=f'{self.args.log_dir}/HE/he', 40 | validation_mode=self.args.validation_mode, 41 | post_kappa=self.args.post_kappa 42 | ) 43 | 44 | if self.args.refine: 45 | mercator.embed( 46 | edgelist_filename=f'{self.args.data_dir}/graph.txt', 47 | quiet_mode=self.args.quiet_mode, 48 | fast_mode=self.args.fast_mode, 49 | output_name=f'{self.args.log_dir}/HE/he', 50 | validation_mode=self.args.validation_mode, 51 | post_kappa=self.args.post_kappa, 52 | inf_coord=f'{self.args.log_dir}/HE/he.inf_coord' 53 | ) 54 | 55 | return self._parse_mercator_output() 56 | 57 | def _parse_mercator_output(self): 58 | with open(f'{self.args.log_dir}/HE/he.inf_coord', 'r') as f: 59 | lines = f.readlines() 60 | 61 | # parse node_num, dim, coords 62 | node_num = int(lines[7].split()[-1]) 63 | beta = float(lines[8].split()[-1]) 64 | mu = float(lines[9].split()[-1]) 65 | 66 | kappa = np.zeros(node_num) 67 | angular = np.zeros(node_num) 68 | radius = np.zeros(node_num) 69 | for i in range(15, 15+node_num): 70 | kappa[i-15] = float(lines[i].split()[1]) 71 | angular[i-15] = float(lines[i].split()[2]) 72 | radius[i-15] = float(lines[i].split()[3]) 73 | 74 | return kappa, angular, radius 75 | 76 | 77 | def atanh(x, eps=1e-5): 78 | x = torch.clamp(x, max=1. - eps) 79 | return .5 * (torch.log(1 + x) - torch.log(1 - x)) 80 | 81 | class PoincareManifold: 82 | 83 | @staticmethod 84 | def poincare_grad(euclidean_grad, x, c=-1, eps=1e-5): 85 | """ 86 | Compute the gradient of the Poincare distance with respect to x. 87 | """ 88 | sqnormx = torch.sum(x * x, dim=-1, keepdim=True) 89 | result = ((1 + c*sqnormx) / 2) ** 2 * euclidean_grad 90 | return result 91 | 92 | @staticmethod 93 | def log_map_zero(x, c=-1, eps=1e-5): 94 | """ 95 | Log map from Poincare ball space to tangent space of zero. 96 | Ref: 97 | 1. https://github.com/cll27/pvae/tree/7abbb4604a1acec2332b1b4dfe21267834b505cc 98 | 2. https://github.com/facebookresearch/hgnn/blob/master/manifold/PoincareManifold.py 99 | """ 100 | norm_diff = torch.norm(x, 2, dim=1, keepdim=True) 101 | atanh_x = atanh(np.sqrt(np.abs(c)) * norm_diff) 102 | lam_zero = 2. # lambda = 2 / (1 + ||zero||) = 2 103 | return 2. / (np.sqrt(np.abs(c)) * lam_zero) * atanh_x * (x + eps) / norm_diff 104 | 105 | 106 | class GNN(nn.Module): 107 | def __init__(self, feature_dim, ode_hid_dim): 108 | super(GNN, self).__init__() 109 | self.f1 = nn.Sequential( 110 | nn.Linear(feature_dim, ode_hid_dim, bias=True), 111 | nn.ReLU(), 112 | nn.Linear(ode_hid_dim, ode_hid_dim, bias=True), 113 | ) 114 | self.f2 = nn.Sequential( 115 | nn.Linear(ode_hid_dim, ode_hid_dim, bias=True), 116 | nn.ReLU(), 117 | nn.Linear(ode_hid_dim, feature_dim, bias=True), 118 | ) 119 | 120 | self.adj = None 121 | 122 | def forward(self, x): 123 | x = self.f1(x) 124 | x = self.adj @ x 125 | x = self.f2(x) 126 | return x 127 | 128 | 129 | class BackboneODE(nn.Module): 130 | """dXdt = f(X) + g(X, A)""" 131 | def __init__(self, lookback, feature_dim, ode_hid_dim, method): 132 | super(BackboneODE, self).__init__() 133 | 134 | self.method = method 135 | self.feature_dim = feature_dim 136 | self.init_enc = nn.Sequential( 137 | nn.Linear(lookback, ode_hid_dim, bias=True), 138 | nn.ReLU(), 139 | nn.Linear(ode_hid_dim, 1, bias=True) 140 | ) 141 | self.f = nn.Sequential( 142 | nn.Linear(feature_dim, ode_hid_dim, bias=True), 143 | nn.ReLU(), 144 | nn.Linear(ode_hid_dim, feature_dim, bias=True), 145 | ) 146 | self.g = GNN(feature_dim, ode_hid_dim) 147 | 148 | def dxdt(self, t, x): 149 | x_self = self.f(x) 150 | x_neigh = self.g(x) 151 | dxdt = x_self + x_neigh 152 | return dxdt 153 | 154 | def forward(self, tspan, x, adj_w): 155 | # batch_size, lookback, node_num, feature_dim 156 | self.g.adj = adj_w 157 | 158 | x = x.permute(0, 2, 3, 1) # batch_size, node_num, feature_dim, lookback 159 | x = self.init_enc(x) # batch_size, node_num, feature_dim, 1 160 | x = x.squeeze(-1) # batch_size, node_num, feature_dim 161 | out = ode.odeint(self.dxdt, x, tspan, method=self.method) # horizon, batch_size, node_num, feature_dim 162 | out = out.permute(1, 0, 2, 3) # batch_size, horizon, node_num, feature_dim 163 | return out 164 | 165 | 166 | class Refiner(nn.Module): 167 | def __init__(self, lookback, horizon, feature_dim, hid_dim): 168 | super(Refiner, self).__init__() 169 | 170 | self.feature_dim = feature_dim 171 | self.mlp_X = nn.Sequential( 172 | nn.Linear(lookback*feature_dim, hid_dim), 173 | nn.Tanh(), 174 | ) 175 | self.mlp_Y = nn.Sequential( 176 | nn.Linear(horizon*feature_dim, hid_dim), 177 | nn.Tanh(), 178 | ) 179 | self.mlp_out = nn.Sequential( 180 | nn.Linear(hid_dim*2, hid_dim), 181 | nn.Tanh(), 182 | nn.Linear(hid_dim, horizon*feature_dim), 183 | ) 184 | 185 | def forward(self, X, Y): 186 | X = X.permute(0, 2, 1, 3) # batch_size, node_num, lookback, feature_dim 187 | X = X.reshape(X.shape[0], X.shape[1], -1) # batch_size, node_num, lookback*feature_dim 188 | Y = Y.permute(0, 2, 1, 3) # batch_size, node_num, lookback, feature_dim 189 | Y = Y.reshape(Y.shape[0], Y.shape[1], -1) # batch_size, node_num, lookback*feature_dim 190 | 191 | X = self.mlp_X(X) 192 | Y = self.mlp_Y(Y) 193 | output = torch.cat([X, Y], dim=-1) 194 | refined_Y = self.mlp_out(output) 195 | 196 | refined_Y = refined_Y.reshape(refined_Y.shape[0], refined_Y.shape[1], -1, self.feature_dim) # batch_size, node_num, horizon, feature_dim 197 | refined_Y = refined_Y.permute(0, 2, 1, 3) # batch_size, horizon, node_num, feature_dim 198 | return refined_Y 199 | 200 | 201 | class DiskNet(nn.Module): 202 | 203 | def __init__(self, args, adj): 204 | super(DiskNet, self).__init__() 205 | self.args = args 206 | self.model_args = args['DiskNet'] 207 | 208 | self.adj = torch.from_numpy(adj).float().to(args.device) 209 | self.norm_lap = normalized_laplacian(self.adj) 210 | self.feature_dim = args[args.dynamics].dim 211 | 212 | # Identity Backbone 213 | self.repr_net1 = nn.Sequential( 214 | nn.Linear(self.model_args.n_dim, self.model_args.ag_hid_dim), 215 | nn.Tanh(), 216 | nn.Linear(self.model_args.ag_hid_dim, self.model_args.ag_hid_dim), 217 | nn.ReLU(), 218 | nn.LayerNorm(self.model_args.ag_hid_dim), 219 | ) 220 | self.repr_net2 = nn.Sequential( 221 | nn.Linear(self.model_args.n_dim, self.model_args.ag_hid_dim), 222 | nn.Tanh(), 223 | nn.Linear(self.model_args.ag_hid_dim, self.model_args.ag_hid_dim), 224 | nn.ReLU(), 225 | nn.LayerNorm(self.model_args.ag_hid_dim), 226 | ) 227 | self.softmax = nn.Softmax(dim=-1) 228 | 229 | # State aggregation 230 | self.agc_mlp = nn.Sequential( 231 | nn.Linear(self.feature_dim, self.model_args.ag_hid_dim), 232 | nn.ReLU(), 233 | nn.Linear(self.model_args.ag_hid_dim, self.feature_dim), 234 | ) 235 | self.tanh = nn.Tanh() 236 | 237 | # Backbone Dynamics 238 | self.BackboneODE = BackboneODE(args.lookback, self.feature_dim, self.model_args.ode_hid_dim, self.model_args.method) 239 | 240 | # K-means 241 | self.cluster_idx, self.cluster_centers = self._kmeans(adj, self.model_args.k, self.model_args.log) 242 | 243 | # Refine 244 | self.refiners = nn.ModuleList([Refiner(args.lookback, args.horizon, self.feature_dim, self.model_args.sr_hid_dim) for _ in range(self.model_args.k)]) 245 | 246 | # Device 247 | self.to(args.device) 248 | 249 | # Init 250 | for m in self.modules(): 251 | if isinstance(m, nn.Linear): 252 | init.normal_(m.weight, mean=0, std=0.1) 253 | if m.bias is not None: 254 | init.constant_(m.bias, 0) 255 | 256 | # Init hyperbolic embedding 257 | self.node_embedding, angular = self._init_poincare() 258 | self.supernode_embedding, backbone, assignment_matrix = self._init_super_node(angular) 259 | draw_embedding(self.adj, self.node_embedding, f'{self.args.log_dir}/HE/init_node_poincare.png') 260 | draw_embedding(backbone, self.supernode_embedding, f'{self.args.log_dir}/HE/init_supernode_poincare.png') 261 | 262 | if self.model_args.prior_init: 263 | # Pretrain for Identity Backbone 264 | self._pretrain_identity_backbone(assignment_matrix) 265 | draw_embedding(self.backbone, self.supernode_embedding, f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/pretrain_supernode_poincare.png') 266 | 267 | def _init_poincare(self): 268 | print('Initializing poincare embedding...') 269 | 270 | _, angular, radius = HyperbolicEmbedding(self.args).fit_transform() 271 | 272 | # Poincaré Disk to Euclidean 273 | radius /= radius.max() # enhance hierarchical structure 274 | r = np.tanh(radius / 2) 275 | x, y = r * np.cos(angular), r * np.sin(angular) 276 | poincare_embedding = torch.from_numpy(np.stack([x, y], axis=1)).float().to(self.args.device) 277 | 278 | print('Done.') 279 | return self._check_norm(poincare_embedding), angular 280 | 281 | def _init_super_node(self, angular): 282 | print('Initializing super node embedding...') 283 | num = int(self.model_args.ratio * self.args.node_num) 284 | 285 | # init super node embedding by angular 286 | idx = np.argsort(angular) 287 | assignment_matrix = torch.zeros(num, self.args.node_num).to(self.args.device) 288 | size = int(1/self.model_args.ratio) 289 | for i in range(num): 290 | assignment_matrix[i, idx[i*size:(i+1)*size]] = 1 291 | 292 | degree = self.adj.sum(axis=1, keepdims=True) 293 | super_node_embedding = assignment_matrix @ (self.node_embedding * degree) / (assignment_matrix @ degree).sum(dim=-1, keepdim=True) 294 | super_node_embedding = nn.Parameter(self._check_norm(super_node_embedding)) 295 | backbone = assignment_matrix @ self.adj @ assignment_matrix.T 296 | 297 | print('Done.') 298 | return super_node_embedding, backbone, assignment_matrix 299 | 300 | def _pretrain_identity_backbone(self, prior_assignment_matrix): 301 | 302 | optimizer = torch.optim.Adam( 303 | [ 304 | {'params': self.repr_net1.parameters(), 'lr': self.args.lr}, 305 | {'params': self.repr_net2.parameters(), 'lr': self.args.lr}, 306 | {'params': self.agc_mlp.parameters(), 'lr': self.args.lr}, 307 | ], 308 | lr=self.args.lr) 309 | loss_fn = nn.L1Loss() 310 | 311 | for epoch in range(self.model_args.pretrain_epoch): 312 | # 1. map to euclidean space from poincare space 313 | node_euclidean_embedding = PoincareManifold.log_map_zero(self.node_embedding) 314 | supernode_euclidean_embedding = PoincareManifold.log_map_zero(self.supernode_embedding) 315 | # 2. topology-aware representation 316 | node_repr = self.repr_net1(node_euclidean_embedding) 317 | supernode_repr = self.repr_net2(supernode_euclidean_embedding) 318 | # 3. assignment matrix 319 | assignment_prob = self.softmax(supernode_repr @ node_repr.T) 320 | assignment_matrix = assignment_prob 321 | # 4. loss 322 | loss = loss_fn(assignment_matrix, prior_assignment_matrix) 323 | # 5. update 324 | optimizer.zero_grad() 325 | loss.backward() 326 | optimizer.step() 327 | 328 | print(f'\rPretrain identity backbone[{epoch}]: {loss.item():.4f}', end='') 329 | 330 | print() 331 | del optimizer, loss_fn 332 | 333 | @property 334 | def assignment_matrix(self): 335 | # 1. map to euclidean space from poincare space 336 | node_euclidean_embedding = PoincareManifold.log_map_zero(self.node_embedding) 337 | supernode_euclidean_embedding = PoincareManifold.log_map_zero(self.supernode_embedding) 338 | # 2. topology-aware representation 339 | node_repr = self.repr_net1(node_euclidean_embedding) 340 | supernode_repr = self.repr_net2(supernode_euclidean_embedding) 341 | # 3. assignment matrix 342 | assignment_prob = self.softmax(supernode_repr @ node_repr.T) 343 | assignment_matrix = assignment_prob 344 | return assignment_matrix 345 | 346 | @property 347 | def backbone(self): 348 | assignment_matrix = self.assignment_matrix 349 | idx = torch.argmax(assignment_matrix, dim=0) 350 | assignment_matrix = torch.zeros_like(assignment_matrix, device=self.args.device) 351 | assignment_matrix[idx, torch.arange(idx.shape[0])] = 1 352 | backbone = assignment_matrix @ self.adj @ assignment_matrix.T 353 | backbone[backbone > 0] = 1 354 | return backbone 355 | 356 | def _update_supernode_embedding(self, lr): 357 | # Update supernode embedding by backbone 358 | euclidean_grad = self.supernode_embedding.grad 359 | poincare_grad = PoincareManifold.poincare_grad(euclidean_grad, self.supernode_embedding) 360 | self.supernode_embedding.data -= lr * poincare_grad 361 | self.supernode_embedding.data = self._check_norm(self.supernode_embedding.data) 362 | self.supernode_embedding.grad.zero_() 363 | 364 | 365 | def _check_norm(self, embedding, eps=1e-5): 366 | norm = torch.norm(embedding, dim=-1) 367 | 368 | # Keep the norm of embedding less than 1 369 | idx = norm > 1 370 | if idx.sum() > 0: 371 | embedding[idx] = embedding[idx] / norm[idx].unsqueeze(-1) - eps 372 | return embedding 373 | 374 | def _kmeans(self, adj, k, log=True): 375 | assert k >= 1, "k must be greater than 1" 376 | 377 | degree = adj.sum(axis=1) 378 | if log: 379 | log_degree = np.log(degree) 380 | 381 | model = cluster.KMeans(n_clusters=k, n_init='auto', max_iter=1000, random_state=0) 382 | model.fit(np.array(log_degree).reshape(-1, 1)) 383 | labels = model.labels_ 384 | log_centers = model.cluster_centers_ 385 | 386 | if log: 387 | centers = np.exp(log_centers) 388 | 389 | cluster_ids = [[] for _ in range(k)] 390 | for i, label in enumerate(labels): 391 | cluster_ids[label].append(i) 392 | 393 | return cluster_ids, centers 394 | 395 | def forward(self, tspan, X, isolate=False): 396 | # X: (batch_size, lookback, node_num, feature_dim) 397 | 398 | ################### 399 | # Identity Backbone 400 | ################### 401 | # 1. map to euclidean space from poincare space 402 | node_euclidean_embedding = PoincareManifold.log_map_zero(self.node_embedding) 403 | supernode_euclidean_embedding = PoincareManifold.log_map_zero(self.supernode_embedding) 404 | # 2. topology-aware representation 405 | node_repr = self.repr_net1(node_euclidean_embedding) 406 | supernode_repr = self.repr_net2(supernode_euclidean_embedding) 407 | # 3. assignment matrix 408 | assignment_prob = self.softmax(supernode_repr @ node_repr.T) 409 | assignment_matrix = assignment_prob 410 | # 4. backbone 411 | backbone = assignment_matrix @ self.adj @ assignment_matrix.T 412 | 413 | 414 | ################### 415 | # State aggregation 416 | ################### 417 | # 1. dynamics-aware representation 418 | agc_repr = self.tanh(self.agc_mlp(self.norm_lap @ X)) 419 | # 2. state aggregation 420 | X_supernode = assignment_matrix @ agc_repr # batch_size, lookback, supernode_num, feature_dim 421 | 422 | 423 | ################### 424 | # Backbone Dynamics 425 | ################### 426 | # 1. predict supernode trajectory by graph neural ode 427 | Y_supernode = self.BackboneODE(tspan, X_supernode, backbone) # batch_size, horizon, supernode_num, feature_dim 428 | # 2. copy supernode trajectory to original nodes 429 | Y_coarse = assignment_matrix.T @ Y_supernode # batch_size, horizon, node_num, feature_dim 430 | 431 | 432 | ################### 433 | # Refine 434 | ################### 435 | Y_refine = torch.zeros_like(Y_coarse) 436 | if isolate: 437 | Y_coarse = Y_coarse.detach() 438 | 439 | for k in range(len(self.refiners)): 440 | cluster_X = X[:, :, self.cluster_idx[k]] 441 | cluster_Y_coarse = Y_coarse[:, :, self.cluster_idx[k]] 442 | 443 | if len(self.cluster_idx[k]) == 0: 444 | continue 445 | else: 446 | Y_refine[:, :, self.cluster_idx[k]] = self.refiners[k](cluster_X, cluster_Y_coarse) 447 | 448 | return assignment_matrix, Y_refine, Y_supernode, (Y_coarse, X, X_supernode) 449 | 450 | def _agc_state(self, X, assignment_matrix): 451 | agc_repr = self.tanh(self.agc_mlp(self.norm_lap @ X)) 452 | X_supernode = assignment_matrix @ agc_repr # batch_size, lookback, supernode_num, feature_dim 453 | return X_supernode 454 | 455 | def _rg_loss(self, y_rg, Y, assignment_matrix, dim=None): 456 | 457 | # Averaging Y by RG mapping M 458 | with torch.no_grad(): 459 | Y_supernode = self._agc_state(Y, assignment_matrix) 460 | 461 | # MSE Loss 462 | if dim is None: 463 | rg_loss = torch.mean((y_rg - Y_supernode) ** 2) 464 | else: 465 | rg_loss = torch.mean((y_rg - Y_supernode) ** 2, dim=dim) 466 | 467 | return rg_loss, Y_supernode 468 | 469 | def _onehot_loss(self, assignment_matrix): 470 | entropy = -torch.sum(assignment_matrix * torch.log2(assignment_matrix + 1e-5), dim=0) 471 | onehot_loss = torch.mean(entropy) 472 | return onehot_loss 473 | 474 | def _uniform_loss(self, assignment_matrix): 475 | supernode_strength = torch.sum(assignment_matrix, dim=1) 476 | prob = supernode_strength / torch.sum(supernode_strength) 477 | entropy = -torch.sum(prob * torch.log2(prob + 1e-5), dim=0) 478 | uniform_loss = -torch.mean(entropy) # maximize entropy 479 | return uniform_loss 480 | 481 | def _recons_loss(self, assignment_matrix, adj): 482 | surrogate_adj = assignment_matrix.T @ assignment_matrix 483 | recons_loss = torch.norm(adj - surrogate_adj, p='fro') 484 | return recons_loss 485 | 486 | def _refine_loss(self, y_refine, Y, dim=None): 487 | 488 | # MSE Loss 489 | if dim is None: 490 | refine_loss = torch.mean((y_refine - Y) ** 2) 491 | else: 492 | refine_loss = torch.mean((y_refine - Y) ** 2, dim=dim) 493 | 494 | return refine_loss, Y 495 | 496 | def fit(self, train_dataloader, val_dataloader): 497 | 498 | # if os.path.exists(f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/model.pt'): 499 | # print('Model exists, skip training') 500 | # return 501 | # else: 502 | # os.makedirs(f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}', exist_ok=True) 503 | # print(f'Training {self.args.model} model') 504 | 505 | optimizer = torch.optim.Adam( 506 | [ 507 | {'params': self.repr_net1.parameters(), 'lr': self.args.lr}, 508 | {'params': self.repr_net2.parameters(), 'lr': self.args.lr}, 509 | {'params': self.agc_mlp.parameters(), 'lr': self.args.lr}, 510 | {'params': self.BackboneODE.parameters(), 'lr': self.args.lr}, 511 | {'params': self.refiners.parameters(), 'lr': self.args.lr}, 512 | {'params': self.supernode_embedding, 'lr': self.args.lr}, 513 | ] 514 | , lr=self.args.lr) 515 | scheduler = StepLR(optimizer, step_size=self.args.lr_step, gamma=self.args.lr_decay) 516 | 517 | dt = self.args[self.args.dynamics].dt 518 | start_t = (self.args.lookback) * dt 519 | end_t = (self.args.lookback + self.args.horizon - 1) * dt 520 | tspan = torch.linspace(start_t, end_t, self.args.horizon).to(self.args.device) 521 | 522 | train_loss_list, val_loss_list = [], [] 523 | for epoch in range(1, self.args.max_epoch+1): 524 | train_loss = 0.0 525 | self.train() 526 | for i, (X, Y) in enumerate(train_dataloader): 527 | assignment_matrix, y_refine, y_rg, _ = self(tspan, X) 528 | rg_loss, _ = self._rg_loss(y_rg, Y, assignment_matrix) 529 | refine_loss, _ = self._refine_loss(y_refine, Y) 530 | onehot_loss = self._onehot_loss(assignment_matrix) 531 | uniform_loss = self._uniform_loss(assignment_matrix) 532 | recons_loss = self._recons_loss(assignment_matrix, self.adj) 533 | loss = refine_loss + rg_loss + onehot_loss + recons_loss + uniform_loss 534 | 535 | optimizer.zero_grad() 536 | loss.backward() 537 | self.supernode_embedding.grad = PoincareManifold.poincare_grad(self.supernode_embedding.grad, self.supernode_embedding) # rescale euclidean grad to poincare grad 538 | optimizer.step() 539 | 540 | with torch.no_grad(): 541 | self.supernode_embedding.data = self._check_norm(self.supernode_embedding.data) 542 | 543 | train_loss += loss.item() 544 | print(f'\rEpoch[{epoch}/{self.args.max_epoch}] train backbone: {rg_loss.item():.4f}, refine: {refine_loss.item():.4f}, onehot: {onehot_loss.item():.4f}, recons: {recons_loss.item():.4f}, uniform: {uniform_loss.item():.4f}', end='') 545 | train_loss_list.append([epoch, train_loss / len(train_dataloader)]) 546 | 547 | scheduler.step() 548 | if epoch % self.args.val_interval == 0: 549 | self.eval() 550 | val_loss = 0 551 | for i, (X, Y) in enumerate(val_dataloader): 552 | assignment_matrix, y_refine, y_rg, info = self(tspan, X) # info: (Y_coarse, X_reindex, Y_coarse, X_rg, kappa_reindex) 553 | rg_loss, Y_coarse = self._rg_loss(y_rg, Y, assignment_matrix) 554 | refine_loss, Y_reindex = self._refine_loss(y_refine, Y) 555 | onehot_loss = self._onehot_loss(assignment_matrix) 556 | uniform_loss = self._uniform_loss(assignment_matrix) 557 | recons_loss = self._recons_loss(assignment_matrix, self.adj) 558 | loss = refine_loss 559 | val_loss += loss.item() 560 | 561 | if i == 0: 562 | os.makedirs(f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}', exist_ok=True) 563 | drawTraj(y_rg[:,:,:100], Y_coarse[:,:,:100], 'pred', 'true', dim=0, out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}/rg_pred.png') 564 | drawTraj(info[2][:,:,:10], Y_coarse[:,:12,:10], 'rg_x', 'rg_y', dim=0, out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}/rg_traj.png') 565 | drawTraj(info[1][:,:,:20], info[2][:,:,:10], 'x', 'x_rg', dim=0, out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}/x_rg.png') 566 | drawTraj(y_refine[:,:,:100], info[0][:,:,:100], 'refined', 'coarse', dim=0, out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}/refine.png') 567 | drawTraj(y_refine[:,:,:100], Y_reindex[:,:,:100], 'pred', 'true', dim=0, out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}/result.png') 568 | drawTraj(Y_reindex[:,:,:100], Y_coarse[:,:,:50], 'Y', 'Y_coarse', dim=0, out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}/y_rg.png') 569 | 570 | # Draw the backbone 571 | idx = torch.argmax(assignment_matrix, dim=0) 572 | assignment_matrix = torch.zeros_like(assignment_matrix, device=self.args.device) 573 | assignment_matrix[idx, torch.arange(idx.shape[0])] = 1 574 | backbone = assignment_matrix @ self.adj @ assignment_matrix.T 575 | # backbone[backbone > 0] = 1 576 | draw_embedding(backbone, self.supernode_embedding, f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}/supernode_poincare.png') 577 | 578 | # Assignment distribution 579 | count = torch.sum(assignment_matrix, dim=1) 580 | valid_num = len(count[count > 0]) 581 | 582 | print(f'\nEpoch[{epoch}/{self.args.max_epoch}] val backbone: {rg_loss.item():.4f}, refine: {refine_loss.item():.4f}, onehot: {onehot_loss.item():.4f}, recons: {recons_loss.item():.4f}, uniform: {uniform_loss.item():.4f} | assignment: {valid_num}/{self.supernode_embedding.shape[0]}') 583 | val_loss_list.append([epoch, val_loss / len(val_dataloader)]) 584 | 585 | # Save model 586 | torch.save(self.state_dict(), f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/model_{epoch}.pt') 587 | 588 | # Draw loss curve 589 | train_loss_list = np.array(train_loss_list) 590 | val_loss_list = np.array(val_loss_list) 591 | plt.figure(figsize=(5, 4)) 592 | plt.plot(train_loss_list[:, 0], train_loss_list[:, 1], label='train') 593 | plt.plot(val_loss_list[:, 0], val_loss_list[:, 1], label='val') 594 | plt.xlabel('Epoch') 595 | plt.ylabel('Loss') 596 | plt.legend(frameon=False) 597 | plt.tight_layout() 598 | plt.savefig(f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/loss.png', dpi=300) 599 | 600 | # Save model 601 | torch.save(self.state_dict(), f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/model.pt') 602 | 603 | # Fine-tune refiner 604 | self.refine(train_dataloader, val_dataloader) 605 | 606 | # Release memory 607 | del train_dataloader, val_dataloader, optimizer, scheduler 608 | 609 | def refine(self, train_dataloader, val_dataloader): 610 | 611 | dt = self.args[self.args.dynamics].dt 612 | start_t = (self.args.lookback) * dt 613 | end_t = (self.args.lookback + self.args.horizon - 1) * dt 614 | tspan = torch.linspace(start_t, end_t, self.args.horizon).to(self.args.device) 615 | 616 | optimizer = torch.optim.Adam(self.refiners.parameters(), lr=0.001) 617 | scheduler = StepLR(optimizer, step_size=self.args.lr_step, gamma=self.args.lr_decay) 618 | 619 | for epoch in range(1, 10+1): 620 | train_loss = 0.0 621 | self.train() 622 | for i, (X, Y) in enumerate(train_dataloader): 623 | optimizer.zero_grad() 624 | _, y_refine, y_rg, _ = self(tspan, X, isolate=True) 625 | refine_loss, _ = self._refine_loss(y_refine, Y) 626 | loss = refine_loss 627 | loss.backward() 628 | optimizer.step() 629 | train_loss += loss.item() 630 | print(f'\rEpoch[{epoch}] train refine loss: {refine_loss.item():.4f}', end='') 631 | 632 | scheduler.step() 633 | if epoch % self.args.val_interval == 0: 634 | self.eval() 635 | val_loss = 0 636 | for i, (X, Y) in enumerate(val_dataloader): 637 | _, y_refine, y_rg, info = self(tspan, X, isolate=True) # info: (Y_coarse, X_reindex, Y_coarse, X_rg, kappa_reindex) 638 | refine_loss, Y_reindex = self._refine_loss(y_refine, Y) 639 | loss = refine_loss 640 | val_loss += loss.item() 641 | if i == 0: 642 | os.makedirs(f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}/refine', exist_ok=True) 643 | drawTraj(info[1][:,:,:20], info[2][:,:,:10], 'x', 'x_rg', out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}/refine/x_rg.png') 644 | drawTraj(y_refine[:,:,:100], info[0][:,:,:100], 'refined', 'coarse', out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}/refine/refine.png') 645 | drawTraj(y_refine[:,:,:100], Y_reindex[:,:,:100], 'pred', 'true', out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/epoch_{epoch}/refine/result.png') 646 | print(f'\nEpoch[{epoch}/10] val refine loss: {refine_loss.item():.4f}') 647 | 648 | # Save model 649 | torch.save(self.state_dict(), f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/model_refine.pt') 650 | 651 | def test(self, test_dataloader): 652 | 653 | # Load model 654 | try: 655 | self.load_state_dict(torch.load(f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/model_refine.pt')) 656 | except: 657 | self.load_state_dict(torch.load(f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/model.pt')) 658 | self.to(self.args.device) 659 | 660 | dt = self.args[self.args.dynamics].dt 661 | start_t = (self.args.lookback) * dt 662 | end_t = (self.args.lookback + self.args.horizon - 1) * dt 663 | tspan = torch.linspace(start_t, end_t, self.args.horizon).to(self.args.device) 664 | 665 | # Test 666 | self.eval() 667 | print('Testing...') 668 | ground_truth = np.zeros((len(test_dataloader), self.args.batch_size, self.args.horizon, self.args.node_num, self.feature_dim)) 669 | predict = np.zeros((len(test_dataloader), self.args.batch_size, self.args.horizon, self.args.node_num, self.feature_dim)) 670 | for i, (X, Y) in enumerate(test_dataloader): 671 | assignment_matrix, y_refine, y_rg, info = self(tspan, X) 672 | 673 | if i == len(test_dataloader)-1: 674 | Y_coarse = self._agc_state(Y, assignment_matrix) 675 | os.makedirs(f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/test', exist_ok=True) 676 | drawTraj(y_rg[:,:,:100], Y_coarse[:,:,:100], 'pred', 'true', out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/test/rg_pred.png', num=3) 677 | drawTraj(info[2][:,:,:10], Y_coarse[:,:12,:10], 'rg_x', 'rg_y', dim=0, out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/test/rg_traj.png') 678 | drawTraj(info[1][:,:,:20], info[2][:,:,:10], 'x', 'x_rg', dim=0, out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/test/x_rg.png') 679 | drawTraj(y_refine[:,:,:200], info[0][:,:,:200], 'refined', 'coarse', out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/test/refine.png') 680 | drawTraj(y_refine[:,:,:200], Y[:,:,:200], 'pred', 'true', out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/test/result.png', num=3) 681 | drawTraj(Y[:,:,:200], Y_coarse[:,:,:100], 'Y', 'Y_coarse', out_path=f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/test/y_rg.png') 682 | 683 | ground_truth[i, :, :, :, :] = Y.cpu().detach().numpy() 684 | predict[i, :, :, :, :] = y_refine.cpu().detach().numpy() 685 | 686 | backbone_pred = y_rg.cpu().detach().numpy() 687 | backbone_true = Y_coarse.cpu().detach().numpy() 688 | 689 | # Draw the backbone 690 | idx = torch.argmax(assignment_matrix, dim=-1) 691 | assignment_matrix = torch.zeros_like(assignment_matrix) 692 | assignment_matrix[torch.arange(idx.shape[0]), idx] = 1 693 | backbone = assignment_matrix @ self.adj @ assignment_matrix.T 694 | # backbone[backbone > 0] = 1 695 | draw_embedding(backbone, self.supernode_embedding, f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/test/supernode_poincare.png') 696 | 697 | # Save result 698 | time_cost = 0.0 699 | ground_truth = ground_truth.reshape(-1, self.args.horizon, self.args.node_num, self.feature_dim) 700 | predict = predict.reshape(-1, self.args.horizon, self.args.node_num, self.feature_dim) 701 | np.savez(f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/backbone_result.npz', backbone_pred=backbone_pred, backbone_true=backbone_true) 702 | np.savez(f'{self.args.log_dir}/{self.args.dynamics}/{self.args.model}/result.npz', ground_truth=ground_truth, predict=predict, time_cost=time_cost) 703 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # SimBackbone 2 | torch 3 | omegaconf 4 | torchdiffeq 5 | mercator # https://github.com/networkgeometry/mercator?tab=readme-ov-file 6 | sdeint 7 | scienceplots 8 | matplotlib 9 | scikit-learn -------------------------------------------------------------------------------- /simluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import urllib.request 5 | import zipfile 6 | import networkx as nx 7 | from sdeint import itoEuler 8 | 9 | from utils import HyperbolicEmbedding 10 | from dynamics import * 11 | 12 | 13 | sites = { 14 | 'Drosophila': 'https://nrvis.com/download/data/bn/bn-fly-drosophila_medulla_1.zip', 15 | 'PowerGrid': 'https://nrvis.com/download/data/power/power-bcspwr10.zip', 16 | 'Social': 'https://nrvis.com/download/data/soc/fb-pages-tvshow.zip', 17 | 'Web': 'https://nrvis.com/download/data/web/web-EPA.zip', 18 | 'Airport': 'https://nrvis.com/download/data/inf/inf-openflights.zip', 19 | } 20 | 21 | 22 | class NetworkSimulator(object): 23 | 24 | def __init__(self, args: dict): 25 | self.args = args 26 | self.G = None 27 | self.HyperbolicG = None 28 | 29 | def buildNetwork(self): 30 | try: 31 | with open(f'{self.args.data_dir}/graph.pkl', 'rb') as file: 32 | graph = pickle.load(file) 33 | self.G = graph.G 34 | print(f'Load {self.args.graph_type} graph with {self.G.number_of_nodes()} nodes and {self.G.number_of_edges()} edges') 35 | 36 | except: 37 | os.makedirs(self.args.data_dir, exist_ok=True) 38 | 39 | # Create graph 40 | if self.args.graph_type == 'BA': 41 | self.G = nx.barabasi_albert_graph(self.args.node_num, self.args.edge_num, seed=self.args.seed) 42 | elif self.args.graph_type == 'WS': 43 | self.G = nx.watts_strogatz_graph(self.args.node_num, self.args.ring_lattice_k, self.args.rewiring_prob, seed=self.args.seed) 44 | elif self.args.graph_type == 'Drosophila': 45 | self._downloadNetwork() 46 | self.G = nx.read_edgelist(f'{self.args.data_dir}/bn/bn-fly-drosophila_medulla_1.edges', create_using=nx.DiGraph) 47 | self.G = self.G.to_undirected() 48 | elif self.args.graph_type == 'PowerGrid': 49 | self._downloadNetwork() 50 | self.G = nx.read_edgelist(f'{self.args.data_dir}/power-bcspwr10.mtx', create_using=nx.DiGraph) 51 | self.G = self.G.to_undirected() 52 | elif self.args.graph_type == 'Social': 53 | self._downloadNetwork() 54 | self.G = nx.read_edgelist(f'{self.args.data_dir}/fb-pages-tvshow.edges', create_using=nx.DiGraph) 55 | self.G = self.G.to_undirected() 56 | elif self.args.graph_type == 'Web': 57 | self._downloadNetwork() 58 | self.G = nx.read_edgelist(f'{self.args.data_dir}/web-EPA.edges', create_using=nx.DiGraph) 59 | self.G = self.G.to_undirected() 60 | elif self.args.graph_type == 'Airport': 61 | self._downloadNetwork() 62 | self.G = nx.read_edgelist(f'{self.args.data_dir}/inf-openflights.edges', create_using=nx.DiGraph) 63 | self.G = self.G.to_undirected() 64 | else: 65 | raise Exception(f'Invalid graph type: {self.args.graph_type}') 66 | 67 | # delete isolated nodes 68 | self.G.remove_nodes_from(list(nx.isolates(self.G))) 69 | # delete self-loop edges 70 | self.G.remove_edges_from(nx.selfloop_edges(self.G)) 71 | # only keep the largest connected component 72 | self.G = self.G.subgraph(max(nx.connected_components(self.G), key=len)) 73 | # relabel nodes 74 | self.G = nx.convert_node_labels_to_integers(self.G, first_label=0) 75 | # keep even number of nodes for static RG model 76 | if self.G.number_of_nodes() % 2 == 1: 77 | print(self.G.number_of_nodes()) 78 | self.G.remove_node(np.random.choice(list(self.G.nodes))) 79 | print(self.G.number_of_nodes()) 80 | # relabel nodes 81 | self.G = nx.convert_node_labels_to_integers(self.G, first_label=0) 82 | 83 | with open(f'{self.args.data_dir}/graph.pkl', 'wb') as file: 84 | pickle.dump(self, file) 85 | 86 | nx.write_edgelist(self.G, f'{self.args.data_dir}/graph.txt', data=False) 87 | 88 | print(f'Save {self.args.graph_type} graph with {self.G.number_of_nodes()} nodes and {self.G.number_of_edges()} edges') 89 | 90 | return self.G, nx.to_numpy_array(self.G) 91 | 92 | def getHyperbolicEmbedding(self): 93 | he = HyperbolicEmbedding(args=self.args) 94 | s1_kappa, s1_angular, h1_radius, mu, beta, radius_s1 = he.fit_transform() 95 | 96 | exp_degree, he_degree = np.mean([d for n, d in nx.degree(self.G)]), 0.1 97 | # while abs(exp_degree - he_degree) > 0.1: 98 | p_matrix = self._connectivity_probability_matrix(s1_kappa, s1_angular, mu, beta, radius_s1) 99 | sampled_A = np.random.binomial(1, p_matrix) 100 | self.HyperbolicG = nx.from_numpy_array(sampled_A) 101 | he_degree = np.mean([d for n, d in nx.degree(self.HyperbolicG)]) 102 | mu = mu * (exp_degree / he_degree) 103 | 104 | nx.write_edgelist(self.HyperbolicG, f'{self.args.log_dir}/HE/hyperbolic_graph.txt', data=False) 105 | 106 | print(f'clustering coeficient: {nx.average_clustering(self.G):.3f}-->{nx.average_clustering(self.HyperbolicG):.3f}') 107 | print(f'degree: {np.mean([d for n, d in nx.degree(self.G)]):.3f}-->{np.mean([d for n, d in nx.degree(self.HyperbolicG)]):.3f}') 108 | 109 | return s1_kappa, s1_angular, h1_radius, mu, beta, radius_s1 110 | 111 | def getSimTraj(self): 112 | 113 | try: 114 | with np.load(f'{self.args.data_dir}/{self.args.dynamics}/dynamics.npz') as data: 115 | X = data['X'] 116 | print(f'Load {self.args.dynamics} dynamics with {self.args.node_num} nodes and {self.args[self.args.dynamics].total_t} time steps') 117 | 118 | except: 119 | dim = self.args[self.args.dynamics].dim 120 | 121 | if self.args.dynamics == 'HindmarshRose': 122 | sde = HindmarshRose(args=self.args, A=nx.to_numpy_array(self.G)) 123 | x0_1 = np.random.uniform(-1, 0, size=self.args.node_num) 124 | x0_2 = np.random.uniform(-5, 0, size=self.args.node_num) 125 | x0_3 = np.random.uniform(3, 3.5, size=self.args.node_num) 126 | x0 = np.concatenate((x0_1, x0_2, x0_3)) 127 | elif self.args.dynamics == 'FitzHughNagumo': 128 | sde = FitzHughNagumo(args=self.args, A=nx.to_numpy_array(self.G)) 129 | x0 = np.random.uniform(-1, 1, size=self.args.node_num*dim) 130 | elif self.args.dynamics == 'CoupledRossler': 131 | sde = CoupledRossler(args=self.args, A=nx.to_numpy_array(self.G)) 132 | x0 = np.random.uniform(-0.05, 0.05, size=self.args.node_num*dim) 133 | 134 | tspan = np.arange(0, self.args[self.args.dynamics].total_t, self.args[self.args.dynamics].sim_dt) 135 | sol = itoEuler(sde.f, sde.g, x0, tspan) # (total_t, node_num*feature_dim) 136 | 137 | # downsample 138 | ratio = int(self.args[self.args.dynamics].dt / self.args[self.args.dynamics].sim_dt) 139 | sol = sol[::ratio] 140 | 141 | X = np.zeros((sol.shape[0], self.args.node_num, dim)) 142 | for i in range(dim): 143 | X[:, :, i] = sol[:, i*self.args.node_num:(i+1)*self.args.node_num] 144 | 145 | if self.args.dynamics == 'CoupledKuramoto': 146 | X = np.sin(X) 147 | 148 | os.makedirs(f'{self.args.data_dir}/{self.args.dynamics}', exist_ok=True) 149 | np.savez(f'{self.args.data_dir}/{self.args.dynamics}/dynamics.npz', X=X) 150 | print(f'Save {self.args.dynamics} dynamics with {self.args.node_num} nodes and {self.args[self.args.dynamics].total_t} time steps') 151 | 152 | return X 153 | 154 | def _downloadNetwork(self): 155 | url = sites[self.args.graph_type] 156 | file_name = f"{self.args.data_dir}/download.zip" 157 | urllib.request.urlretrieve(url, file_name) 158 | 159 | with zipfile.ZipFile(file_name, "r") as zip_ref: 160 | zip_ref.extractall(self.args.data_dir) 161 | 162 | os.remove(file_name) 163 | 164 | if self.args.graph_type == 'PowerGrid': 165 | # delete header 166 | with open(f'{self.args.data_dir}/power-bcspwr10.mtx', 'r') as f: 167 | lines = f.readlines() 168 | with open(f'{self.args.data_dir}/power-bcspwr10.mtx', 'w') as f: 169 | f.writelines(lines[14:]) 170 | elif self.args.graph_type == 'Social': 171 | # replace ',' with ' ' 172 | with open(f'{self.args.data_dir}/fb-pages-tvshow.edges', 'r') as f: 173 | lines = f.readlines() 174 | with open(f'{self.args.data_dir}/fb-pages-tvshow.edges', 'w') as f: 175 | for line in lines: 176 | f.write(line.replace(',', ' ')) 177 | elif self.args.graph_type == 'Airport': 178 | # delete header 179 | with open(f'{self.args.data_dir}/inf-openflights.edges', 'r') as f: 180 | lines = f.readlines() 181 | with open(f'{self.args.data_dir}/inf-openflights.edges', 'w') as f: 182 | f.writelines(lines[2:]) 183 | 184 | def _connectivity_probability_matrix(self, kappa, angular, mu, beta, radius): 185 | """p_ij = 1 / (1 + (radius*(delta_angular) / (mu*kappa_i*kappa_j))**beta)""" 186 | 187 | abs_delta_angular = np.abs(angular.reshape(-1, 1) - angular.reshape(1, -1)) 188 | delta_angular = np.minimum(abs_delta_angular, 2 * np.pi - abs_delta_angular) 189 | kappa_mul = kappa.reshape(-1, 1) * kappa.reshape(1, -1) 190 | rescaled_dist = radius * delta_angular / (mu * kappa_mul) 191 | p_matrix = 1 / (1 + rescaled_dist ** beta) - np.eye(self.args.node_num) 192 | 193 | return p_matrix 194 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import mercator 5 | import numpy as np 6 | import networkx as nx 7 | import scienceplots 8 | import matplotlib.pyplot as plt 9 | 10 | from metrics import * 11 | 12 | 13 | plt.style.use(['ieee', 'science', 'no-latex']) 14 | plt.rcParams['legend.fontsize'] = 12 15 | plt.rcParams['xtick.labelsize'] = 12 16 | plt.rcParams['ytick.labelsize'] = 12 17 | plt.rcParams['axes.titlesize'] = 12 18 | plt.rcParams['axes.labelsize'] = 12 19 | plt.rcParams['font.sans-serif'] = 'Arial' 20 | plt.rcParams["font.family"] = 'Arial' 21 | 22 | 23 | def set_cpu_num(cpu_num: int = 1): 24 | if cpu_num <= 0: return 25 | 26 | os.environ ['OMP_NUM_THREADS'] = str(cpu_num) 27 | os.environ ['OPENBLAS_NUM_THREADS'] = str(cpu_num) 28 | os.environ ['MKL_NUM_THREADS'] = str(cpu_num) 29 | os.environ ['VECLIB_MAXIMUM_THREADS'] = str(cpu_num) 30 | os.environ ['NUMEXPR_NUM_THREADS'] = str(cpu_num) 31 | torch.set_num_threads(cpu_num) 32 | 33 | 34 | def seed_everything(seed: int = 42): 35 | # Set the random seed for Python's built-in random module 36 | random.seed(seed) 37 | 38 | # Set the random seed for NumPy 39 | np.random.seed(seed) 40 | 41 | # Set the random seed for torch operations 42 | torch.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | torch.backends.cudnn.deterministic = True 45 | torch.backends.cudnn.benchmark = False 46 | 47 | 48 | def count_parameters(model): 49 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 50 | 51 | def print_model_summary(model): 52 | print(model) 53 | print("Number of parameters: {:,}".format(count_parameters(model))) 54 | 55 | print("\nParameter details:") 56 | for name, parameter in model.named_parameters(): 57 | if parameter.requires_grad: 58 | print(name, parameter.shape, parameter.device, parameter.dtype, parameter.numel()) 59 | 60 | 61 | class HyperbolicEmbedding: 62 | def __init__(self, args: dict): 63 | self.args = args 64 | 65 | def fit_transform(self): 66 | if not os.path.exists(f'{self.args.log_dir}/HE/he.inf_coord'): 67 | os.makedirs(f'{self.args.log_dir}/HE', exist_ok=True) 68 | mercator.embed( 69 | edgelist_filename=f'{self.args.data_dir}/graph.txt', 70 | quiet_mode=self.args.quiet_mode, 71 | fast_mode=self.args.fast_mode, 72 | output_name=f'{self.args.log_dir}/HE/he', 73 | validation_mode=self.args.validation_mode, 74 | post_kappa=self.args.post_kappa 75 | ) 76 | 77 | if self.args.refine: 78 | mercator.embed( 79 | edgelist_filename=f'{self.args.data_dir}/graph.txt', 80 | quiet_mode=self.args.quiet_mode, 81 | fast_mode=self.args.fast_mode, 82 | output_name=f'{self.args.log_dir}/HE/he', 83 | validation_mode=self.args.validation_mode, 84 | post_kappa=self.args.post_kappa, 85 | inf_coord=f'{self.args.log_dir}/HE/he.inf_coord' 86 | ) 87 | 88 | return self._parse_mercator_output() 89 | 90 | def _parse_mercator_output(self): 91 | with open(f'{self.args.log_dir}/HE/he.inf_coord', 'r') as f: 92 | lines = f.readlines() 93 | 94 | # parse node_num, dim, coords 95 | node_num = int(lines[7].split()[-1]) 96 | beta = float(lines[8].split()[-1]) 97 | mu = float(lines[9].split()[-1]) 98 | radius_s1 = float(lines[10].split()[-1]) 99 | radius_h2 = float(lines[11].split()[-1]) 100 | 101 | s1_kappa = np.zeros(node_num) 102 | s1_angular = np.zeros(node_num) 103 | h2_radius = np.zeros(node_num) 104 | for i in range(15, 15+node_num): 105 | s1_kappa[i-15] = float(lines[i].split()[1]) 106 | s1_angular[i-15] = float(lines[i].split()[2]) 107 | h2_radius[i-15] = float(lines[i].split()[3]) 108 | 109 | return s1_kappa, s1_angular, h2_radius, mu, beta, radius_s1 110 | 111 | 112 | def drawGraph(G: nx.Graph, layout: str = 'random', filter: str = 'none', threshold = 0.3, out_path: str = 'graph.png'): 113 | """ 114 | param G: networkX graph 115 | param layout: layout type, options: random, circular, spring, spectral, shell, kamada_kawai (default: random) 116 | param filter: disparity filter, options: none, degree, betweenness, random (default: none) 117 | param threshold: disparity filter threshold (default: 0.3) 118 | """ 119 | 120 | # disparity filter 121 | if filter == 'degree': 122 | degree_centrality = nx.degree_centrality(G) 123 | core_nodes = [node for node, centrality in degree_centrality.items() if centrality >= threshold] 124 | G = G.subgraph(core_nodes) 125 | elif filter == 'betweenness': 126 | betweenness_centrality = nx.betweenness_centrality(G) 127 | core_nodes = [node for node, centrality in betweenness_centrality.items() if centrality >= threshold] 128 | G = G.subgraph(core_nodes) 129 | elif filter == 'random': 130 | core_nodes = [node for node in G.nodes() if np.random.rand() >= threshold] 131 | G = G.subgraph(core_nodes) 132 | else: 133 | pass 134 | 135 | # layout 136 | if layout == 'random': 137 | pos = nx.random_layout(G) 138 | elif layout == 'circular': 139 | pos = nx.circular_layout(G) 140 | elif layout == 'spring': 141 | pos = nx.spring_layout(G) 142 | elif layout == 'spectral': 143 | pos = nx.spectral_layout(G) 144 | elif layout == 'shell': 145 | pos = nx.shell_layout(G) 146 | elif layout == 'kamada_kawai': 147 | pos = nx.kamada_kawai_layout(G) 148 | else: 149 | pos = nx.random_layout(G) 150 | 151 | # node color and size 152 | node_color = [G.degree(v) for v in G] 153 | node_size = [v * 10 for v in node_color] 154 | 155 | plt.figure(figsize=(8, 8)) 156 | nx.draw(G, pos, with_labels=False, node_color=node_color, node_size=node_size, edge_color=(0, 0, 0, 0.25)) 157 | plt.savefig(out_path, dpi=300) 158 | 159 | 160 | def drawTraj(X1, X2, title1, title2, yticks1=None, yticks2=None, dim=0, out_path='traj.png', num=2): 161 | # X1, X2: (Batch, Node Num, Horizon) 162 | if isinstance(X1, torch.Tensor): 163 | data1 = X1[0,:,:,dim].detach().cpu().numpy().T 164 | data2 = X2[0,:,:,dim].detach().cpu().numpy().T 165 | else: 166 | data1 = X1[0,:,:,dim].T 167 | data2 = X2[0,:,:,dim].T 168 | if yticks1 is not None: 169 | yticks1 = yticks1.detach().cpu().numpy()[:,0] 170 | yticks2 = yticks2.detach().cpu().numpy()[:,0] 171 | 172 | zmax = max(np.max(data1), np.max(data2)) 173 | zmin = min(np.min(data1), np.min(data2)) 174 | 175 | # Heatmap 176 | plt.figure(figsize=(num*4, 4)) 177 | plt.subplot(1, num, 1) 178 | plt.imshow(data1, cmap='hot', interpolation='nearest', vmin=zmin, vmax=zmax) 179 | plt.ylabel('Node') 180 | plt.xlabel('Time Step') 181 | if yticks1 is not None: 182 | plt.yticks(np.arange(0, len(yticks1), 1), yticks1) 183 | plt.title(title1) 184 | plt.subplot(1, num, 2) 185 | plt.imshow(data2, cmap='hot', interpolation='nearest', vmin=zmin, vmax=zmax) 186 | plt.ylabel('Node') 187 | plt.xlabel('Time Step') 188 | if yticks1 is not None: 189 | plt.yticks(np.arange(0, len(yticks2), 1), yticks2) 190 | plt.title(title2) 191 | if num == 3: 192 | plt.subplot(1, num, 3) 193 | plt.imshow(np.abs(data2-data1), cmap='hot', interpolation='nearest', vmin=zmin, vmax=zmax) 194 | plt.ylabel('Node') 195 | plt.xlabel('Time Step') 196 | if yticks1 is not None: 197 | plt.yticks(np.arange(0, len(yticks2), 1), yticks2) 198 | plt.title('MAE') 199 | # plt.colorbar() 200 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 201 | plt.savefig(out_path) 202 | plt.close() 203 | 204 | 205 | def draw_embedding(adj, embedding, out_path, label=None, mask=None): 206 | """ 207 | Draw the embedding. 208 | """ 209 | if os.path.exists(out_path): 210 | return 211 | 212 | embedding = embedding.cpu().detach().numpy() 213 | size = adj.diagonal().cpu().detach().numpy() + 1 214 | 215 | plt.figure(figsize=(4, 4)) 216 | 217 | # edge 218 | tmp = adj >= 1 219 | idx = np.where(tmp.cpu().detach().numpy()) 220 | for i in range(len(idx[0])): 221 | plt.plot([embedding[idx[0][i], 0], embedding[idx[1][i], 0]], [embedding[idx[0][i], 1], embedding[idx[1][i], 1]], color="#555555", alpha=0.03, linewidth=0.5) 222 | 223 | # node 224 | if label is None: 225 | plt.scatter(embedding[:, 0], embedding[:, 1], s=8, alpha=1.0) 226 | else: 227 | if mask is None: 228 | mask = np.ones_like(label, dtype=bool) 229 | else: 230 | mask = mask.cpu().detach().numpy() 231 | 232 | for i in np.unique(label): 233 | mask_ = (label == i) & mask 234 | plt.scatter(embedding[mask_, 0], embedding[mask_, 1], s=size[mask_]*1.2, alpha=1.0, label=f'Supernode {i}') 235 | 236 | plt.xlim(-1, 1) 237 | plt.ylim(-1, 1) 238 | plt.xticks([]) 239 | plt.yticks([]) 240 | # plt.legend(loc='upper right', fontsize=8, frameon=False) 241 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 242 | plt.savefig(out_path, dpi=300) 243 | plt.close() 244 | 245 | 246 | def StatisticGraph(): 247 | 248 | import pickle 249 | import networkx as nx 250 | for graph_type in ['PowerGrid', 'Drosophila', 'Social', 'Web', 'Airport', 'BA_n5000_612', 'WS_n5000_612']: 251 | with open(f'data/{graph_type}/graph.pkl', 'rb') as file: 252 | graph = pickle.load(file).G # nx.Graph 253 | 254 | node_num = graph.number_of_nodes() 255 | edge_num = graph.number_of_edges() 256 | avg_degree = sum(dict(graph.degree()).values()) / node_num 257 | avg_clustering = nx.average_clustering(graph) 258 | density = nx.density(graph) 259 | print(f'{graph_type.rjust(12)}: Node Num={node_num}, Edge Num={edge_num}, Avg Degree={avg_degree:.4f}, Avg Clustering={avg_clustering:.4f}, Density={density:.4f}') 260 | 261 | 262 | # if __name__ == '__main__': 263 | # StatisticGraph() --------------------------------------------------------------------------------