├── pprgo ├── __init__.py ├── dataset.py ├── pprgo.py ├── predict.py ├── pytorch_utils.py ├── utils.py ├── ppr.py ├── train.py └── sparsegraph.py ├── data ├── pubmed.npz ├── cora_full.npz └── get_reddit.md ├── .gitignore ├── setup.py ├── config_demo.yaml ├── README.md ├── config_seml.yaml ├── run_seml.py ├── LICENSE.md └── demo.ipynb /pprgo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/pubmed.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TUM-DAML/pprgo_pytorch/HEAD/data/pubmed.npz -------------------------------------------------------------------------------- /data/cora_full.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TUM-DAML/pprgo_pytorch/HEAD/data/cora_full.npz -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | pprgo.egg-info 4 | __pycache__ 5 | .ipynb_checkpoints 6 | 7 | data/reddit.npz 8 | *.egg-info 9 | -------------------------------------------------------------------------------- /data/get_reddit.md: -------------------------------------------------------------------------------- 1 | You can download the Reddit dataset in sparse matrix format at https://figshare.com/articles/dataset/Reddit_graph_dataset_in_sparse_matrix_format/12624146. 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | install_requires = [ 4 | "numpy", 5 | "scipy>=1.3", 6 | "numba>=0.49", 7 | "torch", 8 | "torch-scatter", 9 | "torch-sparse", 10 | "scikit-learn", 11 | "sacred", 12 | "seml" 13 | ] 14 | 15 | setup( 16 | name='pprgo_pytorch', 17 | version='1.0', 18 | description='PPRGo model in PyTorch, from "Scaling Graph Neural Networks with Approximate PageRank"', 19 | author='Aleksandar Bojchevski, Johannes Gasteiger, Bryan Perozzi, Amol Kapoor, Martin Blais, Benedek Rózemberczki, Michal Lukasik, Stephan Günnemann', 20 | author_email='a.bojchevski@in.tum.de, j.gasteiger@in.tum.de', 21 | packages=['pprgo'], 22 | install_requires=install_requires, 23 | zip_safe=False 24 | ) 25 | -------------------------------------------------------------------------------- /pprgo/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .pytorch_utils import matrix_to_torch 4 | 5 | 6 | class PPRDataset(torch.utils.data.Dataset): 7 | def __init__(self, attr_matrix_all, ppr_matrix, indices, labels_all=None): 8 | self.attr_matrix_all = attr_matrix_all 9 | self.ppr_matrix = ppr_matrix 10 | self.indices = indices 11 | self.labels_all = torch.tensor(labels_all) 12 | self.cached = {} 13 | 14 | def __len__(self): 15 | return self.indices.shape[0] 16 | 17 | def __getitem__(self, idx): 18 | # idx is a list of indices 19 | key = idx[0] 20 | if key not in self.cached: 21 | ppr_matrix = self.ppr_matrix[idx] 22 | source_idx, neighbor_idx = ppr_matrix.nonzero() 23 | ppr_scores = ppr_matrix.data 24 | 25 | attr_matrix = matrix_to_torch(self.attr_matrix_all[neighbor_idx]) 26 | ppr_scores = torch.tensor(ppr_scores, dtype=torch.float32) 27 | source_idx = torch.tensor(source_idx, dtype=torch.long) 28 | 29 | if self.labels_all is None: 30 | labels = None 31 | else: 32 | labels = self.labels_all[self.indices[idx]] 33 | self.cached[key] = ((attr_matrix, ppr_scores, source_idx), labels) 34 | return self.cached[key] 35 | -------------------------------------------------------------------------------- /pprgo/pprgo.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch_scatter import scatter 4 | 5 | from .pytorch_utils import MixedDropout, MixedLinear 6 | 7 | 8 | class PPRGoMLP(nn.Module): 9 | def __init__(self, num_features, num_classes, hidden_size, nlayers, dropout): 10 | super().__init__() 11 | 12 | fcs = [MixedLinear(num_features, hidden_size, bias=False)] 13 | for i in range(nlayers - 2): 14 | fcs.append(nn.Linear(hidden_size, hidden_size, bias=False)) 15 | fcs.append(nn.Linear(hidden_size, num_classes, bias=False)) 16 | self.fcs = nn.ModuleList(fcs) 17 | 18 | self.drop = MixedDropout(dropout) 19 | 20 | def forward(self, X): 21 | embs = self.drop(X) 22 | embs = self.fcs[0](embs) 23 | for fc in self.fcs[1:]: 24 | embs = fc(self.drop(F.relu(embs))) 25 | return embs 26 | 27 | 28 | class PPRGo(nn.Module): 29 | def __init__(self, num_features, num_classes, hidden_size, nlayers, dropout): 30 | super().__init__() 31 | self.mlp = PPRGoMLP(num_features, num_classes, hidden_size, nlayers, dropout) 32 | 33 | def forward(self, X, ppr_scores, ppr_idx): 34 | logits = self.mlp(X) 35 | propagated_logits = scatter(logits * ppr_scores[:, None], ppr_idx[:, None], 36 | dim=0, dim_size=ppr_idx[-1] + 1, reduce='sum') 37 | return propagated_logits 38 | -------------------------------------------------------------------------------- /config_demo.yaml: -------------------------------------------------------------------------------- 1 | data_file: data/reddit.npz # Path to the .npz data file 2 | split_seed: 0 # Seed for splitting the dataset into train/val/test 3 | ntrain_div_classes: 20 # Number of training nodes divided by number of classes 4 | attr_normalization: None # Attribute normalization. Not used in the paper 5 | 6 | alpha: 0.5 # PPR teleport probability 7 | eps: 1e-4 # Stopping threshold for ACL's ApproximatePR 8 | topk: 32 # Number of PPR neighbors for each node 9 | ppr_normalization: 'sym' # Adjacency matrix normalization for weighting neighbors 10 | 11 | hidden_size: 32 # Size of the MLP's hidden layer 12 | nlayers: 2 # Number of MLP layers 13 | weight_decay: 1e-4 # Weight decay used for training the MLP 14 | dropout: 0.1 # Dropout used for training 15 | 16 | lr: 5e-3 # Learning rate 17 | max_epochs: 200 # Maximum number of epochs (exact number if no early stopping) 18 | batch_size: 512 # Batch size for training 19 | batch_mult_val: 4 # Multiplier for validation batch size 20 | 21 | eval_step: 20 # Accuracy is evaluated after every this number of steps 22 | run_val: False # Evaluate accuracy on validation set during training 23 | 24 | early_stop: False # Use early stopping 25 | patience: 50 # Patience for early stopping 26 | 27 | nprop_inference: 2 # Number of propagation steps during inference 28 | inf_fraction: 1.0 # Fraction of nodes for which local predictions are computed during inference 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PPRGo (PyTorch) 2 | 3 | This repository provides a PyTorch implementation of PPRGo for a single machine. You can find the [original TensorFlow 1 implementation in another repository](https://github.com/TUM-DAML/pprgo). PPRGo is a fast GNN able to scale to massive graphs in both single-machine and distributed setups. It was proposed in our paper 4 | 5 | **[Scaling Graph Neural Networks with Approximate PageRank](https://www.daml.in.tum.de/pprgo)** 6 | by Aleksandar Bojchevski\*, Johannes Gasteiger\*, Bryan Perozzi, Amol Kapoor, Martin Blais, Benedek Rózemberczki, Michal Lukasik, Stephan Günnemann 7 | Published at ACM SIGKDD 2020. 8 | 9 | \*Both authors contributed equally to this research. Note that the author's name has changed from Johannes Klicpera to Johannes Gasteiger. 10 | 11 | ## Demonstration 12 | To see for yourself how fast PPRGo runs even on a large dataset we've set up a [Google Colab notebook](https://colab.research.google.com/drive/1nw3MIpXPK_n6IZvKcLxkOuOi9_1i6IzA?usp=sharing), which trains and generates predictions for the Reddit dataset, as described in the paper. 13 | 14 | ## Installation 15 | You can install the repository using `pip install -e .`. Since CUDA 10.0 includes a bug that affects PPRGo we strongly recommend using e.g. 10.1. 16 | 17 | ## Run the code 18 | This repository contains a demo notebook for running training and inference (`demo.ipynb`) and a script for running the model on a cluster with [SEML](https://github.com/TUM-DAML/seml) (`run_seml.py`). 19 | 20 | ## Contact 21 | Please contact a.bojchevski@in.tum.de or j.gasteiger@in.tum.de if you have any questions. 22 | 23 | ## Cite 24 | Please cite our paper if you use the model or this code in your own work: 25 | 26 | ``` 27 | @inproceedings{bojchevski2020pprgo, 28 | title={Scaling Graph Neural Networks with Approximate PageRank}, 29 | author={Bojchevski, Aleksandar and Gasteiger, Johannes and Perozzi, Bryan and Kapoor, Amol and Blais, Martin and R{\'o}zemberczki, Benedek and Lukasik, Michal and G{\"u}nnemann, Stephan}, 30 | booktitle = {Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining}, 31 | year={2020}, 32 | publisher = {ACM}, 33 | address = {New York, NY, USA}, 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /config_seml.yaml: -------------------------------------------------------------------------------- 1 | seml: 2 | name: pprgo 3 | executable: run_seml.py 4 | output_dir: '~/slurm_out' 5 | project_root_dir: . 6 | 7 | slurm: 8 | experiments_per_job: 1 9 | sbatch_options: 10 | gres: 'gpu:"GeForce GTX 1080 Ti":1' 11 | mem: 32G # Main memory 12 | cpus-per-task: 5 # CPU cores 13 | time: 1-00:00 # Maximum runtime, D-HH:MM 14 | partition: [gpu_all] 15 | 16 | fixed: 17 | data_dir: /nfs/shared/data/ # Directory containing .npz data files 18 | ntrain_div_classes: 20 # Number of training nodes divided by number of classes 19 | attr_normalization: None # Attribute normalization. Not used in the paper 20 | topk: 32 # Number of PPR neighbors for each node 21 | ppr_normalization: 'sym' # Adjacency matrix normalization for weighting neighbors 22 | hidden_size: 32 # Size of the MLP's hidden layer 23 | nlayers: 2 # Number of MLP layers 24 | weight_decay: 1e-4 # Weight decay used for training the MLP 25 | dropout: 0.1 # Dropout used for training 26 | lr: 5e-3 # Learning rate 27 | max_epochs: 200 # Maximum number of epochs (exact number if no early stopping) 28 | batch_size: 512 # Batch size for training 29 | batch_mult_val: 4 # Multiplier for validation batch size 30 | eval_step: 1 # Accuracy is evaluated after every this number of steps 31 | run_val: False # Evaluate accuracy on validation set during training 32 | early_stop: False # Use early stopping 33 | patience: 50 # Patience for early stopping 34 | nprop_inference: 2 # Number of propagation steps during inference 35 | inf_fraction: 1.0 # Fraction of nodes for which local predictions are computed during inference 36 | 37 | grid: 38 | split_seed: # Seed for splitting the dataset into train/val/test 39 | type: 'range' 40 | min: 0 41 | max: 5 42 | step: 1 43 | eps: # Stopping threshold for ACL's ApproximatePR 44 | type: choice 45 | options: [1e-2, 1e-4] 46 | 47 | reddit: 48 | fixed: 49 | alpha: 0.5 # PPR teleport probability 50 | data_fname: 'reddit.npz' # Name of .npz data file 51 | 52 | remaining: 53 | fixed: 54 | alpha: 0.25 # PPR teleport probability 55 | grid: 56 | data_fname: # Name of .npz data file 57 | type: choice 58 | options: 59 | - 'cora_full.npz' 60 | - 'pubmed.npz' 61 | - 'mag_large_filtered_06_09_coarse_standardized.npz' 62 | # - 'mag_large_filtered_06_09_fine_standardized.npz' 63 | -------------------------------------------------------------------------------- /pprgo/predict.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | 5 | from .pytorch_utils import matrix_to_torch 6 | 7 | 8 | def get_local_logits(model, attr_matrix, batch_size=10000): 9 | device = next(model.parameters()).device 10 | 11 | nnodes = attr_matrix.shape[0] 12 | logits = [] 13 | with torch.set_grad_enabled(False): 14 | for i in range(0, nnodes, batch_size): 15 | batch_attr = matrix_to_torch(attr_matrix[i:i + batch_size]).to(device) 16 | logits.append(model(batch_attr).to('cpu').numpy()) 17 | logits = np.row_stack(logits) 18 | return logits 19 | 20 | 21 | def predict(model, adj_matrix, attr_matrix, alpha, 22 | nprop=2, inf_fraction=1.0, ppr_normalization='sym', batch_size_logits=10000): 23 | 24 | model.eval() 25 | 26 | start = time.time() 27 | if inf_fraction < 1.0: 28 | idx_sub = np.random.choice(adj_matrix.shape[0], int(inf_fraction * adj_matrix.shape[0]), replace=False) 29 | idx_sub.sort() 30 | attr_sub = attr_matrix[idx_sub] 31 | logits_sub = get_local_logits(model.mlp, attr_sub, batch_size_logits) 32 | local_logits = np.zeros([adj_matrix.shape[0], logits_sub.shape[1]], dtype=np.float32) 33 | local_logits[idx_sub] = logits_sub 34 | else: 35 | local_logits = get_local_logits(model.mlp, attr_matrix, batch_size_logits) 36 | time_logits = time.time() - start 37 | 38 | start = time.time() 39 | row, col = adj_matrix.nonzero() 40 | logits = local_logits.copy() 41 | 42 | if ppr_normalization == 'sym': 43 | # Assume undirected (symmetric) adjacency matrix 44 | deg = adj_matrix.sum(1).A1 45 | deg_sqrt_inv = 1. / np.sqrt(np.maximum(deg, 1e-12)) 46 | for _ in range(nprop): 47 | logits = (1 - alpha) * deg_sqrt_inv[:, None] * (adj_matrix @ (deg_sqrt_inv[:, None] * logits)) + alpha * local_logits 48 | elif ppr_normalization == 'col': 49 | deg_col = adj_matrix.sum(0).A1 50 | deg_col_inv = 1. / np.maximum(deg_col, 1e-12) 51 | for _ in range(nprop): 52 | logits = (1 - alpha) * (adj_matrix @ (deg_col_inv[:, None] * logits)) + alpha * local_logits 53 | elif ppr_normalization == 'row': 54 | deg_row = adj_matrix.sum(1).A1 55 | deg_row_inv_alpha = (1 - alpha) / np.maximum(deg_row, 1e-12) 56 | for _ in range(nprop): 57 | logits = deg_row_inv_alpha[:, None] * (adj_matrix @ logits) + alpha * local_logits 58 | else: 59 | raise ValueError(f"Unknown PPR normalization: {ppr_normalization}") 60 | predictions = logits.argmax(1) 61 | time_propagation = time.time() - start 62 | 63 | return predictions, time_logits, time_propagation 64 | -------------------------------------------------------------------------------- /pprgo/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import scipy.sparse as sp 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch_sparse 7 | 8 | 9 | class SparseDropout(nn.Module): 10 | def __init__(self, p): 11 | super().__init__() 12 | self.p = p 13 | 14 | def forward(self, input): 15 | value_dropped = F.dropout(input.storage.value(), self.p, self.training) 16 | return torch_sparse.SparseTensor( 17 | row=input.storage.row(), rowptr=input.storage.rowptr(), col=input.storage.col(), 18 | value=value_dropped, sparse_sizes=input.sparse_sizes(), is_sorted=True) 19 | 20 | 21 | class MixedDropout(nn.Module): 22 | def __init__(self, p): 23 | super().__init__() 24 | self.dense_dropout = nn.Dropout(p) 25 | self.sparse_dropout = SparseDropout(p) 26 | 27 | def forward(self, input): 28 | if isinstance(input, torch_sparse.SparseTensor): 29 | return self.sparse_dropout(input) 30 | else: 31 | return self.dense_dropout(input) 32 | 33 | 34 | class MixedLinear(nn.Module): 35 | def __init__(self, in_features, out_features, bias=True): 36 | super().__init__() 37 | self.in_features = in_features 38 | self.out_features = out_features 39 | self.weight = nn.Parameter(torch.Tensor(in_features, out_features)) 40 | if bias: 41 | self.bias = nn.Parameter(torch.Tensor(out_features)) 42 | else: 43 | self.register_parameter('bias', None) 44 | self.reset_parameters() 45 | 46 | def reset_parameters(self): 47 | # Our fan_in is interpreted by PyTorch as fan_out (swapped dimensions) 48 | nn.init.kaiming_uniform_(self.weight, mode='fan_out', a=math.sqrt(5)) 49 | if self.bias is not None: 50 | _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight) 51 | bound = 1 / math.sqrt(fan_out) 52 | nn.init.uniform_(self.bias, -bound, bound) 53 | 54 | def forward(self, input): 55 | if isinstance(input, torch_sparse.SparseTensor): 56 | res = input.matmul(self.weight) 57 | if self.bias: 58 | res += self.bias[None, :] 59 | else: 60 | if self.bias: 61 | res = torch.addmm(self.bias, input, self.weight) 62 | else: 63 | res = input.matmul(self.weight) 64 | return res 65 | 66 | def extra_repr(self): 67 | return 'in_features={}, out_features={}, bias={}'.format( 68 | self.in_features, self.out_features, self.bias is not None) 69 | 70 | 71 | def matrix_to_torch(X): 72 | if sp.issparse(X): 73 | return torch_sparse.SparseTensor.from_scipy(X) 74 | else: 75 | return torch.FloatTensor(X) 76 | -------------------------------------------------------------------------------- /pprgo/utils.py: -------------------------------------------------------------------------------- 1 | import resource 2 | import numpy as np 3 | import scipy.sparse as sp 4 | import sklearn 5 | 6 | from .sparsegraph import load_from_npz 7 | 8 | 9 | class SparseRowIndexer: 10 | def __init__(self, csr_matrix): 11 | data = [] 12 | indices = [] 13 | indptr = [] 14 | 15 | # Iterating over the rows this way is significantly more efficient 16 | # than csr_matrix[row_index,:] and csr_matrix.getrow(row_index) 17 | for row_start, row_end in zip(csr_matrix.indptr[:-1], csr_matrix.indptr[1:]): 18 | data.append(csr_matrix.data[row_start:row_end]) 19 | indices.append(csr_matrix.indices[row_start:row_end]) 20 | indptr.append(row_end - row_start) # nnz of the row 21 | 22 | self.data = np.array(data) 23 | self.indices = np.array(indices) 24 | self.indptr = np.array(indptr) 25 | self.shape = csr_matrix.shape 26 | 27 | def __getitem__(self, row_selector): 28 | data = np.concatenate(self.data[row_selector]) 29 | indices = np.concatenate(self.indices[row_selector]) 30 | indptr = np.append(0, np.cumsum(self.indptr[row_selector])) 31 | 32 | shape = [indptr.shape[0] - 1, self.shape[1]] 33 | 34 | return sp.csr_matrix((data, indices, indptr), shape=shape) 35 | 36 | 37 | def split_random(seed, n, n_train, n_val): 38 | np.random.seed(seed) 39 | rnd = np.random.permutation(n) 40 | 41 | train_idx = np.sort(rnd[:n_train]) 42 | val_idx = np.sort(rnd[n_train:n_train + n_val]) 43 | 44 | train_val_idx = np.concatenate((train_idx, val_idx)) 45 | test_idx = np.sort(np.setdiff1d(np.arange(n), train_val_idx)) 46 | 47 | return train_idx, val_idx, test_idx 48 | 49 | 50 | def get_data(dataset_path, seed, ntrain_div_classes, normalize_attr=None): 51 | ''' 52 | Get data from a .npz-file. 53 | 54 | Parameters 55 | ---------- 56 | dataset_path 57 | path to dataset .npz file 58 | seed 59 | Random seed for dataset splitting 60 | ntrain_div_classes 61 | Number of training nodes divided by number of classes 62 | normalize_attr 63 | Normalization scheme for attributes. By default (and in the paper) no normalization is used. 64 | 65 | ''' 66 | g = load_from_npz(dataset_path) 67 | 68 | if dataset_path.split('/')[-1] in ['cora_full.npz']: 69 | g.standardize() 70 | 71 | # number of nodes and attributes 72 | n, d = g.attr_matrix.shape 73 | 74 | # optional attribute normalization 75 | if normalize_attr == 'per_feature': 76 | if sp.issparse(g.attr_matrix): 77 | scaler = sklearn.preprocessing.StandardScaler(with_mean=False) 78 | else: 79 | scaler = sklearn.preprocessing.StandardScaler() 80 | attr_matrix = scaler.fit_transform(g.attr_matrix) 81 | elif normalize_attr == 'per_node': 82 | if sp.issparse(g.attr_matrix): 83 | attr_norms = sp.linalg.norm(g.attr_matrix, ord=1, axis=1) 84 | attr_invnorms = 1 / np.maximum(attr_norms, 1e-12) 85 | attr_matrix = g.attr_matrix.multiply(attr_invnorms[:, np.newaxis]).tocsr() 86 | else: 87 | attr_norms = np.linalg.norm(g.attr_matrix, ord=1, axis=1) 88 | attr_invnorms = 1 / np.maximum(attr_norms, 1e-12) 89 | attr_matrix = g.attr_matrix * attr_invnorms[:, np.newaxis] 90 | else: 91 | attr_matrix = g.attr_matrix 92 | 93 | # helper that speeds up row indexing 94 | if sp.issparse(attr_matrix): 95 | attr_matrix = SparseRowIndexer(attr_matrix) 96 | else: 97 | attr_matrix = attr_matrix 98 | 99 | # split the data into train/val/test 100 | num_classes = g.labels.max() + 1 101 | n_train = num_classes * ntrain_div_classes 102 | n_val = n_train * 10 103 | train_idx, val_idx, test_idx = split_random(seed, n, n_train, n_val) 104 | 105 | 106 | return g.adj_matrix, attr_matrix, g.labels, train_idx, val_idx, test_idx 107 | 108 | 109 | def get_max_memory_bytes(): 110 | return 1024 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss 111 | -------------------------------------------------------------------------------- /pprgo/ppr.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import numpy as np 3 | import scipy.sparse as sp 4 | 5 | 6 | @numba.njit(cache=True, locals={'_val': numba.float32, 'res': numba.float32, 'res_vnode': numba.float32}) 7 | def _calc_ppr_node(inode, indptr, indices, deg, alpha, epsilon): 8 | alpha_eps = alpha * epsilon 9 | f32_0 = numba.float32(0) 10 | p = {inode: f32_0} 11 | r = {} 12 | r[inode] = alpha 13 | q = [inode] 14 | while len(q) > 0: 15 | unode = q.pop() 16 | 17 | res = r[unode] if unode in r else f32_0 18 | if unode in p: 19 | p[unode] += res 20 | else: 21 | p[unode] = res 22 | r[unode] = f32_0 23 | for vnode in indices[indptr[unode]:indptr[unode + 1]]: 24 | _val = (1 - alpha) * res / deg[unode] 25 | if vnode in r: 26 | r[vnode] += _val 27 | else: 28 | r[vnode] = _val 29 | 30 | res_vnode = r[vnode] if vnode in r else f32_0 31 | if res_vnode >= alpha_eps * deg[vnode]: 32 | if vnode not in q: 33 | q.append(vnode) 34 | 35 | return list(p.keys()), list(p.values()) 36 | 37 | 38 | @numba.njit(cache=True) 39 | def calc_ppr(indptr, indices, deg, alpha, epsilon, nodes): 40 | js = [] 41 | vals = [] 42 | for i, node in enumerate(nodes): 43 | j, val = _calc_ppr_node(node, indptr, indices, deg, alpha, epsilon) 44 | js.append(j) 45 | vals.append(val) 46 | return js, vals 47 | 48 | 49 | @numba.njit(cache=True, parallel=True) 50 | def calc_ppr_topk_parallel(indptr, indices, deg, alpha, epsilon, nodes, topk): 51 | js = [np.zeros(0, dtype=np.int64)] * len(nodes) 52 | vals = [np.zeros(0, dtype=np.float32)] * len(nodes) 53 | for i in numba.prange(len(nodes)): 54 | j, val = _calc_ppr_node(nodes[i], indptr, indices, deg, alpha, epsilon) 55 | j_np, val_np = np.array(j), np.array(val) 56 | idx_topk = np.argsort(val_np)[-topk:] 57 | js[i] = j_np[idx_topk] 58 | vals[i] = val_np[idx_topk] 59 | return js, vals 60 | 61 | 62 | def ppr_topk(adj_matrix, alpha, epsilon, nodes, topk): 63 | """Calculate the PPR matrix approximately using Anderson.""" 64 | 65 | out_degree = np.sum(adj_matrix > 0, axis=1).A1 66 | nnodes = adj_matrix.shape[0] 67 | 68 | neighbors, weights = calc_ppr_topk_parallel(adj_matrix.indptr, adj_matrix.indices, out_degree, 69 | numba.float32(alpha), numba.float32(epsilon), nodes, topk) 70 | 71 | return construct_sparse(neighbors, weights, (len(nodes), nnodes)) 72 | 73 | 74 | def construct_sparse(neighbors, weights, shape): 75 | i = np.repeat(np.arange(len(neighbors)), np.fromiter(map(len, neighbors), dtype=np.int)) 76 | j = np.concatenate(neighbors) 77 | return sp.coo_matrix((np.concatenate(weights), (i, j)), shape) 78 | 79 | 80 | def topk_ppr_matrix(adj_matrix, alpha, eps, idx, topk, normalization='row'): 81 | """Create a sparse matrix where each node has up to the topk PPR neighbors and their weights.""" 82 | 83 | topk_matrix = ppr_topk(adj_matrix, alpha, eps, idx, topk).tocsr() 84 | 85 | if normalization == 'sym': 86 | # Assume undirected (symmetric) adjacency matrix 87 | deg = adj_matrix.sum(1).A1 88 | deg_sqrt = np.sqrt(np.maximum(deg, 1e-12)) 89 | deg_inv_sqrt = 1. / deg_sqrt 90 | 91 | row, col = topk_matrix.nonzero() 92 | # assert np.all(deg[idx[row]] > 0) 93 | # assert np.all(deg[col] > 0) 94 | topk_matrix.data = deg_sqrt[idx[row]] * topk_matrix.data * deg_inv_sqrt[col] 95 | elif normalization == 'col': 96 | # Assume undirected (symmetric) adjacency matrix 97 | deg = adj_matrix.sum(1).A1 98 | deg_inv = 1. / np.maximum(deg, 1e-12) 99 | 100 | row, col = topk_matrix.nonzero() 101 | # assert np.all(deg[idx[row]] > 0) 102 | # assert np.all(deg[col] > 0) 103 | topk_matrix.data = deg[idx[row]] * topk_matrix.data * deg_inv[col] 104 | elif normalization == 'row': 105 | pass 106 | else: 107 | raise ValueError(f"Unknown PPR normalization: {normalization}") 108 | 109 | return topk_matrix 110 | -------------------------------------------------------------------------------- /pprgo/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def run_batch(model, xbs, yb, optimizer, train): 8 | 9 | # Set model to training mode 10 | if train: 11 | model.train() 12 | else: 13 | model.eval() 14 | 15 | # zero the parameter gradients 16 | if train: 17 | optimizer.zero_grad() 18 | 19 | # forward 20 | with torch.set_grad_enabled(train): 21 | pred = model(*xbs) 22 | loss = F.cross_entropy(pred, yb) 23 | top1 = torch.argmax(pred, dim=1) 24 | ncorrect = torch.sum(top1 == yb) 25 | 26 | # backward + optimize only if in training phase 27 | if train: 28 | loss.backward() 29 | optimizer.step() 30 | 31 | return loss, ncorrect 32 | 33 | 34 | def train(model, train_set, val_set, lr, weight_decay, 35 | max_epochs=200, batch_size=512, batch_mult_val=4, 36 | eval_step=1, early_stop=False, patience=50, ex=None): 37 | device = next(model.parameters()).device 38 | 39 | train_loader = torch.utils.data.DataLoader( 40 | dataset=train_set, 41 | sampler=torch.utils.data.BatchSampler( 42 | torch.utils.data.SequentialSampler(train_set), 43 | batch_size=batch_size, drop_last=False 44 | ), 45 | batch_size=None, 46 | num_workers=0, 47 | ) 48 | step = 0 49 | best_loss = np.inf 50 | 51 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 52 | 53 | loss_hist = {'train': [], 'val': []} 54 | acc_hist = {'train': [], 'val': []} 55 | if ex is not None: 56 | ex.current_run.info['train'] = {'loss': [], 'acc': []} 57 | ex.current_run.info['val'] = {'loss': [], 'acc': []} 58 | 59 | loss = 0 60 | ncorrect = 0 61 | nsamples = 0 62 | for epoch in range(max_epochs): 63 | for xbs, yb in train_loader: 64 | xbs, yb = [xb.to(device) for xb in xbs], yb.to(device) 65 | 66 | loss_batch, ncorr_batch = run_batch(model, xbs, yb, optimizer, train=True) 67 | loss += loss_batch 68 | ncorrect += ncorr_batch 69 | nsamples += yb.shape[0] 70 | 71 | step += 1 72 | if step % eval_step == 0: 73 | # update train stats 74 | train_loss = loss / nsamples 75 | train_acc = ncorrect / nsamples 76 | 77 | loss_hist['train'].append(train_loss) 78 | acc_hist['train'].append(train_acc) 79 | if ex is not None: 80 | ex.current_run.info['train']['loss'].append(train_loss) 81 | ex.current_run.info['train']['acc'].append(train_acc) 82 | 83 | if val_set is not None: 84 | # update val stats 85 | rnd_idx = np.random.choice(len(val_set), size=batch_mult_val * batch_size, replace=False) 86 | xbs, yb = val_set[rnd_idx] 87 | xbs, yb = [xb.to(device) for xb in xbs], yb.to(device) 88 | val_loss, val_ncorr = run_batch(model, xbs, yb, None, train=False) 89 | val_acc = val_ncorr / (batch_mult_val * batch_size) 90 | 91 | loss_hist['val'].append(val_loss) 92 | acc_hist['val'].append(val_acc) 93 | if ex is not None: 94 | ex.current_run.info['val']['loss'].append(val_loss) 95 | ex.current_run.info['val']['acc'].append(val_acc) 96 | 97 | logging.info(f"Epoch {epoch}, step {step}: train {train_loss:.5f}, val {val_loss:.5f}") 98 | 99 | if val_loss < best_loss: 100 | best_loss = val_loss 101 | best_epoch = epoch 102 | best_state = { 103 | key: value.cpu() for key, value 104 | in model.state_dict().items() 105 | } 106 | # early stop only if this variable is set to True 107 | elif early_stop and epoch >= best_epoch + patience: 108 | model.load_state_dict(best_state) 109 | return epoch + 1, loss_hist, acc_hist 110 | else: 111 | logging.info(f"Epoch {epoch}, step {step}: train {train_loss:.5f}") 112 | if val_set is not None: 113 | model.load_state_dict(best_state) 114 | return epoch + 1, loss_hist, acc_hist 115 | -------------------------------------------------------------------------------- /run_seml.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | from sklearn.metrics import accuracy_score, f1_score 4 | import torch 5 | from sacred import Experiment 6 | import seml 7 | 8 | from pprgo import utils, ppr 9 | from pprgo.pprgo import PPRGo 10 | from pprgo.train import train 11 | from pprgo.predict import predict 12 | from pprgo.dataset import PPRDataset 13 | 14 | ex = Experiment() 15 | seml.setup_logger(ex) 16 | 17 | 18 | @ex.config 19 | def config(): 20 | overwrite = None 21 | db_collection = None 22 | if db_collection is not None: 23 | ex.observers.append(seml.create_mongodb_observer(db_collection, overwrite=overwrite)) 24 | 25 | 26 | @ex.automain 27 | def run(data_dir, data_fname, split_seed, ntrain_div_classes, attr_normalization, 28 | alpha, eps, topk, ppr_normalization, 29 | hidden_size, nlayers, weight_decay, dropout, 30 | lr, max_epochs, batch_size, batch_mult_val, 31 | eval_step, run_val, 32 | early_stop, patience, 33 | nprop_inference, inf_fraction): 34 | ''' 35 | Run training and inference. 36 | 37 | Parameters 38 | ---------- 39 | data_dir: 40 | Directory containing .npz data files. 41 | data_fname: 42 | Name of .npz data file. 43 | split_seed: 44 | Seed for splitting the dataset into train/val/test. 45 | ntrain_div_classes: 46 | Number of training nodes divided by number of classes. 47 | attr_normalization: 48 | Attribute normalization. Not used in the paper. 49 | alpha: 50 | PPR teleport probability. 51 | eps: 52 | Stopping threshold for ACL's ApproximatePR. 53 | topk: 54 | Number of PPR neighbors for each node. 55 | ppr_normalization: 56 | Adjacency matrix normalization for weighting neighbors. 57 | hidden_size: 58 | Size of the MLP's hidden layer. 59 | nlayers: 60 | Number of MLP layers. 61 | weight_decay: 62 | Weight decay used for training the MLP. 63 | dropout: 64 | Dropout used for training. 65 | lr: 66 | Learning rate. 67 | max_epochs: 68 | Maximum number of epochs (exact number if no early stopping). 69 | batch_size: 70 | Batch size for training. 71 | batch_mult_val: 72 | Multiplier for validation batch size. 73 | eval_step: 74 | Accuracy is evaluated after every this number of steps. 75 | run_val: 76 | Evaluate accuracy on validation set during training. 77 | early_stop: 78 | Use early stopping. 79 | patience: 80 | Patience for early stopping. 81 | nprop_inference: 82 | Number of propagation steps during inference 83 | inf_fraction: 84 | Fraction of nodes for which local predictions are computed during inference. 85 | ''' 86 | torch.manual_seed(0) 87 | 88 | start = time.time() 89 | (adj_matrix, attr_matrix, labels, 90 | train_idx, val_idx, test_idx) = utils.get_data( 91 | f"{data_dir}/{data_fname}", 92 | seed=split_seed, 93 | ntrain_div_classes=ntrain_div_classes, 94 | normalize_attr=attr_normalization 95 | ) 96 | try: 97 | d = attr_matrix.n_columns 98 | except AttributeError: 99 | d = attr_matrix.shape[1] 100 | nc = labels.max() + 1 101 | time_loading = time.time() - start 102 | logging.info('Loading done.') 103 | 104 | # compute the ppr vectors for train/val nodes using ACL's ApproximatePR 105 | start = time.time() 106 | topk_train = ppr.topk_ppr_matrix(adj_matrix, alpha, eps, train_idx, topk, 107 | normalization=ppr_normalization) 108 | train_set = PPRDataset(attr_matrix_all=attr_matrix, ppr_matrix=topk_train, indices=train_idx, labels_all=labels) 109 | if run_val: 110 | topk_val = ppr.topk_ppr_matrix(adj_matrix, alpha, eps, val_idx, topk, 111 | normalization=ppr_normalization) 112 | val_set = PPRDataset(attr_matrix_all=attr_matrix, ppr_matrix=topk_val, indices=val_idx, labels_all=labels) 113 | else: 114 | val_set = None 115 | time_preprocessing = time.time() - start 116 | logging.info('Preprocessing done.') 117 | 118 | start = time.time() 119 | model = PPRGo(d, nc, hidden_size, nlayers, dropout) 120 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 121 | model.to(device) 122 | 123 | nepochs, _, _ = train( 124 | model=model, train_set=train_set, val_set=val_set, 125 | lr=lr, weight_decay=weight_decay, 126 | max_epochs=max_epochs, batch_size=batch_size, batch_mult_val=batch_mult_val, 127 | eval_step=eval_step, early_stop=early_stop, patience=patience, 128 | ex=ex) 129 | time_training = time.time() - start 130 | logging.info('Training done.') 131 | 132 | start = time.time() 133 | predictions, time_logits, time_propagation = predict( 134 | model=model, adj_matrix=adj_matrix, attr_matrix=attr_matrix, alpha=alpha, 135 | nprop=nprop_inference, inf_fraction=inf_fraction, 136 | ppr_normalization=ppr_normalization) 137 | time_inference = time.time() - start 138 | logging.info('Inference done.') 139 | 140 | results = { 141 | 'accuracy_train': 100 * accuracy_score(labels[train_idx], predictions[train_idx]), 142 | 'accuracy_val': 100 * accuracy_score(labels[val_idx], predictions[val_idx]), 143 | 'accuracy_test': 100 * accuracy_score(labels[test_idx], predictions[test_idx]), 144 | 'f1_train': f1_score(labels[train_idx], predictions[train_idx], average='macro'), 145 | 'f1_val': f1_score(labels[val_idx], predictions[val_idx], average='macro'), 146 | 'f1_test': f1_score(labels[test_idx], predictions[test_idx], average='macro'), 147 | } 148 | 149 | results.update({ 150 | 'time_loading': time_loading, 151 | 'time_preprocessing': time_preprocessing, 152 | 'time_training': time_training, 153 | 'time_inference': time_inference, 154 | 'time_logits': time_logits, 155 | 'time_propagation': time_propagation, 156 | 'gpu_memory': torch.cuda.max_memory_allocated(), 157 | 'memory': utils.get_max_memory_bytes(), 158 | 'nepochs': nepochs, 159 | }) 160 | return results 161 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | PPRGo Copyright 2020 Aleksandar Bojchevski, Johannes Gasteiger (“Licensor”) 2 | 3 | Hippocratic License Version Number: 2.1. 4 | 5 | Purpose. The purpose of this License is for the Licensor named above to permit the Licensee (as defined below) broad permission, if consistent with Human Rights Laws and Human Rights Principles (as each is defined below), to use and work with the Software (as defined below) within the full scope of Licensor’s copyright and patent rights, if any, in the Software, while ensuring attribution and protecting the Licensor from liability. 6 | 7 | Permission and Conditions. The Licensor grants permission by this license (“License”), free of charge, to the extent of Licensor’s rights under applicable copyright and patent law, to any person or entity (the “Licensee”) obtaining a copy of this software and associated documentation files (the “Software”), to do everything with the Software that would otherwise infringe (i) the Licensor’s copyright in the Software or (ii) any patent claims to the Software that the Licensor can license or becomes able to license, subject to all of the following terms and conditions: 8 | 9 | * Acceptance. This License is automatically offered to every person and entity subject to its terms and conditions. Licensee accepts this License and agrees to its terms and conditions by taking any action with the Software that, absent this License, would infringe any intellectual property right held by Licensor. 10 | 11 | * Notice. Licensee must ensure that everyone who gets a copy of any part of this Software from Licensee, with or without changes, also receives the License and the above copyright notice (and if included by the Licensor, patent, trademark and attribution notice). Licensee must cause any modified versions of the Software to carry prominent notices stating that Licensee changed the Software. For clarity, although Licensee is free to create modifications of the Software and distribute only the modified portion created by Licensee with additional or different terms, the portion of the Software not modified must be distributed pursuant to this License. If anyone notifies Licensee in writing that Licensee has not complied with this Notice section, Licensee can keep this License by taking all practical steps to comply within 30 days after the notice. If Licensee does not do so, Licensee’s License (and all rights licensed hereunder) shall end immediately. 12 | 13 | * Compliance with Human Rights Principles and Human Rights Laws. 14 | 15 | 1. Human Rights Principles. 16 | 17 | (a) Licensee is advised to consult the articles of the United Nations Universal Declaration of Human Rights and the United Nations Global Compact that define recognized principles of international human rights (the “Human Rights Principles”). Licensee shall use the Software in a manner consistent with Human Rights Principles. 18 | 19 | (b) Unless the Licensor and Licensee agree otherwise, any dispute, controversy, or claim arising out of or relating to (i) Section 1(a) regarding Human Rights Principles, including the breach of Section 1(a), termination of this License for breach of the Human Rights Principles, or invalidity of Section 1(a) or (ii) a determination of whether any Law is consistent or in conflict with Human Rights Principles pursuant to Section 2, below, shall be settled by arbitration in accordance with the Hague Rules on Business and Human Rights Arbitration (the “Rules”); provided, however, that Licensee may elect not to participate in such arbitration, in which event this License (and all rights licensed hereunder) shall end immediately. The number of arbitrators shall be one unless the Rules require otherwise. 20 | 21 | Unless both the Licensor and Licensee agree to the contrary: (1) All documents and information concerning the arbitration shall be public and may be disclosed by any party; (2) The repository referred to under Article 43 of the Rules shall make available to the public in a timely manner all documents concerning the arbitration which are communicated to it, including all submissions of the parties, all evidence admitted into the record of the proceedings, all transcripts or other recordings of hearings and all orders, decisions and awards of the arbitral tribunal, subject only to the arbitral tribunal's powers to take such measures as may be necessary to safeguard the integrity of the arbitral process pursuant to Articles 18, 33, 41 and 42 of the Rules; and (3) Article 26(6) of the Rules shall not apply. 22 | 23 | 2. Human Rights Laws. The Software shall not be used by any person or entity for any systems, activities, or other uses that violate any Human Rights Laws. “Human Rights Laws” means any applicable laws, regulations, or rules (collectively, “Laws”) that protect human, civil, labor, privacy, political, environmental, security, economic, due process, or similar rights; provided, however, that such Laws are consistent and not in conflict with Human Rights Principles (a dispute over the consistency or a conflict between Laws and Human Rights Principles shall be determined by arbitration as stated above). Where the Human Rights Laws of more than one jurisdiction are applicable or in conflict with respect to the use of the Software, the Human Rights Laws that are most protective of the individuals or groups harmed shall apply. 24 | 25 | 3. Indemnity. Licensee shall hold harmless and indemnify Licensor (and any other contributor) against all losses, damages, liabilities, deficiencies, claims, actions, judgments, settlements, interest, awards, penalties, fines, costs, or expenses of whatever kind, including Licensor’s reasonable attorneys’ fees, arising out of or relating to Licensee’s use of the Software in violation of Human Rights Laws or Human Rights Principles. 26 | 27 | * Failure to Comply. Any failure of Licensee to act according to the terms and conditions of this License is both a breach of the License and an infringement of the intellectual property rights of the Licensor (subject to exceptions under Laws, e.g., fair use). In the event of a breach or infringement, the terms and conditions of this License may be enforced by Licensor under the Laws of any jurisdiction to which Licensee is subject. Licensee also agrees that the Licensor may enforce the terms and conditions of this License against Licensee through specific performance (or similar remedy under Laws) to the extent permitted by Laws. For clarity, except in the event of a breach of this License, infringement, or as otherwise stated in this License, Licensor may not terminate this License with Licensee. 28 | 29 | * Enforceability and Interpretation. If any term or provision of this License is determined to be invalid, illegal, or unenforceable by a court of competent jurisdiction, then such invalidity, illegality, or unenforceability shall not affect any other term or provision of this License or invalidate or render unenforceable such term or provision in any other jurisdiction; provided, however, subject to a court modification pursuant to the immediately following sentence, if any term or provision of this License pertaining to Human Rights Laws or Human Rights Principles is deemed invalid, illegal, or unenforceable against Licensee by a court of competent jurisdiction, all rights in the Software granted to Licensee shall be deemed null and void as between Licensor and Licensee. Upon a determination that any term or provision is invalid, illegal, or unenforceable, to the extent permitted by Laws, the court may modify this License to affect the original purpose that the Software be used in compliance with Human Rights Principles and Human Rights Laws as closely as possible. The language in this License shall be interpreted as to its fair meaning and not strictly for or against any party. 30 | 31 | * Disclaimer. TO THE FULL EXTENT ALLOWED BY LAW, THIS SOFTWARE COMES “AS IS,” WITHOUT ANY WARRANTY, EXPRESS OR IMPLIED, AND LICENSOR AND ANY OTHER CONTRIBUTOR SHALL NOT BE LIABLE TO ANYONE FOR ANY DAMAGES OR OTHER LIABILITY ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THIS LICENSE, UNDER ANY KIND OF LEGAL CLAIM. 32 | 33 | This Hippocratic License is an Ethical Source license (https://ethicalsource.dev) and is offered for use by licensors and licensees at their own risk, on an “AS IS” basis, and with no warranties express or implied, to the maximum extent permitted by Laws. 34 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import time\n", 11 | "import logging\n", 12 | "import yaml\n", 13 | "import ast\n", 14 | "import numpy as np\n", 15 | "from sklearn.metrics import accuracy_score, f1_score, confusion_matrix\n", 16 | "import torch\n", 17 | "\n", 18 | "from pprgo import utils, ppr\n", 19 | "from pprgo.pprgo import PPRGo\n", 20 | "from pprgo.train import train\n", 21 | "from pprgo.predict import predict\n", 22 | "from pprgo.dataset import PPRDataset" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# Set up logging\n", 32 | "logger = logging.getLogger()\n", 33 | "logger.handlers = []\n", 34 | "ch = logging.StreamHandler()\n", 35 | "formatter = logging.Formatter(\n", 36 | " fmt='%(asctime)s (%(levelname)s): %(message)s',\n", 37 | " datefmt='%Y-%m-%d %H:%M:%S')\n", 38 | "ch.setFormatter(formatter)\n", 39 | "logger.addHandler(ch)\n", 40 | "logger.setLevel('INFO')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "# Download dataset" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "--2020-07-16 17:44:26-- https://ndownloader.figshare.com/files/23742119\n", 60 | "Resolving proxy.in.tum.de (proxy.in.tum.de)... 131.159.0.2\n", 61 | "Connecting to proxy.in.tum.de (proxy.in.tum.de)|131.159.0.2|:8080... connected.\n", 62 | "Proxy request sent, awaiting response... 302 Found\n", 63 | "Location: https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/23742119/reddit.npz [following]\n", 64 | "--2020-07-16 17:44:26-- https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/23742119/reddit.npz\n", 65 | "Connecting to proxy.in.tum.de (proxy.in.tum.de)|131.159.0.2|:8080... connected.\n", 66 | "Proxy request sent, awaiting response... 200 OK\n", 67 | "Length: 1480703860 (1,4G) [application/octet-stream]\n", 68 | "Saving to: ‘data/reddit.npz’\n", 69 | "\n", 70 | "data/reddit.npz 100%[===================>] 1,38G 73,0MB/s in 19s \n", 71 | "\n", 72 | "2020-07-16 17:44:45 (74,2 MB/s) - ‘data/reddit.npz’ saved [1480703860/1480703860]\n", 73 | "\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "!wget --show-progress -O data/reddit.npz https://ndownloader.figshare.com/files/23742119" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "# Load config" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "with open('config_demo.yaml', 'r') as c:\n", 95 | " config = yaml.safe_load(c)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 5, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "# For strings that yaml doesn't parse (e.g. None)\n", 105 | "for key, val in config.items():\n", 106 | " if type(val) is str:\n", 107 | " try:\n", 108 | " config[key] = ast.literal_eval(val)\n", 109 | " except (ValueError, SyntaxError):\n", 110 | " pass" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 6, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "data_file = config['data_file'] # Path to the .npz data file\n", 120 | "split_seed = config['split_seed'] # Seed for splitting the dataset into train/val/test\n", 121 | "ntrain_div_classes = config['ntrain_div_classes'] # Number of training nodes divided by number of classes\n", 122 | "attr_normalization = config['attr_normalization'] # Attribute normalization. Not used in the paper\n", 123 | "\n", 124 | "alpha = config['alpha'] # PPR teleport probability\n", 125 | "eps = config['eps'] # Stopping threshold for ACL's ApproximatePR\n", 126 | "topk = config['topk'] # Number of PPR neighbors for each node\n", 127 | "ppr_normalization = config['ppr_normalization'] # Adjacency matrix normalization for weighting neighbors\n", 128 | "\n", 129 | "hidden_size = config['hidden_size'] # Size of the MLP's hidden layer\n", 130 | "nlayers = config['nlayers'] # Number of MLP layers\n", 131 | "weight_decay = config['weight_decay'] # Weight decay used for training the MLP\n", 132 | "dropout = config['dropout'] # Dropout used for training\n", 133 | "\n", 134 | "lr = config['lr'] # Learning rate\n", 135 | "max_epochs = config['max_epochs'] # Maximum number of epochs (exact number if no early stopping)\n", 136 | "batch_size = config['batch_size'] # Batch size for training\n", 137 | "batch_mult_val = config['batch_mult_val'] # Multiplier for validation batch size\n", 138 | "\n", 139 | "eval_step = config['eval_step'] # Accuracy is evaluated after every this number of steps\n", 140 | "run_val = config['run_val'] # Evaluate accuracy on validation set during training\n", 141 | "\n", 142 | "early_stop = config['early_stop'] # Use early stopping\n", 143 | "patience = config['patience'] # Patience for early stopping\n", 144 | "\n", 145 | "nprop_inference = config['nprop_inference'] # Number of propagation steps during inference\n", 146 | "inf_fraction = config['inf_fraction'] # Fraction of nodes for which local predictions are computed during inference" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "# Load the data" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 7, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "Runtime: 8.37s\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "start = time.time()\n", 171 | "(adj_matrix, attr_matrix, labels,\n", 172 | " train_idx, val_idx, test_idx) = utils.get_data(\n", 173 | " f\"{data_file}\",\n", 174 | " seed=split_seed,\n", 175 | " ntrain_div_classes=ntrain_div_classes,\n", 176 | " normalize_attr=attr_normalization\n", 177 | ")\n", 178 | "try:\n", 179 | " d = attr_matrix.n_columns\n", 180 | "except AttributeError:\n", 181 | " d = attr_matrix.shape[1]\n", 182 | "nc = labels.max() + 1\n", 183 | "time_loading = time.time() - start\n", 184 | "print(f\"Runtime: {time_loading:.2f}s\")" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "# Preprocessing: Calculate PPR scores" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 8, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "name": "stdout", 201 | "output_type": "stream", 202 | "text": [ 203 | "Runtime: 0.92s\n" 204 | ] 205 | } 206 | ], 207 | "source": [ 208 | "# compute the ppr vectors for train/val nodes using ACL's ApproximatePR\n", 209 | "start = time.time()\n", 210 | "topk_train = ppr.topk_ppr_matrix(adj_matrix, alpha, eps, train_idx, topk,\n", 211 | " normalization=ppr_normalization)\n", 212 | "train_set = PPRDataset(attr_matrix_all=attr_matrix, ppr_matrix=topk_train, indices=train_idx, labels_all=labels)\n", 213 | "if run_val:\n", 214 | " topk_val = ppr.topk_ppr_matrix(adj_matrix, alpha, eps, val_idx, topk,\n", 215 | " normalization=ppr_normalization)\n", 216 | " val_set = PPRDataset(attr_matrix_all=attr_matrix, ppr_matrix=topk_val, indices=val_idx, labels_all=labels)\n", 217 | "else:\n", 218 | " val_set = None\n", 219 | "time_preprocessing = time.time() - start\n", 220 | "print(f\"Runtime: {time_preprocessing:.2f}s\")" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "# Training: Set up model and train" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 9, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "name": "stderr", 237 | "output_type": "stream", 238 | "text": [ 239 | "2020-07-16 17:45:00 (INFO): Epoch 9, step 20: train 0.00916\n", 240 | "2020-07-16 17:45:00 (INFO): Epoch 19, step 40: train 0.00872\n", 241 | "2020-07-16 17:45:00 (INFO): Epoch 29, step 60: train 0.00842\n", 242 | "2020-07-16 17:45:00 (INFO): Epoch 39, step 80: train 0.00821\n", 243 | "2020-07-16 17:45:00 (INFO): Epoch 49, step 100: train 0.00803\n", 244 | "2020-07-16 17:45:00 (INFO): Epoch 59, step 120: train 0.00787\n", 245 | "2020-07-16 17:45:00 (INFO): Epoch 69, step 140: train 0.00773\n", 246 | "2020-07-16 17:45:00 (INFO): Epoch 79, step 160: train 0.00759\n", 247 | "2020-07-16 17:45:01 (INFO): Epoch 89, step 180: train 0.00746\n", 248 | "2020-07-16 17:45:01 (INFO): Epoch 99, step 200: train 0.00734\n", 249 | "2020-07-16 17:45:01 (INFO): Epoch 109, step 220: train 0.00723\n", 250 | "2020-07-16 17:45:01 (INFO): Epoch 119, step 240: train 0.00712\n", 251 | "2020-07-16 17:45:01 (INFO): Epoch 129, step 260: train 0.00701\n", 252 | "2020-07-16 17:45:01 (INFO): Epoch 139, step 280: train 0.00691\n", 253 | "2020-07-16 17:45:01 (INFO): Epoch 149, step 300: train 0.00682\n", 254 | "2020-07-16 17:45:01 (INFO): Epoch 159, step 320: train 0.00673\n", 255 | "2020-07-16 17:45:01 (INFO): Epoch 169, step 340: train 0.00664\n", 256 | "2020-07-16 17:45:01 (INFO): Epoch 179, step 360: train 0.00656\n", 257 | "2020-07-16 17:45:01 (INFO): Epoch 189, step 380: train 0.00648\n", 258 | "2020-07-16 17:45:01 (INFO): Epoch 199, step 400: train 0.00640\n", 259 | "2020-07-16 17:45:01 (INFO): Training done.\n" 260 | ] 261 | }, 262 | { 263 | "name": "stdout", 264 | "output_type": "stream", 265 | "text": [ 266 | "Runtime: 2.87s\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "start = time.time()\n", 272 | "model = PPRGo(d, nc, hidden_size, nlayers, dropout)\n", 273 | "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", 274 | "model.to(device)\n", 275 | "\n", 276 | "nepochs, _, _ = train(\n", 277 | " model=model, train_set=train_set, val_set=val_set,\n", 278 | " lr=lr, weight_decay=weight_decay,\n", 279 | " max_epochs=max_epochs, batch_size=batch_size, batch_mult_val=batch_mult_val,\n", 280 | " eval_step=eval_step, early_stop=early_stop, patience=patience)\n", 281 | "time_training = time.time() - start\n", 282 | "logging.info('Training done.')\n", 283 | "print(f\"Runtime: {time_training:.2f}s\")" 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "metadata": {}, 289 | "source": [ 290 | "# Inference (val and test)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 10, 296 | "metadata": {}, 297 | "outputs": [ 298 | { 299 | "name": "stdout", 300 | "output_type": "stream", 301 | "text": [ 302 | "Runtime: 8.00s\n" 303 | ] 304 | } 305 | ], 306 | "source": [ 307 | "start = time.time()\n", 308 | "predictions, time_logits, time_propagation = predict(\n", 309 | " model=model, adj_matrix=adj_matrix, attr_matrix=attr_matrix, alpha=alpha,\n", 310 | " nprop=nprop_inference, inf_fraction=inf_fraction,\n", 311 | " ppr_normalization=ppr_normalization)\n", 312 | "time_inference = time.time() - start\n", 313 | "print(f\"Runtime: {time_inference:.2f}s\")" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "# Collect and print results" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 11, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "acc_train = 100 * accuracy_score(labels[train_idx], predictions[train_idx])\n", 330 | "acc_val = 100 * accuracy_score(labels[val_idx], predictions[val_idx])\n", 331 | "acc_test = 100 * accuracy_score(labels[test_idx], predictions[test_idx])\n", 332 | "f1_train = f1_score(labels[train_idx], predictions[train_idx], average='macro')\n", 333 | "f1_val = f1_score(labels[val_idx], predictions[val_idx], average='macro')\n", 334 | "f1_test = f1_score(labels[test_idx], predictions[test_idx], average='macro')\n", 335 | "\n", 336 | "gpu_memory = torch.cuda.max_memory_allocated()\n", 337 | "memory = utils.get_max_memory_bytes()\n", 338 | "\n", 339 | "time_total = time_preprocessing + time_training + time_inference" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 12, 345 | "metadata": {}, 346 | "outputs": [ 347 | { 348 | "name": "stdout", 349 | "output_type": "stream", 350 | "text": [ 351 | "\n", 352 | "Accuracy: Train: 37.2%, val: 27.0%, test: 27.2%\n", 353 | "F1 score: Train: 0.188, val: 0.112, test: 0.112\n", 354 | "\n", 355 | "Runtime: Preprocessing: 0.92s, training: 2.87s, inference: 8.00s -> total: 11.78s\n", 356 | "Memory: Main: 5.38GB, GPU: 0.045GB\n", 357 | "\n" 358 | ] 359 | } 360 | ], 361 | "source": [ 362 | "print(f'''\n", 363 | "Accuracy: Train: {acc_train:.1f}%, val: {acc_val:.1f}%, test: {acc_test:.1f}%\n", 364 | "F1 score: Train: {f1_train:.3f}, val: {f1_val:.3f}, test: {f1_test:.3f}\n", 365 | "\n", 366 | "Runtime: Preprocessing: {time_preprocessing:.2f}s, training: {time_training:.2f}s, inference: {time_inference:.2f}s -> total: {time_total:.2f}s\n", 367 | "Memory: Main: {memory / 2**30:.2f}GB, GPU: {gpu_memory / 2**30:.3f}GB\n", 368 | "''')" 369 | ] 370 | } 371 | ], 372 | "metadata": { 373 | "kernelspec": { 374 | "display_name": "Python [conda env:pytorch]", 375 | "language": "python", 376 | "name": "conda-env-pytorch-py" 377 | }, 378 | "language_info": { 379 | "codemirror_mode": { 380 | "name": "ipython", 381 | "version": 3 382 | }, 383 | "file_extension": ".py", 384 | "mimetype": "text/x-python", 385 | "name": "python", 386 | "nbconvert_exporter": "python", 387 | "pygments_lexer": "ipython3", 388 | "version": "3.7.6" 389 | } 390 | }, 391 | "nbformat": 4, 392 | "nbformat_minor": 4 393 | } 394 | -------------------------------------------------------------------------------- /pprgo/sparsegraph.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, Union, Tuple, Any 3 | import numpy as np 4 | import scipy.sparse as sp 5 | 6 | __all__ = ['SparseGraph'] 7 | 8 | sparse_graph_properties = [ 9 | 'adj_matrix', 'attr_matrix', 'edge_attr_matrix', 'labels', 10 | 'node_names', 'attr_names', 'edge_attr_names', 'class_names', 11 | 'metadata'] 12 | 13 | 14 | class SparseGraph: 15 | """Attributed labeled graph stored in sparse matrix form. 16 | 17 | All properties are immutable so users don't mess up the 18 | data format's assumptions (e.g. of edge_attr_matrix). 19 | Be careful when circumventing this and changing the internal matrices 20 | regardless (e.g. by exchanging the data array of a sparse matrix). 21 | 22 | Parameters 23 | ---------- 24 | adj_matrix 25 | Adjacency matrix in CSR format. Shape [num_nodes, num_nodes] 26 | attr_matrix 27 | Attribute matrix in CSR or numpy format. Shape [num_nodes, num_attr] 28 | edge_attr_matrix 29 | Edge attribute matrix in CSR or numpy format. Shape [num_edges, num_edge_attr] 30 | labels 31 | Array, where each entry represents respective node's label(s). Shape [num_nodes] 32 | Alternatively, CSR matrix with labels in one-hot format. Shape [num_nodes, num_classes] 33 | node_names 34 | Names of nodes (as strings). Shape [num_nodes] 35 | attr_names 36 | Names of the attributes (as strings). Shape [num_attr] 37 | edge_attr_names 38 | Names of the edge attributes (as strings). Shape [num_edge_attr] 39 | class_names 40 | Names of the class labels (as strings). Shape [num_classes] 41 | metadata 42 | Additional metadata such as text. 43 | 44 | """ 45 | def __init__( 46 | self, adj_matrix: sp.spmatrix, 47 | attr_matrix: Union[np.ndarray, sp.spmatrix] = None, 48 | edge_attr_matrix: Union[np.ndarray, sp.spmatrix] = None, 49 | labels: Union[np.ndarray, sp.spmatrix] = None, 50 | node_names: np.ndarray = None, 51 | attr_names: np.ndarray = None, 52 | edge_attr_names: np.ndarray = None, 53 | class_names: np.ndarray = None, 54 | metadata: Any = None): 55 | # Make sure that the dimensions of matrices / arrays all agree 56 | if sp.isspmatrix(adj_matrix): 57 | adj_matrix = adj_matrix.tocsr().astype(np.float32) 58 | else: 59 | raise ValueError("Adjacency matrix must be in sparse format (got {0} instead)." 60 | .format(type(adj_matrix))) 61 | 62 | if adj_matrix.shape[0] != adj_matrix.shape[1]: 63 | raise ValueError("Dimensions of the adjacency matrix don't agree.") 64 | 65 | if attr_matrix is not None: 66 | if sp.isspmatrix(attr_matrix): 67 | attr_matrix = attr_matrix.tocsr().astype(np.float32) 68 | elif isinstance(attr_matrix, np.ndarray): 69 | attr_matrix = attr_matrix.astype(np.float32) 70 | else: 71 | raise ValueError("Attribute matrix must be a sp.spmatrix or a np.ndarray (got {0} instead)." 72 | .format(type(attr_matrix))) 73 | 74 | if attr_matrix.shape[0] != adj_matrix.shape[0]: 75 | raise ValueError("Dimensions of the adjacency and attribute matrices don't agree.") 76 | 77 | if edge_attr_matrix is not None: 78 | if sp.isspmatrix(edge_attr_matrix): 79 | edge_attr_matrix = edge_attr_matrix.tocsr().astype(np.float32) 80 | elif isinstance(edge_attr_matrix, np.ndarray): 81 | edge_attr_matrix = edge_attr_matrix.astype(np.float32) 82 | else: 83 | raise ValueError("Edge attribute matrix must be a sp.spmatrix or a np.ndarray (got {0} instead)." 84 | .format(type(edge_attr_matrix))) 85 | 86 | if edge_attr_matrix.shape[0] != adj_matrix.count_nonzero(): 87 | raise ValueError("Number of edges and dimension of the edge attribute matrix don't agree.") 88 | 89 | if labels is not None: 90 | if labels.shape[0] != adj_matrix.shape[0]: 91 | raise ValueError("Dimensions of the adjacency matrix and the label vector don't agree.") 92 | 93 | if node_names is not None: 94 | if len(node_names) != adj_matrix.shape[0]: 95 | raise ValueError("Dimensions of the adjacency matrix and the node names don't agree.") 96 | 97 | if attr_names is not None: 98 | if len(attr_names) != attr_matrix.shape[1]: 99 | raise ValueError("Dimensions of the attribute matrix and the attribute names don't agree.") 100 | 101 | if edge_attr_names is not None: 102 | if len(edge_attr_names) != edge_attr_matrix.shape[1]: 103 | raise ValueError("Dimensions of the edge attribute matrix and the edge attribute names don't agree.") 104 | 105 | # TODO: check that class names matches the labels vector (including the multi-label case) 106 | 107 | self._adj_matrix = adj_matrix 108 | self._attr_matrix = attr_matrix 109 | self._edge_attr_matrix = edge_attr_matrix 110 | self._labels = labels 111 | self._node_names = node_names 112 | self._attr_names = attr_names 113 | self._edge_attr_names = edge_attr_names 114 | self._class_names = class_names 115 | self._metadata = metadata 116 | 117 | def num_nodes(self) -> int: 118 | """Get the number of nodes in the graph. 119 | """ 120 | return self.adj_matrix.shape[0] 121 | 122 | def num_edges(self, warn: bool = True) -> int: 123 | """Get the number of edges in the graph. 124 | 125 | For undirected graphs, (i, j) and (j, i) are counted as _two_ edges. 126 | 127 | """ 128 | if warn and not self.is_directed(): 129 | warnings.warn("num_edges always returns the number of directed edges now.", FutureWarning) 130 | return self.adj_matrix.nnz 131 | 132 | def get_neighbors(self, idx: int) -> np.ndarray: 133 | """Get the indices of neighbors of a given node. 134 | 135 | Parameters 136 | ---------- 137 | idx 138 | Index of the node whose neighbors are of interest. 139 | 140 | """ 141 | return self.adj_matrix[idx].indices 142 | 143 | def get_edgeid_to_idx_array(self) -> np.ndarray: 144 | """Return a Numpy Array that maps edgeids to the indices in the adjacency matrix. 145 | 146 | Returns 147 | ------- 148 | np.ndarray 149 | The i'th entry contains the x- and y-coordinates of edge i in the adjacency matrix. 150 | Shape [num_edges, 2] 151 | 152 | """ 153 | return np.transpose(self.adj_matrix.nonzero()) 154 | 155 | def get_idx_to_edgeid_matrix(self) -> sp.csr_matrix: 156 | """Return a sparse matrix that maps indices in the adjacency matrix to edgeids. 157 | 158 | Caution: This contains one explicit 0 (zero stored as a nonzero), 159 | which is the index of the first edge. 160 | 161 | Returns 162 | ------- 163 | sp.csr_matrix 164 | The entry [x, y] contains the edgeid of the corresponding edge (or 0 for non-edges). 165 | Shape [num_nodes, num_nodes] 166 | 167 | """ 168 | return sp.csr_matrix( 169 | (np.arange(self.adj_matrix.nnz), self.adj_matrix.indices, self.adj_matrix.indptr), 170 | shape=self.adj_matrix.shape) 171 | 172 | def is_directed(self) -> bool: 173 | """Check if the graph is directed (adjacency matrix is not symmetric). 174 | """ 175 | return (self.adj_matrix != self.adj_matrix.T).sum() != 0 176 | 177 | def to_undirected(self) -> 'SparseGraph': 178 | """Convert to an undirected graph (make adjacency matrix symmetric). 179 | """ 180 | # TODO: add warnings / logging 181 | 182 | idx = self.get_edgeid_to_idx_array().T 183 | ridx = np.ravel_multi_index(idx, self.adj_matrix.shape) 184 | ridx_rev = np.ravel_multi_index(idx[::-1], self.adj_matrix.shape) 185 | 186 | # Get duplicate edges (self-loops and opposing edges) 187 | dup_ridx = ridx[np.isin(ridx, ridx_rev)] 188 | dup_idx = np.unravel_index(dup_ridx, self.adj_matrix.shape) 189 | 190 | # Check if the adjacency matrix weights are symmetric (if nonzero) 191 | if len(dup_ridx) > 0 and not np.allclose(self.adj_matrix[dup_idx], self.adj_matrix[dup_idx[::-1]]): 192 | raise ValueError("Adjacency matrix weights of opposing edges differ.") 193 | 194 | # Create symmetric matrix 195 | new_adj_matrix = self.adj_matrix + self.adj_matrix.T 196 | if len(dup_ridx) > 0: 197 | new_adj_matrix[dup_idx] = (new_adj_matrix[dup_idx] - self.adj_matrix[dup_idx]).A1 198 | 199 | if self.edge_attr_matrix is not None: 200 | 201 | # Check if edge attributes are symmetric 202 | edgeid_mat = self.get_idx_to_edgeid_matrix() 203 | if len(dup_ridx) > 0: 204 | dup_edgeids = edgeid_mat[dup_idx].A1 205 | dup_rev_edgeids = edgeid_mat[dup_idx[::-1]].A1 206 | if not np.allclose(self.edge_attr_matrix[dup_edgeids], self.edge_attr_matrix[dup_rev_edgeids]): 207 | raise ValueError("Edge attributes of opposing edges differ.") 208 | 209 | # Adjust edge attributes to new adjacency matrix 210 | edgeid_mat.data += 1 # Add 1 so we don't lose the explicit 0 and change the sparsity structure 211 | new_edgeid_mat = edgeid_mat + edgeid_mat.T 212 | if len(dup_ridx) > 0: 213 | new_edgeid_mat[dup_idx] = (new_edgeid_mat[dup_idx] - edgeid_mat[dup_idx]).A1 214 | new_idx = new_adj_matrix.nonzero() 215 | edgeids_perm = new_edgeid_mat[new_idx].A1 - 1 216 | self._edge_attr_matrix = self.edge_attr_matrix[edgeids_perm] 217 | 218 | self._adj_matrix = new_adj_matrix 219 | return self 220 | 221 | def is_weighted(self) -> bool: 222 | """Check if the graph is weighted (edge weights other than 1). 223 | """ 224 | return np.any(np.unique(self.adj_matrix[self.adj_matrix.nonzero()].A1) != 1) 225 | 226 | def to_unweighted(self) -> 'SparseGraph': 227 | """Convert to an unweighted graph (set all edge weights to 1). 228 | """ 229 | # TODO: add warnings / logging 230 | self._adj_matrix.data = np.ones_like(self._adj_matrix.data) 231 | return self 232 | 233 | def is_connected(self) -> bool: 234 | """Check if the graph is connected. 235 | """ 236 | return sp.csgraph.connected_components(self.adj_matrix, return_labels=False) == 1 237 | 238 | def has_self_loops(self) -> bool: 239 | """Check if the graph has self-loops. 240 | """ 241 | return not np.allclose(self.adj_matrix.diagonal(), 0) 242 | 243 | def __repr__(self) -> str: 244 | props = [] 245 | for prop_name in sparse_graph_properties: 246 | prop = getattr(self, prop_name) 247 | if prop is not None: 248 | if prop_name == 'metadata': 249 | props.append(prop_name) 250 | else: 251 | shape_string = 'x'.join([str(x) for x in prop.shape]) 252 | props.append("{} ({})".format(prop_name, shape_string)) 253 | dir_string = 'Directed' if self.is_directed() else 'Undirected' 254 | weight_string = 'weighted' if self.is_weighted() else 'unweighted' 255 | conn_string = 'connected' if self.is_connected() else 'disconnected' 256 | loop_string = 'has self-loops' if self.has_self_loops() else 'no self-loops' 257 | return ("<{}, {} and {} SparseGraph with {} edges ({}). Data: {}>" 258 | .format(dir_string, weight_string, conn_string, 259 | self.num_edges(warn=False), loop_string, 260 | ', '.join(props))) 261 | 262 | # Quality of life (shortcuts) 263 | def standardize( 264 | self, make_unweighted: bool = True, 265 | make_undirected: bool = True, 266 | no_self_loops: bool = True, 267 | select_lcc: bool = True 268 | ) -> 'SparseGraph': 269 | """Perform common preprocessing steps: remove self-loops, make unweighted/undirected, select LCC. 270 | 271 | All changes are done inplace. 272 | 273 | Parameters 274 | ---------- 275 | make_unweighted 276 | Whether to set all edge weights to 1. 277 | make_undirected 278 | Whether to make the adjacency matrix symmetric. Can only be used if make_unweighted is True. 279 | no_self_loops 280 | Whether to remove self loops. 281 | select_lcc 282 | Whether to select the largest connected component of the graph. 283 | 284 | """ 285 | # TODO: add warnings / logging 286 | G = self 287 | if make_unweighted and G.is_weighted(): 288 | G = G.to_unweighted() 289 | if make_undirected and G.is_directed(): 290 | G = G.to_undirected() 291 | if no_self_loops and G.has_self_loops(): 292 | G = remove_self_loops(G) 293 | if select_lcc and not G.is_connected(): 294 | G = largest_connected_components(G, 1) 295 | self._adopt_graph(G) 296 | return G 297 | 298 | def unpack(self) -> Tuple[sp.csr_matrix, 299 | Union[np.ndarray, sp.csr_matrix], 300 | Union[np.ndarray, sp.csr_matrix], 301 | Union[np.ndarray, sp.csr_matrix]]: 302 | """Return the (A, X, E, z) quadruplet. 303 | """ 304 | return self._adj_matrix, self._attr_matrix, self._edge_attr_matrix, self._labels 305 | 306 | def _adopt_graph(self, graph: 'SparseGraph'): 307 | """Copy all properties from the given graph to this graph. 308 | """ 309 | for prop in sparse_graph_properties: 310 | setattr(self, '_{}'.format(prop), getattr(graph, prop)) 311 | 312 | def to_flat_dict(self) -> Dict[str, Any]: 313 | """Return flat dictionary containing all SparseGraph properties. 314 | """ 315 | data_dict = {} 316 | for key in sparse_graph_properties: 317 | val = getattr(self, key) 318 | if sp.isspmatrix(val): 319 | data_dict['{}.data'.format(key)] = val.data 320 | data_dict['{}.indices'.format(key)] = val.indices 321 | data_dict['{}.indptr'.format(key)] = val.indptr 322 | data_dict['{}.shape'.format(key)] = val.shape 323 | else: 324 | data_dict[key] = val 325 | return data_dict 326 | 327 | @staticmethod 328 | def from_flat_dict(data_dict: Dict[str, Any]) -> 'SparseGraph': 329 | """Initialize SparseGraph from a flat dictionary. 330 | """ 331 | init_dict = {} 332 | del_entries = [] 333 | 334 | # Construct sparse matrices 335 | for key in data_dict.keys(): 336 | if key.endswith('_data') or key.endswith('.data'): 337 | if key.endswith('_data'): 338 | sep = '_' 339 | warnings.warn( 340 | "The separator used for sparse matrices during export (for .npz files) " 341 | "is now '.' instead of '_'. Please update (re-save) your stored graphs.", 342 | DeprecationWarning, stacklevel=3) 343 | else: 344 | sep = '.' 345 | matrix_name = key[:-5] 346 | mat_data = key 347 | mat_indices = '{}{}indices'.format(matrix_name, sep) 348 | mat_indptr = '{}{}indptr'.format(matrix_name, sep) 349 | mat_shape = '{}{}shape'.format(matrix_name, sep) 350 | if matrix_name == 'adj' or matrix_name == 'attr': 351 | warnings.warn( 352 | "Matrices are exported (for .npz files) with full names now. " 353 | "Please update (re-save) your stored graphs.", 354 | DeprecationWarning, stacklevel=3) 355 | matrix_name += '_matrix' 356 | init_dict[matrix_name] = sp.csr_matrix( 357 | (data_dict[mat_data], 358 | data_dict[mat_indices], 359 | data_dict[mat_indptr]), 360 | shape=data_dict[mat_shape]) 361 | del_entries.extend([mat_data, mat_indices, mat_indptr, mat_shape]) 362 | 363 | # Delete sparse matrix entries 364 | for del_entry in del_entries: 365 | del data_dict[del_entry] 366 | 367 | # Load everything else 368 | for key, val in data_dict.items(): 369 | if ((val is not None) and (None not in val)): 370 | init_dict[key] = val 371 | 372 | # Check if the dictionary contains only entries in sparse_graph_properties 373 | unknown_keys = [key for key in init_dict.keys() if key not in sparse_graph_properties] 374 | if len(unknown_keys) > 0: 375 | raise ValueError("Input dictionary contains keys that are not SparseGraph properties ({})." 376 | .format(unknown_keys)) 377 | 378 | return SparseGraph(**init_dict) 379 | 380 | @property 381 | def adj_matrix(self) -> sp.csr_matrix: 382 | return self._adj_matrix 383 | 384 | @property 385 | def attr_matrix(self) -> Union[np.ndarray, sp.csr_matrix]: 386 | return self._attr_matrix 387 | 388 | @property 389 | def edge_attr_matrix(self) -> Union[np.ndarray, sp.csr_matrix]: 390 | return self._edge_attr_matrix 391 | 392 | @property 393 | def labels(self) -> Union[np.ndarray, sp.csr_matrix]: 394 | return self._labels 395 | 396 | @property 397 | def node_names(self) -> np.ndarray: 398 | return self._node_names 399 | 400 | @property 401 | def attr_names(self) -> np.ndarray: 402 | return self._attr_names 403 | 404 | @property 405 | def edge_attr_names(self) -> np.ndarray: 406 | return self._edge_attr_names 407 | 408 | @property 409 | def class_names(self) -> np.ndarray: 410 | return self._class_names 411 | 412 | @property 413 | def metadata(self) -> Any: 414 | return self._metadata 415 | 416 | 417 | def create_subgraph( 418 | sparse_graph: 'SparseGraph', 419 | _sentinel: None = None, 420 | nodes_to_remove: np.ndarray = None, 421 | nodes_to_keep: np.ndarray = None 422 | ) -> 'SparseGraph': 423 | """Create a graph with the specified subset of nodes. 424 | 425 | Exactly one of (nodes_to_remove, nodes_to_keep) should be provided, while the other stays None. 426 | Note that to avoid confusion, it is required to pass node indices as named arguments to this function. 427 | 428 | The subgraph partially points to the old graph's data. 429 | 430 | Parameters 431 | ---------- 432 | sparse_graph 433 | Input graph. 434 | _sentinel 435 | Internal, to prevent passing positional arguments. Do not use. 436 | nodes_to_remove 437 | Indices of nodes that have to removed. 438 | nodes_to_keep 439 | Indices of nodes that have to be kept. 440 | 441 | Returns 442 | ------- 443 | SparseGraph 444 | Graph with specified nodes removed. 445 | 446 | """ 447 | # Check that arguments are passed correctly 448 | if _sentinel is not None: 449 | raise ValueError("Only call `create_subgraph` with named arguments'," 450 | " (nodes_to_remove=...) or (nodes_to_keep=...).") 451 | if nodes_to_remove is None and nodes_to_keep is None: 452 | raise ValueError("Either nodes_to_remove or nodes_to_keep must be provided.") 453 | elif nodes_to_remove is not None and nodes_to_keep is not None: 454 | raise ValueError("Only one of nodes_to_remove or nodes_to_keep must be provided.") 455 | elif nodes_to_remove is not None: 456 | nodes_to_keep = [i for i in range(sparse_graph.num_nodes()) if i not in nodes_to_remove] 457 | elif nodes_to_keep is not None: 458 | nodes_to_keep = sorted(nodes_to_keep) 459 | else: 460 | raise RuntimeError("This should never happen.") 461 | 462 | adj_matrix = sparse_graph.adj_matrix[nodes_to_keep][:, nodes_to_keep] 463 | if sparse_graph.attr_matrix is None: 464 | attr_matrix = None 465 | else: 466 | attr_matrix = sparse_graph.attr_matrix[nodes_to_keep] 467 | if sparse_graph.edge_attr_matrix is None: 468 | edge_attr_matrix = None 469 | else: 470 | old_idx = sparse_graph.get_edgeid_to_idx_array() 471 | keep_edge_idx = np.where(np.all(np.isin(old_idx, nodes_to_keep), axis=1))[0] 472 | edge_attr_matrix = sparse_graph.edge_attr_matrix[keep_edge_idx] 473 | if sparse_graph.labels is None: 474 | labels = None 475 | else: 476 | labels = sparse_graph.labels[nodes_to_keep] 477 | if sparse_graph.node_names is None: 478 | node_names = None 479 | else: 480 | node_names = sparse_graph.node_names[nodes_to_keep] 481 | # TODO: add warnings / logging 482 | # print("Resulting subgraph with N = {0}, E = {1}" 483 | # .format(sparse_graph.num_nodes(), sparse_graph.num_edges())) 484 | return SparseGraph( 485 | adj_matrix, attr_matrix, edge_attr_matrix, labels, node_names, 486 | sparse_graph.attr_names, sparse_graph.edge_attr_names, 487 | sparse_graph.class_names, sparse_graph.metadata) 488 | 489 | 490 | def remove_self_loops(sparse_graph: 'SparseGraph') -> 'SparseGraph': 491 | """Remove self loops (diagonal entries in the adjacency matrix). 492 | 493 | Changes are returned in a partially new SparseGraph. 494 | 495 | """ 496 | num_self_loops = (~np.isclose(sparse_graph.adj_matrix.diagonal(), 0)).sum() 497 | if num_self_loops > 0: 498 | adj_matrix = sparse_graph.adj_matrix.copy().tolil() 499 | adj_matrix.setdiag(0) 500 | adj_matrix = adj_matrix.tocsr() 501 | if sparse_graph.edge_attr_matrix is None: 502 | edge_attr_matrix = None 503 | else: 504 | old_idx = sparse_graph.get_edgeid_to_idx_array() 505 | keep_edge_idx = np.where((old_idx[:, 0] - old_idx[:, 1]) != 0)[0] 506 | edge_attr_matrix = sparse_graph._edge_attr_matrix[keep_edge_idx] 507 | warnings.warn("{0} self loops removed".format(num_self_loops)) 508 | return SparseGraph( 509 | adj_matrix, sparse_graph.attr_matrix, edge_attr_matrix, 510 | sparse_graph.labels, sparse_graph.node_names, 511 | sparse_graph.attr_names, sparse_graph.edge_attr_names, 512 | sparse_graph.class_names, sparse_graph.metadata) 513 | else: 514 | return sparse_graph 515 | 516 | 517 | def largest_connected_components(sparse_graph: 'SparseGraph', n_components: int = 1) -> 'SparseGraph': 518 | """Select the largest connected components in the graph. 519 | 520 | Changes are returned in a partially new SparseGraph. 521 | 522 | Parameters 523 | ---------- 524 | sparse_graph 525 | Input graph. 526 | n_components 527 | Number of largest connected components to keep. 528 | 529 | Returns 530 | ------- 531 | SparseGraph 532 | Subgraph of the input graph where only the nodes in largest n_components are kept. 533 | 534 | """ 535 | _, component_indices = sp.csgraph.connected_components(sparse_graph.adj_matrix) 536 | component_sizes = np.bincount(component_indices) 537 | components_to_keep = np.argsort(component_sizes)[::-1][:n_components] # reverse order to sort descending 538 | nodes_to_keep = [ 539 | idx for (idx, component) in enumerate(component_indices) if component in components_to_keep 540 | ] 541 | # TODO: add warnings / logging 542 | # print("Selecting {0} largest connected components".format(n_components)) 543 | return create_subgraph(sparse_graph, nodes_to_keep=nodes_to_keep) 544 | 545 | 546 | def load_from_npz(file_name: str) -> SparseGraph: 547 | """Load a SparseGraph from a Numpy binary file. 548 | 549 | Parameters 550 | ---------- 551 | file_name 552 | Name of the file to load. 553 | 554 | Returns 555 | ------- 556 | SparseGraph 557 | Graph in sparse matrix format. 558 | 559 | """ 560 | with np.load(file_name, allow_pickle=True) as loader: 561 | loader = dict(loader) 562 | if 'type' in loader: 563 | del loader['type'] 564 | dataset = SparseGraph.from_flat_dict(loader) 565 | return dataset 566 | --------------------------------------------------------------------------------