├── dataset └── README.md ├── subgraph ├── README.md └── .ipynb_checkpoints │ └── README-checkpoint.md ├── __pycache__ ├── DGI.cpython-37.pyc ├── metric.cpython-37.pyc ├── model.cpython-36.pyc ├── model.cpython-37.pyc ├── utils.cpython-37.pyc ├── capsule.cpython-37.pyc ├── dataset.cpython-37.pyc ├── sugbcon.cpython-37.pyc ├── utils_mp.cpython-37.pyc └── evaluation.cpython-37.pyc ├── README.md ├── model.py ├── subgcon.py ├── subgraph.py ├── evaluation.py ├── train.py └── utils_mp.py /dataset/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /subgraph/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /subgraph/.ipynb_checkpoints/README-checkpoint.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__pycache__/DGI.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzjiao/Subg-Con/HEAD/__pycache__/DGI.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzjiao/Subg-Con/HEAD/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzjiao/Subg-Con/HEAD/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzjiao/Subg-Con/HEAD/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzjiao/Subg-Con/HEAD/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/capsule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzjiao/Subg-Con/HEAD/__pycache__/capsule.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzjiao/Subg-Con/HEAD/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/sugbcon.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzjiao/Subg-Con/HEAD/__pycache__/sugbcon.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils_mp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzjiao/Subg-Con/HEAD/__pycache__/utils_mp.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzjiao/Subg-Con/HEAD/__pycache__/evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Subg-Con 2 | Sub-graph Contrast for Scalable Self-Supervised Graph Representation Learning (Jiao *et al.*, ICDM 2020): [https://arxiv.org/abs/2009.10273](https://arxiv.org/abs/2009.10273) 3 | 4 | 5 | ## Overview 6 | Here we provide an implementation of Subg-Con in PyTorch and Torch Geometric. The repository is organised as follows: 7 | - `subgcon.py` is the implementation of the Subg-Con pipeline; 8 | - `subgraph.py` is the implementation of subgraph extractor; 9 | - `model.py` is the implementation of components for Subg-Con, including a GNN layer, a pooling layer, and a scoring function; 10 | - `utils_mp.py` is the necessary processing subroutines; 11 | - `dataset/` will contain the automatically downloaded datasets; 12 | - `subgraph/` will contain the processed subgraphs. 13 | 14 | Finally, `train.py` puts all of the above together and may be used to execute a full training. 15 | 16 | 17 | ## Dependencies 18 | - Python 3.7.3 19 | - [PyTorch](https://github.com/pytorch/pytorch) 1.5.1 20 | - [torch_geometric](https://github.com/rusty1s/pytorch_geometric) 1.4.3 21 | - scikit-learn 0.23.2 22 | - scipy 1.5.2 23 | - cytoolz 0.10.0 24 | 25 | 26 | ## Reference 27 | If you make advantage of Subg-Con in your research, please cite the following in your manuscript: 28 | 29 | ``` 30 | @article{jiao2020sub, 31 | title={Sub-graph Contrast for Scalable Self-Supervised Graph Representation Learning}, 32 | author={Jiao, Yizhu and Xiong, Yun and Zhang, Jiawei and Zhang, Yao and Zhang, Tianqi and Zhu, Yangyong}, 33 | journal={arXiv preprint arXiv:2009.10273}, 34 | year={2020} 35 | } 36 | ``` 37 | 38 | 39 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Linear 5 | from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool, global_add_pool, SAGPooling 6 | 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, in_channels, hidden_channels): 10 | super(Encoder, self).__init__() 11 | self.hidden_channels = hidden_channels 12 | self.conv = GCNConv(in_channels, self.hidden_channels) 13 | self.prelu = nn.PReLU(self.hidden_channels) 14 | 15 | def forward(self, x, edge_index): 16 | x1 = self.conv(x, edge_index) 17 | x1 = self.prelu(x1) 18 | return x1 19 | 20 | 21 | class Pool(nn.Module): 22 | def __init__(self, in_channels, ratio=1.0): 23 | super(Pool, self).__init__() 24 | self.sag_pool = SAGPooling(in_channels, ratio) 25 | self.lin1 = torch.nn.Linear(in_channels * 2, in_channels) 26 | 27 | def forward(self, x, edge, batch, type='mean_pool'): 28 | if type == 'mean_pool': 29 | return global_mean_pool(x, batch) 30 | elif type == 'max_pool': 31 | return global_max_pool(x, batch) 32 | elif type == 'sum_pool': 33 | return global_add_pool(x, batch) 34 | elif type == 'sag_pool': 35 | x1, _, _, batch, _, _ = self.sag_pool(x, edge, batch=batch) 36 | return global_mean_pool(x1, batch) 37 | 38 | 39 | 40 | 41 | class Scorer(nn.Module): 42 | def __init__(self, hidden_size): 43 | super(Scorer, self).__init__() 44 | self.weight = nn.Parameter(torch.Tensor(hidden_size, hidden_size)) 45 | 46 | def forward(self, input1, input2): 47 | output = torch.sigmoid(torch.sum(input1 * torch.matmul(input2, self.weight), dim = -1)) 48 | return output 49 | 50 | 51 | -------------------------------------------------------------------------------- /subgcon.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import Parameter 5 | import torch.nn.functional as F 6 | from torch_geometric.nn.inits import reset 7 | from sklearn.linear_model import LogisticRegression 8 | 9 | 10 | EPS = 1e-15 11 | 12 | class SugbCon(torch.nn.Module): 13 | 14 | def __init__(self, hidden_channels, encoder, pool, scorer): 15 | super(SugbCon, self).__init__() 16 | self.encoder = encoder 17 | self.hidden_channels = hidden_channels 18 | self.pool = pool 19 | self.scorer = scorer 20 | self.marginloss = nn.MarginRankingLoss(0.5) 21 | self.sigmoid = nn.Sigmoid() 22 | self.reset_parameters() 23 | 24 | def reset_parameters(self): 25 | reset(self.scorer) 26 | reset(self.encoder) 27 | reset(self.pool) 28 | 29 | def forward(self, x, edge_index, batch=None, index=None): 30 | r""" Return node and subgraph representations of each node before and after being shuffled """ 31 | hidden = self.encoder(x, edge_index) 32 | if index is None: 33 | return hidden 34 | 35 | z = hidden[index] 36 | summary = self.pool(hidden, edge_index, batch) 37 | return z, summary 38 | 39 | 40 | def loss(self, hidden1, summary1): 41 | r"""Computes the margin objective.""" 42 | 43 | shuf_index = torch.randperm(summary1.size(0)) 44 | 45 | hidden2 = hidden1[shuf_index] 46 | summary2 = summary1[shuf_index] 47 | 48 | logits_aa = torch.sigmoid(torch.sum(hidden1 * summary1, dim = -1)) 49 | logits_bb = torch.sigmoid(torch.sum(hidden2 * summary2, dim = -1)) 50 | logits_ab = torch.sigmoid(torch.sum(hidden1 * summary2, dim = -1)) 51 | logits_ba = torch.sigmoid(torch.sum(hidden2 * summary1, dim = -1)) 52 | 53 | TotalLoss = 0.0 54 | ones = torch.ones(logits_aa.size(0)).cuda(logits_aa.device) 55 | TotalLoss += self.marginloss(logits_aa, logits_ba, ones) 56 | TotalLoss += self.marginloss(logits_bb, logits_ab, ones) 57 | 58 | return TotalLoss 59 | 60 | 61 | def test(self, train_z, train_y, val_z, val_y, test_z, test_y, solver='lbfgs', 62 | multi_class='auto', *args, **kwargs): 63 | r"""Evaluates latent space quality via a logistic regression downstream task.""" 64 | clf = LogisticRegression(solver=solver, multi_class=multi_class, *args, 65 | **kwargs).fit(train_z.detach().cpu().numpy(), 66 | train_y.detach().cpu().numpy()) 67 | val_acc = clf.score(val_z.detach().cpu().numpy(), val_y.detach().cpu().numpy()) 68 | test_acc = clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy()) 69 | return val_acc, test_acc -------------------------------------------------------------------------------- /subgraph.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import Parameter 5 | import torch.nn.functional as F 6 | from torch_geometric.nn.inits import reset 7 | from sklearn.linear_model import LogisticRegression 8 | 9 | 10 | EPS = 1e-15 11 | 12 | class SugbCon(torch.nn.Module): 13 | 14 | def __init__(self, hidden_channels, encoder, pool, scorer): 15 | super(SugbCon, self).__init__() 16 | self.encoder = encoder 17 | self.hidden_channels = hidden_channels 18 | self.pool = pool 19 | self.scorer = scorer 20 | self.marginloss = nn.MarginRankingLoss(0.5) 21 | self.sigmoid = nn.Sigmoid() 22 | self.reset_parameters() 23 | 24 | def reset_parameters(self): 25 | reset(self.scorer) 26 | reset(self.encoder) 27 | reset(self.pool) 28 | 29 | def forward(self, x, edge_index, batch=None, index=None): 30 | r""" Return node and subgraph representations of each node before and after being shuffled """ 31 | hidden = self.encoder(x, edge_index) 32 | if index is None: 33 | return hidden 34 | 35 | z = hidden[index] 36 | summary = self.pool(hidden, edge_index, batch) 37 | return z, summary 38 | 39 | 40 | def loss(self, hidden1, summary1): 41 | r"""Computes the margin objective.""" 42 | 43 | shuf_index = torch.randperm(summary1.size(0)) 44 | 45 | hidden2 = hidden1[shuf_index] 46 | summary2 = summary1[shuf_index] 47 | 48 | logits_aa = torch.sigmoid(torch.sum(hidden1 * summary1, dim = -1)) 49 | logits_bb = torch.sigmoid(torch.sum(hidden2 * summary2, dim = -1)) 50 | logits_ab = torch.sigmoid(torch.sum(hidden1 * summary2, dim = -1)) 51 | logits_ba = torch.sigmoid(torch.sum(hidden2 * summary1, dim = -1)) 52 | 53 | TotalLoss = 0.0 54 | ones = torch.ones(logits_aa.size(0)).cuda(logits_aa.device) 55 | TotalLoss += self.marginloss(logits_aa, logits_ba, ones) 56 | TotalLoss += self.marginloss(logits_bb, logits_ab, ones) 57 | 58 | return TotalLoss 59 | 60 | 61 | def test(self, train_z, train_y, val_z, val_y, test_z, test_y, solver='lbfgs', 62 | multi_class='auto', *args, **kwargs): 63 | r"""Evaluates latent space quality via a logistic regression downstream task.""" 64 | clf = LogisticRegression(solver=solver, multi_class=multi_class, *args, 65 | **kwargs).fit(train_z.detach().cpu().numpy(), 66 | train_y.detach().cpu().numpy()) 67 | val_acc = clf.score(val_z.detach().cpu().numpy(), val_y.detach().cpu().numpy()) 68 | test_acc = clf.score(test_z.detach().cpu().numpy(), test_y.detach().cpu().numpy()) 69 | return val_acc, test_acc -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import metrics 3 | import sklearn.metrics as Metrics 4 | import networkx as nx 5 | import scipy.sparse as sp 6 | from collections import defaultdict 7 | 8 | def process_set(sets): 9 | comms = [] 10 | labels = set(sets) 11 | for label in labels: 12 | comm = set(np.where(sets == label)[0]) 13 | comms.append(comm) 14 | return comms 15 | 16 | def NMI_helper(found_sets, GT_sets): 17 | NMI_result = metrics.normalized_mutual_info_score(GT_sets,found_sets) 18 | return NMI_result 19 | 20 | def F_score_helper(GT, found, common_elements): 21 | len_common = len(common_elements) 22 | precision = float(len_common)/len(found) 23 | if precision == 0: 24 | return 0 25 | 26 | recall = float(len_common)/len(GT) 27 | if recall == 0: 28 | return 0 29 | return (2*precision*recall)/(precision+recall) 30 | 31 | 32 | def cal_F_score_helper(found_sets, GT_sets): 33 | d1 = {} #best match for an extracted community 34 | d2 = {} #best match for a known community 35 | 36 | for i in range(len(GT_sets)): 37 | gt = GT_sets[i] 38 | f_max = 0 39 | 40 | for j in range(len(found_sets)): 41 | f = found_sets[j] 42 | 43 | common_elements = gt.intersection(f) 44 | if len(common_elements) == 0: 45 | temp = 0 46 | else: 47 | temp = F_score_helper(gt, f, common_elements) 48 | 49 | f_max = max(f_max,temp) 50 | 51 | d1[j] = max(d1.get(j,0),temp) 52 | 53 | d2[i] = f_max 54 | 55 | return d1, d2 56 | 57 | 58 | def cal_F_score(found_sets, GT_sets, verbose=False): 59 | found_sets = process_set(found_sets) 60 | GT_sets = process_set(GT_sets) 61 | d1,d2 = cal_F_score_helper(found_sets, GT_sets) 62 | 63 | if d1 == None: 64 | return [0]*6 65 | 66 | vals1 = sum(d1.values())/len(d1) 67 | vals2 = sum(d2.values())/len(d2) 68 | f_score = vals1 + vals2 69 | f_score /= 2 70 | f_score = round(f_score,4) 71 | vals1 = round(vals1,4) 72 | vals2 = round(vals2,4) 73 | 74 | return f_score, vals1, vals2 75 | 76 | def matched(true,pred): 77 | max_idx = max(max(true),max(pred)) 78 | cm = Metrics.confusion_matrix(true,pred,labels=np.arange(0,max_idx+1)) 79 | shifted_mat = np.zeros((cm.shape[0]*2,cm.shape[0] * 2)) 80 | shifted_mat[:cm.shape[0],cm.shape[0]:] = cm 81 | g = nx.from_numpy_matrix(shifted_mat) 82 | match = nx.max_weight_matching(g) 83 | unmatched = set(np.arange(0,cm.shape[0])) 84 | label_map = {} 85 | for m in match: 86 | p,t = max(m),min(m) 87 | unmatched.remove(t) 88 | label_map[p] = t 89 | unmatched = list(unmatched) 90 | for i in range(cm.shape[0],cm.shape[0]*2): 91 | if not i in label_map: 92 | label_map[i] = unmatched[-1] 93 | unmatched.pop() 94 | 95 | for i in range(len(pred)): 96 | pred[i] = label_map[pred[i]+cm.shape[0]] 97 | return pred 98 | 99 | def matched_cm(true,pred): 100 | max_idx = max(max(true),max(pred)) 101 | pred = matched(true,pred,labels=np.arange(0,max_idx+1)) 102 | cm = Metrics.confusion_matrix(true,pred) 103 | return cm 104 | 105 | def matched_ac(pred, true): 106 | pred = matched(true,pred) 107 | ac = Metrics.accuracy_score(true,pred) 108 | return ac 109 | #score = cal_F_score(found_comms, groundtruth_comms) 110 | #print ("f1:",score[0]) 111 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #CUDA_VISIBLE_DEVICES=0 /remote-home/mzhong/anaconda3/bin/python train.py --subgraph_size 10 --batch_size 200 2 | import argparse, os 3 | import math 4 | import torch 5 | import random 6 | import numpy as np 7 | from torch_geometric.datasets import Planetoid 8 | import torch_geometric.transforms as T 9 | 10 | from utils_mp import Subgraph, preprocess 11 | from subgcon import SugbCon 12 | from model import Encoder, Scorer, Pool 13 | 14 | 15 | def get_parser(): 16 | parser = argparse.ArgumentParser(description='Description: Script to run our model.') 17 | parser.add_argument('--dataset',help='Cora, Citeseer or Pubmed. Default=Cora', default='Cora') 18 | parser.add_argument('--batch_size', type=int, help='batch size', default=500) 19 | parser.add_argument('--subgraph_size', type=int, help='subgraph size', default=20) 20 | parser.add_argument('--n_order', type=int, help='order of neighbor nodes', default=10) 21 | parser.add_argument('--hidden_size', type=int, help='hidden size', default=1024) 22 | return parser 23 | 24 | if __name__ == '__main__': 25 | parser = get_parser() 26 | try: 27 | args = parser.parse_args() 28 | except: 29 | exit() 30 | print (args) 31 | 32 | # Loading data 33 | data = Planetoid(root='./dataset/' + args.dataset, name=args.dataset) 34 | num_classes = data.num_classes 35 | data = data[0] 36 | num_node = data.x.size(0) 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | 39 | # Setting up the subgraph extractor 40 | ppr_path = './subgraph/' + args.dataset 41 | subgraph = Subgraph(data.x, data.edge_index, ppr_path, args.subgraph_size, args.n_order) 42 | subgraph.build() 43 | 44 | # Setting up the model and optimizer 45 | model = SugbCon( 46 | hidden_channels=args.hidden_size, encoder=Encoder(data.num_features, args.hidden_size), 47 | pool=Pool(in_channels=args.hidden_size), 48 | scorer=Scorer(args.hidden_size)).to(device) 49 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 50 | 51 | 52 | def train(epoch): 53 | # Model training 54 | model.train() 55 | optimizer.zero_grad() 56 | sample_idx = random.sample(range(data.x.size(0)), args.batch_size) 57 | batch, index = subgraph.search(sample_idx) 58 | z, summary = model(batch.x.cuda(), batch.edge_index.cuda(), batch.batch.cuda(), index.cuda()) 59 | 60 | loss = model.loss(z, summary) 61 | loss.backward() 62 | optimizer.step() 63 | return loss.item() 64 | 65 | 66 | def get_all_node_emb(model, mask): 67 | # Obtain central node embs from subgraphs 68 | node_list = np.arange(0, num_node, 1)[mask] 69 | list_size = node_list.size 70 | z = torch.Tensor(list_size, args.hidden_size).cuda() 71 | group_nb = math.ceil(list_size/args.batch_size) 72 | for i in range(group_nb): 73 | maxx = min(list_size, (i + 1) * args.batch_size) 74 | minn = i * args.batch_size 75 | batch, index = subgraph.search(node_list[minn:maxx]) 76 | node, _ = model(batch.x.cuda(), batch.edge_index.cuda(), batch.batch.cuda(), index.cuda()) 77 | z[minn:maxx] = node 78 | return z 79 | 80 | 81 | def test(model): 82 | # Model testing 83 | model.eval() 84 | with torch.no_grad(): 85 | train_z = get_all_node_emb(model, data.train_mask) 86 | val_z = get_all_node_emb(model, data.val_mask) 87 | test_z = get_all_node_emb(model, data.test_mask) 88 | 89 | train_y = data.y[data.train_mask] 90 | val_y = data.y[data.val_mask] 91 | test_y = data.y[data.test_mask] 92 | val_acc, test_acc = model.test(train_z, train_y, val_z, val_y, test_z, test_y) 93 | print('val_acc = {} test_acc = {}'.format(val_acc, test_acc)) 94 | return val_acc, test_acc 95 | 96 | 97 | print('Start training !!!') 98 | best_acc_from_val = 0 99 | best_val_acc = 0 100 | best_ts_acc = 0 101 | max_val = 0 102 | stop_cnt = 0 103 | patience = 20 104 | 105 | for epoch in range(10000): 106 | loss = train(epoch) 107 | print('epoch = {}, loss = {}'.format(epoch, loss)) 108 | val_acc, test_acc = test(model) 109 | best_val_acc = max(best_val_acc, val_acc) 110 | best_ts_acc = max(best_ts_acc, test_acc) 111 | if val_acc >= max_val: 112 | max_val = val_acc 113 | best_acc_from_val = test_acc 114 | stop_cnt = 0 115 | else: 116 | stop_cnt += 1 117 | print('best_val_acc = {}, best_test_acc = {}'.format(best_val_acc, best_ts_acc)) 118 | if stop_cnt >= patience: 119 | break 120 | print('best_acc_from_val = {}'.format(best_acc_from_val)) 121 | 122 | -------------------------------------------------------------------------------- /utils_mp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from cytoolz import curry 5 | import multiprocessing as mp 6 | from scipy import sparse as sp 7 | from sklearn.preprocessing import normalize, StandardScaler 8 | from torch_geometric.data import Data, Batch 9 | 10 | 11 | def standardize(feat, mask): 12 | scaler = StandardScaler() 13 | scaler.fit(feat[mask]) 14 | new_feat = torch.FloatTensor(scaler.transform(feat)) 15 | return new_feats 16 | 17 | 18 | def preprocess(features): 19 | rowsum = np.array(features.sum(1)) 20 | r_inv = np.power(rowsum, -1).flatten() 21 | r_inv[np.isinf(r_inv)] = 0. 22 | r_mat_inv = sp.diags(r_inv) 23 | features = r_mat_inv.dot(features) 24 | return torch.tensor(features) 25 | 26 | 27 | class PPR: 28 | #Node-wise personalized pagerank 29 | def __init__(self, adj_mat, maxsize=200, n_order=2, alpha=0.85): 30 | self.n_order = n_order 31 | self.maxsize = maxsize 32 | self.adj_mat = adj_mat 33 | self.P = normalize(adj_mat, norm='l1', axis=0) 34 | self.d = np.array(adj_mat.sum(1)).squeeze() 35 | 36 | def search(self, seed, alpha=0.85): 37 | x = sp.csc_matrix((np.ones(1), ([seed], np.zeros(1, dtype=int))), shape=[self.P.shape[0], 1]) 38 | r = x.copy() 39 | for _ in range(self.n_order): 40 | x = (1 - alpha) * r + alpha * self.P @ x 41 | scores = x.data / (self.d[x.indices] + 1e-9) 42 | 43 | idx = scores.argsort()[::-1][:self.maxsize] 44 | neighbor = np.array(x.indices[idx]) 45 | 46 | seed_idx = np.where(neighbor == seed)[0] 47 | if seed_idx.size == 0: 48 | neighbor = np.append(np.array([seed]), neighbor) 49 | else : 50 | seed_idx = seed_idx[0] 51 | neighbor[seed_idx], neighbor[0] = neighbor[0], neighbor[seed_idx] 52 | 53 | assert np.where(neighbor == seed)[0].size == 1 54 | assert np.where(neighbor == seed)[0][0] == 0 55 | 56 | return neighbor 57 | 58 | @curry 59 | def process(self, path, seed): 60 | ppr_path = os.path.join(path, 'ppr{}'.format(seed)) 61 | if not os.path.isfile(ppr_path) or os.stat(ppr_path).st_size == 0: 62 | print ('Processing node {}.'.format(seed)) 63 | neighbor = self.search(seed) 64 | torch.save(neighbor, ppr_path) 65 | else : 66 | print ('File of node {} exists.'.format(seed)) 67 | 68 | def search_all(self, node_num, path): 69 | neighbor = {} 70 | if os.path.isfile(path+'_neighbor') and os.stat(path+'_neighbor').st_size != 0: 71 | print ("Exists neighbor file") 72 | neighbor = torch.load(path+'_neighbor') 73 | else : 74 | print ("Extracting subgraphs") 75 | os.system('mkdir {}'.format(path)) 76 | with mp.Pool() as pool: 77 | list(pool.imap_unordered(self.process(path), list(range(node_num)), chunksize=1000)) 78 | 79 | print ("Finish Extracting") 80 | for i in range(node_num): 81 | neighbor[i] = torch.load(os.path.join(path, 'ppr{}'.format(i))) 82 | torch.save(neighbor, path+'_neighbor') 83 | os.system('rm -r {}'.format(path)) 84 | print ("Finish Writing") 85 | return neighbor 86 | 87 | 88 | class Subgraph: 89 | #Class for subgraph extraction 90 | 91 | def __init__(self, x, edge_index, path, maxsize=50, n_order=10): 92 | self.x = x 93 | self.path = path 94 | self.edge_index = np.array(edge_index) 95 | self.edge_num = edge_index[0].size(0) 96 | self.node_num = x.size(0) 97 | self.maxsize = maxsize 98 | 99 | self.sp_adj = sp.csc_matrix((np.ones(self.edge_num), (edge_index[0], edge_index[1])), 100 | shape=[self.node_num, self.node_num]) 101 | self.ppr = PPR(self.sp_adj, n_order=n_order) 102 | 103 | self.neighbor = {} 104 | self.adj_list = {} 105 | self.subgraph = {} 106 | 107 | def process_adj_list(self): 108 | for i in range(self.node_num): 109 | self.adj_list[i] = set() 110 | for i in range(self.edge_num): 111 | u, v = self.edge_index[0][i], self.edge_index[1][i] 112 | self.adj_list[u].add(v) 113 | self.adj_list[v].add(u) 114 | 115 | def adjust_edge(self, idx): 116 | #Generate edges for subgraphs 117 | dic = {} 118 | for i in range(len(idx)): 119 | dic[idx[i]] = i 120 | 121 | new_index = [[], []] 122 | nodes = set(idx) 123 | for i in idx: 124 | edge = list(self.adj_list[i] & nodes) 125 | edge = [dic[_] for _ in edge] 126 | #edge = [_ for _ in edge if _ > i] 127 | new_index[0] += len(edge) * [dic[i]] 128 | new_index[1] += edge 129 | return torch.LongTensor(new_index) 130 | 131 | def adjust_x(self, idx): 132 | #Generate node features for subgraphs 133 | return self.x[idx] 134 | 135 | def build(self): 136 | #Extract subgraphs for all nodes 137 | if os.path.isfile(self.path+'_subgraph') and os.stat(self.path+'_subgraph').st_size != 0: 138 | print ("Exists subgraph file") 139 | self.subgraph = torch.load(self.path+'_subgraph') 140 | return 141 | 142 | self.neighbor = self.ppr.search_all(self.node_num, self.path) 143 | self.process_adj_list() 144 | for i in range(self.node_num): 145 | nodes = self.neighbor[i][:self.maxsize] 146 | x = self.adjust_x(nodes) 147 | edge = self.adjust_edge(nodes) 148 | self.subgraph[i] = Data(x, edge) 149 | torch.save(self.subgraph, self.path+'_subgraph') 150 | 151 | def search(self, node_list): 152 | #Extract subgraphs for nodes in the list 153 | batch = [] 154 | index = [] 155 | size = 0 156 | for node in node_list: 157 | batch.append(self.subgraph[node]) 158 | index.append(size) 159 | size += self.subgraph[node].x.size(0) 160 | index = torch.tensor(index) 161 | batch = Batch().from_data_list(batch) 162 | return batch, index 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | --------------------------------------------------------------------------------