├── README.md ├── data ├── Enron │ ├── eval_14.npz │ ├── graph.pkl │ └── train_pairs_n2v_14.pkl └── email_uci │ ├── features.npz │ ├── graph.pkl │ ├── graphs.npz │ └── graphs.pkl ├── eval └── link_prediction.py ├── model_checkpoints └── model.pt ├── models ├── __init__.py ├── layers.py └── model.py ├── raw_data └── Enron │ ├── process.py │ ├── vis.digraph.allEdges.json │ └── vis.graph.nodeList.json ├── requirements.txt ├── train.py └── utils ├── __init__.py ├── minibatch.py ├── preprocess.py ├── random_walk.py └── utilities.py /README.md: -------------------------------------------------------------------------------- 1 | # DySAT: Deep Neural Representation Learning on Dynamic Graphs via Self-Attention Networks 2 | This is a pytorch implementation of DySAT. All codes are adapted from official [implementation in TensorFlow](https://github.com/aravindsankar28/DySAT). This implementation is only tested using dataset Enron, and the results is inconsistent with official results (better than that). Code review and contribution is welcome! 3 | 4 | # Raw Data Process 5 | ``` 6 | cd raw_data/Enron 7 | pyhton process.py 8 | ``` 9 | The processed data will stored at 'data/Enron" 10 | 11 | # Training 12 | ``` 13 | python train --dataset Enron --time_steps 16 14 | ``` 15 | 16 | -------------------------------------------------------------------------------- /data/Enron/eval_14.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGSSS/DySAT_pytorch/6230d7dcfe83e3f2f85b8b771dc8f57e261004e6/data/Enron/eval_14.npz -------------------------------------------------------------------------------- /data/Enron/graph.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGSSS/DySAT_pytorch/6230d7dcfe83e3f2f85b8b771dc8f57e261004e6/data/Enron/graph.pkl -------------------------------------------------------------------------------- /data/Enron/train_pairs_n2v_14.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGSSS/DySAT_pytorch/6230d7dcfe83e3f2f85b8b771dc8f57e261004e6/data/Enron/train_pairs_n2v_14.pkl -------------------------------------------------------------------------------- /data/email_uci/features.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGSSS/DySAT_pytorch/6230d7dcfe83e3f2f85b8b771dc8f57e261004e6/data/email_uci/features.npz -------------------------------------------------------------------------------- /data/email_uci/graph.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGSSS/DySAT_pytorch/6230d7dcfe83e3f2f85b8b771dc8f57e261004e6/data/email_uci/graph.pkl -------------------------------------------------------------------------------- /data/email_uci/graphs.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGSSS/DySAT_pytorch/6230d7dcfe83e3f2f85b8b771dc8f57e261004e6/data/email_uci/graphs.npz -------------------------------------------------------------------------------- /data/email_uci/graphs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGSSS/DySAT_pytorch/6230d7dcfe83e3f2f85b8b771dc8f57e261004e6/data/email_uci/graphs.pkl -------------------------------------------------------------------------------- /eval/link_prediction.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | from sklearn.metrics import roc_auc_score 3 | import numpy as np 4 | from sklearn import linear_model 5 | from collections import defaultdict 6 | import random 7 | 8 | np.random.seed(123) 9 | operatorTypes = ["HAD"] 10 | 11 | 12 | def write_to_csv(test_results, output_name, model_name, dataset, time_steps, mod='val'): 13 | """Output result scores to a csv file for result logging""" 14 | with open(output_name, 'a+') as f: 15 | for op in test_results: 16 | print("{} results ({})".format(model_name, mod), test_results[op]) 17 | _, best_auc = test_results[op] 18 | f.write("{},{},{},{},{},{},{}\n".format(dataset, time_steps, model_name, op, mod, "AUC", best_auc)) 19 | 20 | 21 | def get_link_score(fu, fv, operator): 22 | """Given a pair of embeddings, compute link feature based on operator (such as Hadammad product, etc.)""" 23 | fu = np.array(fu) 24 | fv = np.array(fv) 25 | if operator == "HAD": 26 | return np.multiply(fu, fv) 27 | else: 28 | raise NotImplementedError 29 | 30 | 31 | def get_link_feats(links, source_embeddings, target_embeddings, operator): 32 | """Compute link features for a list of pairs""" 33 | features = [] 34 | for l in links: 35 | a, b = l[0], l[1] 36 | f = get_link_score(source_embeddings[a], target_embeddings[b], operator) 37 | features.append(f) 38 | return features 39 | 40 | 41 | def get_random_split(train_pos, train_neg, val_pos, val_neg, test_pos, test_neg): 42 | """ Randomly split a given set of train, val and test examples""" 43 | all_data_pos = [] 44 | all_data_neg = [] 45 | 46 | all_data_pos.extend(train_pos) 47 | all_data_neg.extend(train_neg) 48 | all_data_pos.extend(test_pos) 49 | all_data_neg.extend(test_neg) 50 | 51 | # re-define train_pos, train_neg, test_pos, test_neg. 52 | random.shuffle(all_data_pos) 53 | random.shuffle(all_data_neg) 54 | 55 | train_pos = all_data_pos[:int(0.2 * len(all_data_pos))] 56 | train_neg = all_data_neg[:int(0.2 * len(all_data_neg))] 57 | 58 | test_pos = all_data_pos[int(0.2 * len(all_data_pos)):] 59 | test_neg = all_data_neg[int(0.2 * len(all_data_neg)):] 60 | print("# train :", len(train_pos) + len(train_neg), "# val :", len(val_pos) + len(val_neg), 61 | "#test :", len(test_pos) + len(test_neg)) 62 | return train_pos, train_neg, val_pos, val_neg, test_pos, test_neg 63 | 64 | 65 | def evaluate_classifier(train_pos, train_neg, val_pos, val_neg, test_pos, test_neg, source_embeds, target_embeds): 66 | """Downstream logistic regression classifier to evaluate link prediction""" 67 | test_results = defaultdict(lambda: []) 68 | val_results = defaultdict(lambda: []) 69 | 70 | test_auc = get_roc_score_t(test_pos, test_neg, source_embeds, target_embeds) 71 | val_auc = get_roc_score_t(val_pos, val_neg, source_embeds, target_embeds) 72 | 73 | # Compute AUC based on sigmoid(u^T v) without classifier training. 74 | test_results['SIGMOID'].extend([test_auc, test_auc]) 75 | val_results['SIGMOID'].extend([val_auc, val_auc]) 76 | 77 | test_pred_true = defaultdict(lambda: []) 78 | val_pred_true = defaultdict(lambda: []) 79 | 80 | for operator in operatorTypes: 81 | train_pos_feats = np.array(get_link_feats(train_pos, source_embeds, target_embeds, operator)) 82 | train_neg_feats = np.array(get_link_feats(train_neg, source_embeds, target_embeds, operator)) 83 | val_pos_feats = np.array(get_link_feats(val_pos, source_embeds, target_embeds, operator)) 84 | val_neg_feats = np.array(get_link_feats(val_neg, source_embeds, target_embeds, operator)) 85 | test_pos_feats = np.array(get_link_feats(test_pos, source_embeds, target_embeds, operator)) 86 | test_neg_feats = np.array(get_link_feats(test_neg, source_embeds, target_embeds, operator)) 87 | 88 | train_pos_labels = np.array([1] * len(train_pos_feats)) 89 | train_neg_labels = np.array([-1] * len(train_neg_feats)) 90 | val_pos_labels = np.array([1] * len(val_pos_feats)) 91 | val_neg_labels = np.array([-1] * len(val_neg_feats)) 92 | 93 | test_pos_labels = np.array([1] * len(test_pos_feats)) 94 | test_neg_labels = np.array([-1] * len(test_neg_feats)) 95 | train_data = np.vstack((train_pos_feats, train_neg_feats)) 96 | train_labels = np.append(train_pos_labels, train_neg_labels) 97 | 98 | val_data = np.vstack((val_pos_feats, val_neg_feats)) 99 | val_labels = np.append(val_pos_labels, val_neg_labels) 100 | 101 | test_data = np.vstack((test_pos_feats, test_neg_feats)) 102 | test_labels = np.append(test_pos_labels, test_neg_labels) 103 | 104 | logistic = linear_model.LogisticRegression() 105 | logistic.fit(train_data, train_labels) 106 | test_predict = logistic.predict_proba(test_data)[:, 1] 107 | val_predict = logistic.predict_proba(val_data)[:, 1] 108 | 109 | test_roc_score = roc_auc_score(test_labels, test_predict) 110 | val_roc_score = roc_auc_score(val_labels, val_predict) 111 | 112 | val_results[operator].extend([val_roc_score, val_roc_score]) 113 | test_results[operator].extend([test_roc_score, test_roc_score]) 114 | 115 | val_pred_true[operator].extend(zip(val_predict, val_labels)) 116 | test_pred_true[operator].extend(zip(test_predict, test_labels)) 117 | 118 | return val_results, test_results, val_pred_true, test_pred_true 119 | 120 | 121 | def get_roc_score_t(edges_pos, edges_neg, source_emb, target_emb): 122 | """Given test examples, edges_pos: +ve edges, edges_neg: -ve edges, return ROC scores for a given snapshot""" 123 | def sigmoid(x): 124 | return 1 / (1 + np.exp(-x)) 125 | 126 | # Predict on test set of edges 127 | adj_rec = np.dot(source_emb, target_emb.T) 128 | pred = [] 129 | pos = [] 130 | for e in edges_pos: 131 | pred.append(sigmoid(adj_rec[e[0], e[1]])) 132 | pos.append(1.0) 133 | 134 | pred_neg = [] 135 | neg = [] 136 | for e in edges_neg: 137 | pred_neg.append(sigmoid(adj_rec[e[0], e[1]])) 138 | neg.append(0.0) 139 | 140 | pred_all = np.hstack([pred, pred_neg]) 141 | labels_all = np.hstack([np.ones(len(pred)), np.zeros(len(pred_neg))]) 142 | roc_score = roc_auc_score(labels_all, pred_all) 143 | return roc_score 144 | -------------------------------------------------------------------------------- /model_checkpoints/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGSSS/DySAT_pytorch/6230d7dcfe83e3f2f85b8b771dc8f57e261004e6/model_checkpoints/model.pt -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGSSS/DySAT_pytorch/6230d7dcfe83e3f2f85b8b771dc8f57e261004e6/models/__init__.py -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : layers.py 4 | @Time : 2021/02/18 14:30:13 5 | @Author : Fei gao 6 | @Contact : feig@mail.bnu.edu.cn 7 | BNU, Beijing, China 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch_geometric.utils import softmax 14 | from torch_scatter import scatter 15 | 16 | import copy 17 | 18 | 19 | class StructuralAttentionLayer(nn.Module): 20 | def __init__(self, 21 | input_dim, 22 | output_dim, 23 | n_heads, 24 | attn_drop, 25 | ffd_drop, 26 | residual): 27 | super(StructuralAttentionLayer, self).__init__() 28 | self.out_dim = output_dim // n_heads 29 | self.n_heads = n_heads 30 | self.act = nn.ELU() 31 | 32 | self.lin = nn.Linear(input_dim, n_heads * self.out_dim, bias=False) 33 | self.att_l = nn.Parameter(torch.Tensor(1, n_heads, self.out_dim)) 34 | self.att_r = nn.Parameter(torch.Tensor(1, n_heads, self.out_dim)) 35 | 36 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2) 37 | 38 | self.attn_drop = nn.Dropout(attn_drop) 39 | self.ffd_drop = nn.Dropout(ffd_drop) 40 | 41 | self.residual = residual 42 | if self.residual: 43 | self.lin_residual = nn.Linear(input_dim, n_heads * self.out_dim, bias=False) 44 | 45 | self.xavier_init() 46 | 47 | def forward(self, graph): 48 | graph = copy.deepcopy(graph) 49 | edge_index = graph.edge_index 50 | edge_weight = graph.edge_weight.reshape(-1, 1) 51 | H, C = self.n_heads, self.out_dim 52 | x = self.lin(graph.x).view(-1, H, C) # [N, heads, out_dim] 53 | # attention 54 | alpha_l = (x * self.att_l).sum(dim=-1).squeeze() # [N, heads] 55 | alpha_r = (x * self.att_r).sum(dim=-1).squeeze() 56 | alpha_l = alpha_l[edge_index[0]] # [num_edges, heads] 57 | alpha_r = alpha_r[edge_index[1]] 58 | alpha = alpha_r + alpha_l 59 | alpha = edge_weight * alpha 60 | alpha = self.leaky_relu(alpha) 61 | coefficients = softmax(alpha, edge_index[1]) # [num_edges, heads] 62 | 63 | # dropout 64 | if self.training: 65 | coefficients = self.attn_drop(coefficients) 66 | x = self.ffd_drop(x) 67 | x_j = x[edge_index[0]] # [num_edges, heads, out_dim] 68 | 69 | # output 70 | out = self.act(scatter(x_j * coefficients[:, :, None], edge_index[1], dim=0, reduce="sum")) 71 | out = out.reshape(-1, self.n_heads*self.out_dim) #[num_nodes, output_dim] 72 | if self.residual: 73 | out = out + self.lin_residual(graph.x) 74 | graph.x = out 75 | return graph 76 | 77 | def xavier_init(self): 78 | nn.init.xavier_uniform_(self.att_l) 79 | nn.init.xavier_uniform_(self.att_r) 80 | 81 | 82 | class TemporalAttentionLayer(nn.Module): 83 | def __init__(self, 84 | input_dim, 85 | n_heads, 86 | num_time_steps, 87 | attn_drop, 88 | residual): 89 | super(TemporalAttentionLayer, self).__init__() 90 | self.n_heads = n_heads 91 | self.num_time_steps = num_time_steps 92 | self.residual = residual 93 | 94 | # define weights 95 | self.position_embeddings = nn.Parameter(torch.Tensor(num_time_steps, input_dim)) 96 | self.Q_embedding_weights = nn.Parameter(torch.Tensor(input_dim, input_dim)) 97 | self.K_embedding_weights = nn.Parameter(torch.Tensor(input_dim, input_dim)) 98 | self.V_embedding_weights = nn.Parameter(torch.Tensor(input_dim, input_dim)) 99 | # ff 100 | self.lin = nn.Linear(input_dim, input_dim, bias=True) 101 | # dropout 102 | self.attn_dp = nn.Dropout(attn_drop) 103 | self.xavier_init() 104 | 105 | 106 | def forward(self, inputs): 107 | """In: attn_outputs (of StructuralAttentionLayer at each snapshot):= [N, T, F]""" 108 | # 1: Add position embeddings to input 109 | position_inputs = torch.arange(0,self.num_time_steps).reshape(1, -1).repeat(inputs.shape[0], 1).long().to(inputs.device) 110 | temporal_inputs = inputs + self.position_embeddings[position_inputs] # [N, T, F] 111 | 112 | # 2: Query, Key based multi-head self attention. 113 | q = torch.tensordot(temporal_inputs, self.Q_embedding_weights, dims=([2],[0])) # [N, T, F] 114 | k = torch.tensordot(temporal_inputs, self.K_embedding_weights, dims=([2],[0])) # [N, T, F] 115 | v = torch.tensordot(temporal_inputs, self.V_embedding_weights, dims=([2],[0])) # [N, T, F] 116 | 117 | # 3: Split, concat and scale. 118 | split_size = int(q.shape[-1]/self.n_heads) 119 | q_ = torch.cat(torch.split(q, split_size_or_sections=split_size, dim=2), dim=0) # [hN, T, F/h] 120 | k_ = torch.cat(torch.split(k, split_size_or_sections=split_size, dim=2), dim=0) # [hN, T, F/h] 121 | v_ = torch.cat(torch.split(v, split_size_or_sections=split_size, dim=2), dim=0) # [hN, T, F/h] 122 | 123 | outputs = torch.matmul(q_, k_.permute(0,2,1)) # [hN, T, T] 124 | outputs = outputs / (self.num_time_steps ** 0.5) 125 | # 4: Masked (causal) softmax to compute attention weights. 126 | diag_val = torch.ones_like(outputs[0]) 127 | tril = torch.tril(diag_val) 128 | masks = tril[None, :, :].repeat(outputs.shape[0], 1, 1) # [h*N, T, T] 129 | padding = torch.ones_like(masks) * (-2**32+1) 130 | outputs = torch.where(masks==0, padding, outputs) 131 | outputs = F.softmax(outputs, dim=2) 132 | self.attn_wts_all = outputs # [h*N, T, T] 133 | 134 | # 5: Dropout on attention weights. 135 | if self.training: 136 | outputs = self.attn_dp(outputs) 137 | outputs = torch.matmul(outputs, v_) # [hN, T, F/h] 138 | outputs = torch.cat(torch.split(outputs, split_size_or_sections=int(outputs.shape[0]/self.n_heads), dim=0), dim=2) # [N, T, F] 139 | 140 | # 6: Feedforward and residual 141 | outputs = self.feedforward(outputs) 142 | if self.residual: 143 | outputs = outputs + temporal_inputs 144 | return outputs 145 | 146 | def feedforward(self, inputs): 147 | outputs = F.relu(self.lin(inputs)) 148 | return outputs + inputs 149 | 150 | 151 | def xavier_init(self): 152 | nn.init.xavier_uniform_(self.position_embeddings) 153 | nn.init.xavier_uniform_(self.Q_embedding_weights) 154 | nn.init.xavier_uniform_(self.K_embedding_weights) 155 | nn.init.xavier_uniform_(self.V_embedding_weights) 156 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : model.py 4 | @Time : 2021/02/19 21:10:00 5 | @Author : Fei gao 6 | @Contact : feig@mail.bnu.edu.cn 7 | BNU, Beijing, China 8 | ''' 9 | import copy 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn.modules.loss import BCEWithLogitsLoss 15 | 16 | from models.layers import StructuralAttentionLayer, TemporalAttentionLayer 17 | from utils.utilities import fixed_unigram_candidate_sampler 18 | 19 | class DySAT(nn.Module): 20 | def __init__(self, args, num_features, time_length): 21 | """[summary] 22 | 23 | Args: 24 | args ([type]): [description] 25 | time_length (int): Total timesteps in dataset. 26 | """ 27 | super(DySAT, self).__init__() 28 | self.args = args 29 | if args.window < 0: 30 | self.num_time_steps = time_length 31 | else: 32 | self.num_time_steps = min(time_length, args.window + 1) # window = 0 => only self. 33 | self.num_features = num_features 34 | 35 | self.structural_head_config = list(map(int, args.structural_head_config.split(","))) 36 | self.structural_layer_config = list(map(int, args.structural_layer_config.split(","))) 37 | self.temporal_head_config = list(map(int, args.temporal_head_config.split(","))) 38 | self.temporal_layer_config = list(map(int, args.temporal_layer_config.split(","))) 39 | self.spatial_drop = args.spatial_drop 40 | self.temporal_drop = args.temporal_drop 41 | 42 | self.structural_attn, self.temporal_attn = self.build_model() 43 | 44 | self.bceloss = BCEWithLogitsLoss() 45 | 46 | def forward(self, graphs): 47 | 48 | # Structural Attention forward 49 | structural_out = [] 50 | for t in range(0, self.num_time_steps): 51 | structural_out.append(self.structural_attn(graphs[t])) 52 | structural_outputs = [g.x[:,None,:] for g in structural_out] # list of [Ni, 1, F] 53 | 54 | # padding outputs along with Ni 55 | maximum_node_num = structural_outputs[-1].shape[0] 56 | out_dim = structural_outputs[-1].shape[-1] 57 | structural_outputs_padded = [] 58 | for out in structural_outputs: 59 | zero_padding = torch.zeros(maximum_node_num-out.shape[0], 1, out_dim).to(out.device) 60 | padded = torch.cat((out, zero_padding), dim=0) 61 | structural_outputs_padded.append(padded) 62 | structural_outputs_padded = torch.cat(structural_outputs_padded, dim=1) # [N, T, F] 63 | 64 | # Temporal Attention forward 65 | temporal_out = self.temporal_attn(structural_outputs_padded) 66 | 67 | return temporal_out 68 | 69 | def build_model(self): 70 | input_dim = self.num_features 71 | 72 | # 1: Structural Attention Layers 73 | structural_attention_layers = nn.Sequential() 74 | for i in range(len(self.structural_layer_config)): 75 | layer = StructuralAttentionLayer(input_dim=input_dim, 76 | output_dim=self.structural_layer_config[i], 77 | n_heads=self.structural_head_config[i], 78 | attn_drop=self.spatial_drop, 79 | ffd_drop=self.spatial_drop, 80 | residual=self.args.residual) 81 | structural_attention_layers.add_module(name="structural_layer_{}".format(i), module=layer) 82 | input_dim = self.structural_layer_config[i] 83 | 84 | # 2: Temporal Attention Layers 85 | input_dim = self.structural_layer_config[-1] 86 | temporal_attention_layers = nn.Sequential() 87 | for i in range(len(self.temporal_layer_config)): 88 | layer = TemporalAttentionLayer(input_dim=input_dim, 89 | n_heads=self.temporal_head_config[i], 90 | num_time_steps=self.num_time_steps, 91 | attn_drop=self.temporal_drop, 92 | residual=self.args.residual) 93 | temporal_attention_layers.add_module(name="temporal_layer_{}".format(i), module=layer) 94 | input_dim = self.temporal_layer_config[i] 95 | 96 | return structural_attention_layers, temporal_attention_layers 97 | 98 | def get_loss(self, feed_dict): 99 | node_1, node_2, node_2_negative, graphs = feed_dict.values() 100 | # run gnn 101 | final_emb = self.forward(graphs) # [N, T, F] 102 | self.graph_loss = 0 103 | for t in range(self.num_time_steps - 1): 104 | emb_t = final_emb[:, t, :].squeeze() #[N, F] 105 | source_node_emb = emb_t[node_1[t]] 106 | tart_node_pos_emb = emb_t[node_2[t]] 107 | tart_node_neg_emb = emb_t[node_2_negative[t]] 108 | pos_score = torch.sum(source_node_emb*tart_node_pos_emb, dim=1) 109 | neg_score = -torch.sum(source_node_emb[:, None, :]*tart_node_neg_emb, dim=2).flatten() 110 | pos_loss = self.bceloss(pos_score, torch.ones_like(pos_score)) 111 | neg_loss = self.bceloss(neg_score, torch.ones_like(neg_score)) 112 | graphloss = pos_loss + self.args.neg_weight*neg_loss 113 | self.graph_loss += graphloss 114 | return self.graph_loss 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /raw_data/Enron/process.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import itertools 4 | from collections import defaultdict 5 | from itertools import islice, chain 6 | 7 | import networkx as nx 8 | import numpy as np 9 | import pickle as pkl 10 | from scipy.sparse import csr_matrix 11 | 12 | from datetime import datetime 13 | from datetime import timedelta 14 | import dateutil.parser 15 | 16 | 17 | def lines_per_n(f, n): 18 | for line in f: 19 | yield ''.join(chain([line], itertools.islice(f, n - 1))) 20 | 21 | def getDateTimeFromISO8601String(s): 22 | d = dateutil.parser.parse(s) 23 | return d 24 | 25 | if __name__ == "__main__": 26 | 27 | node_data = defaultdict(lambda : ()) 28 | with open('vis.graph.nodeList.json') as f: 29 | for chunk in lines_per_n(f, 5): 30 | chunk = chunk.split("\n") 31 | id_string = chunk[1].split(":")[1] 32 | x = [x.start() for x in re.finditer('\"', id_string)] 33 | id = id_string[x[0]+1:x[1]] 34 | 35 | name_string = chunk[2].split(":")[1] 36 | x = [x.start() for x in re.finditer('\"', name_string)] 37 | name = name_string[x[0]+1:x[1]] 38 | 39 | idx_string = chunk[3].split(":")[1] 40 | x1 = idx_string.find('(') 41 | x2 = idx_string.find(')') 42 | idx = idx_string[x1+1:x2] 43 | 44 | print("ID:{}, IDX:{:<4}, NAME:{}".format(id, idx, name)) 45 | node_data[name] = (id,idx) 46 | 47 | links = [] 48 | ts = [] 49 | with open('vis.digraph.allEdges.json') as f: 50 | for chunk in lines_per_n(f, 5): 51 | chunk = chunk.split("\n") 52 | 53 | name_string = chunk[2].split(":")[1] 54 | x = [x.start() for x in re.finditer('\"', name_string)] 55 | from_id, to_id = name_string[x[0]+1:x[1]].split("_") 56 | 57 | time_string = chunk[3].split("ISODate")[1] 58 | x = [x.start() for x in re.finditer('\"', time_string)] 59 | timestamp = getDateTimeFromISO8601String(time_string[x[0]+1:x[1]]) 60 | ts.append(timestamp) 61 | links.append((from_id, to_id, timestamp)) 62 | print (min(ts), max(ts)) 63 | print ("# interactions", len(links)) 64 | links.sort(key =lambda x: x[2]) 65 | 66 | # split edges 67 | SLICE_MONTHS = 2 68 | START_DATE = min(ts) + timedelta(200) 69 | END_DATE = max(ts) - timedelta(200) 70 | print("Spliting Time Interval: \n Start Time : {}, End Time : {}".format(START_DATE, END_DATE)) 71 | 72 | slice_links = defaultdict(lambda: nx.MultiGraph()) 73 | for (a, b, time) in links: 74 | datetime_object = time 75 | if datetime_object > END_DATE: 76 | months_diff = (END_DATE - START_DATE).days//30 77 | else: 78 | months_diff = (datetime_object - START_DATE).days//30 79 | slice_id = months_diff // SLICE_MONTHS 80 | slice_id = max(slice_id, 0) 81 | 82 | if slice_id not in slice_links.keys(): 83 | slice_links[slice_id] = nx.MultiGraph() 84 | if slice_id > 0: 85 | slice_links[slice_id].add_nodes_from(slice_links[slice_id-1].nodes(data=True)) 86 | assert (len(slice_links[slice_id].edges()) ==0) 87 | slice_links[slice_id].add_edge(a,b, date=datetime_object) 88 | 89 | # print statics of each graph 90 | used_nodes = [] 91 | for id, slice in slice_links.items(): 92 | print("In snapshoot {:<2}, #Nodes={:<5}, #Edges={:<5}".format(id, \ 93 | slice.number_of_nodes(), slice.number_of_edges())) 94 | for node in slice.nodes(): 95 | if not node in used_nodes: 96 | used_nodes.append(node) 97 | # remap nodes in graphs. Cause start time is not zero, the node index is not consistent 98 | nodes_consistent_map = {node:idx for idx, node in enumerate(used_nodes)} 99 | for id, slice in slice_links.items(): 100 | slice_links[id] = nx.relabel_nodes(slice, nodes_consistent_map) 101 | 102 | # One-Hot features 103 | onehot = np.identity(slice_links[max(slice_links.keys())].number_of_nodes()) 104 | graphs = [] 105 | for id, slice in slice_links.items(): 106 | tmp_feature = [] 107 | for node in slice.nodes(): 108 | tmp_feature.append(onehot[node]) 109 | slice.graph["feature"] = csr_matrix(tmp_feature) 110 | graphs.append(slice) 111 | 112 | # save 113 | save_path = "../../data/Enron/graph.pkl" 114 | with open(save_path, "wb") as f: 115 | pkl.dump(graphs, f) 116 | print("Processed Data Saved at {}".format(save_path)) 117 | -------------------------------------------------------------------------------- /raw_data/Enron/vis.graph.nodeList.json: -------------------------------------------------------------------------------- 1 | { 2 | "_id" : ObjectId("55098b62251497209062421f"), 3 | "name" : "albert.meyers@enron.com", 4 | "idx" : NumberInt(0) 5 | } 6 | { 7 | "_id" : ObjectId("55098b622514972090624220"), 8 | "name" : "andrea.ring@enron.com", 9 | "idx" : NumberInt(1) 10 | } 11 | { 12 | "_id" : ObjectId("55098b622514972090624221"), 13 | "name" : "andrew.lewis@enron.com", 14 | "idx" : NumberInt(2) 15 | } 16 | { 17 | "_id" : ObjectId("55098b622514972090624222"), 18 | "name" : "andy.zipper@enron.com", 19 | "idx" : NumberInt(3) 20 | } 21 | { 22 | "_id" : ObjectId("55098b622514972090624223"), 23 | "name" : "barry.tycholiz@enron.com", 24 | "idx" : NumberInt(4) 25 | } 26 | { 27 | "_id" : ObjectId("55098b622514972090624224"), 28 | "name" : "benjamin.rogers@enron.com", 29 | "idx" : NumberInt(5) 30 | } 31 | { 32 | "_id" : ObjectId("55098b622514972090624225"), 33 | "name" : "bill.rapp@enron.com", 34 | "idx" : NumberInt(6) 35 | } 36 | { 37 | "_id" : ObjectId("55098b622514972090624226"), 38 | "name" : "bill.williams@enron.com", 39 | "idx" : NumberInt(7) 40 | } 41 | { 42 | "_id" : ObjectId("55098b622514972090624227"), 43 | "name" : "brad.mckay@enron.com", 44 | "idx" : NumberInt(8) 45 | } 46 | { 47 | "_id" : ObjectId("55098b622514972090624228"), 48 | "name" : "brenda.whitehead@enron.com", 49 | "idx" : NumberInt(9) 50 | } 51 | { 52 | "_id" : ObjectId("55098b622514972090624229"), 53 | "name" : "cara.semperger@enron.com", 54 | "idx" : NumberInt(10) 55 | } 56 | { 57 | "_id" : ObjectId("55098b62251497209062422a"), 58 | "name" : "charles.weldon@enron.com", 59 | "idx" : NumberInt(11) 60 | } 61 | { 62 | "_id" : ObjectId("55098b62251497209062422b"), 63 | "name" : "chris.dorland@enron.com", 64 | "idx" : NumberInt(12) 65 | } 66 | { 67 | "_id" : ObjectId("55098b62251497209062422c"), 68 | "name" : "chris.germany@enron.com", 69 | "idx" : NumberInt(13) 70 | } 71 | { 72 | "_id" : ObjectId("55098b62251497209062422d"), 73 | "name" : "clint.dean@enron.com", 74 | "idx" : NumberInt(14) 75 | } 76 | { 77 | "_id" : ObjectId("55098b62251497209062422e"), 78 | "name" : "cooper.richey@enron.com", 79 | "idx" : NumberInt(15) 80 | } 81 | { 82 | "_id" : ObjectId("55098b62251497209062422f"), 83 | "name" : "craig.dean@enron.com", 84 | "idx" : NumberInt(16) 85 | } 86 | { 87 | "_id" : ObjectId("55098b622514972090624230"), 88 | "name" : "dan.hyvl@enron.com", 89 | "idx" : NumberInt(17) 90 | } 91 | { 92 | "_id" : ObjectId("55098b622514972090624231"), 93 | "name" : "dana.davis@enron.com", 94 | "idx" : NumberInt(18) 95 | } 96 | { 97 | "_id" : ObjectId("55098b622514972090624232"), 98 | "name" : "danny.mccarty@enron.com", 99 | "idx" : NumberInt(19) 100 | } 101 | { 102 | "_id" : ObjectId("55098b622514972090624233"), 103 | "name" : "daren.farmer@enron.com", 104 | "idx" : NumberInt(20) 105 | } 106 | { 107 | "_id" : ObjectId("55098b622514972090624234"), 108 | "name" : "darrell.schoolcraft@enron.com", 109 | "idx" : NumberInt(21) 110 | } 111 | { 112 | "_id" : ObjectId("55098b622514972090624235"), 113 | "name" : "darron.giron@enron.com", 114 | "idx" : NumberInt(22) 115 | } 116 | { 117 | "_id" : ObjectId("55098b622514972090624236"), 118 | "name" : "david.delainey@enron.com", 119 | "idx" : NumberInt(23) 120 | } 121 | { 122 | "_id" : ObjectId("55098b622514972090624237"), 123 | "name" : "debra.bailey@enron.com", 124 | "idx" : NumberInt(24) 125 | } 126 | { 127 | "_id" : ObjectId("55098b622514972090624238"), 128 | "name" : "debra.perlingiere@enron.com", 129 | "idx" : NumberInt(25) 130 | } 131 | { 132 | "_id" : ObjectId("55098b622514972090624239"), 133 | "name" : "diana.scholtes@enron.com", 134 | "idx" : NumberInt(26) 135 | } 136 | { 137 | "_id" : ObjectId("55098b62251497209062423a"), 138 | "name" : "don.baughman@enron.com", 139 | "idx" : NumberInt(27) 140 | } 141 | { 142 | "_id" : ObjectId("55098b62251497209062423b"), 143 | "name" : "drew.fossum@enron.com", 144 | "idx" : NumberInt(28) 145 | } 146 | { 147 | "_id" : ObjectId("55098b62251497209062423c"), 148 | "name" : "dutch.quigley@enron.com", 149 | "idx" : NumberInt(29) 150 | } 151 | { 152 | "_id" : ObjectId("55098b62251497209062423d"), 153 | "name" : "elizabeth.sager@enron.com", 154 | "idx" : NumberInt(30) 155 | } 156 | { 157 | "_id" : ObjectId("55098b62251497209062423e"), 158 | "name" : "eric.bass@enron.com", 159 | "idx" : NumberInt(31) 160 | } 161 | { 162 | "_id" : ObjectId("55098b62251497209062423f"), 163 | "name" : "eric.saibi@enron.com", 164 | "idx" : NumberInt(32) 165 | } 166 | { 167 | "_id" : ObjectId("55098b622514972090624240"), 168 | "name" : "errol.mclaughlin@enron.com", 169 | "idx" : NumberInt(33) 170 | } 171 | { 172 | "_id" : ObjectId("55098b622514972090624241"), 173 | "name" : "fletcher.sturm@enron.com", 174 | "idx" : NumberInt(34) 175 | } 176 | { 177 | "_id" : ObjectId("55098b622514972090624242"), 178 | "name" : "frank.ermis@enron.com", 179 | "idx" : NumberInt(35) 180 | } 181 | { 182 | "_id" : ObjectId("55098b622514972090624243"), 183 | "name" : "geir.solberg@enron.com", 184 | "idx" : NumberInt(36) 185 | } 186 | { 187 | "_id" : ObjectId("55098b622514972090624244"), 188 | "name" : "geoff.storey@enron.com", 189 | "idx" : NumberInt(37) 190 | } 191 | { 192 | "_id" : ObjectId("55098b622514972090624245"), 193 | "name" : "gerald.nemec@enron.com", 194 | "idx" : NumberInt(38) 195 | } 196 | { 197 | "_id" : ObjectId("55098b622514972090624246"), 198 | "name" : "greg.whalley@enron.com", 199 | "idx" : NumberInt(39) 200 | } 201 | { 202 | "_id" : ObjectId("55098b622514972090624247"), 203 | "name" : "gretel.smith@enron.com", 204 | "idx" : NumberInt(40) 205 | } 206 | { 207 | "_id" : ObjectId("55098b622514972090624248"), 208 | "name" : "holden.salisbury@enron.com", 209 | "idx" : NumberInt(41) 210 | } 211 | { 212 | "_id" : ObjectId("55098b622514972090624249"), 213 | "name" : "hunter.shively@enron.com", 214 | "idx" : NumberInt(42) 215 | } 216 | { 217 | "_id" : ObjectId("55098b62251497209062424a"), 218 | "name" : "j.harris@enron.com", 219 | "idx" : NumberInt(43) 220 | } 221 | { 222 | "_id" : ObjectId("55098b62251497209062424b"), 223 | "name" : "james.derrick@enron.com", 224 | "idx" : NumberInt(44) 225 | } 226 | { 227 | "_id" : ObjectId("55098b62251497209062424c"), 228 | "name" : "james.steffes@enron.com", 229 | "idx" : NumberInt(45) 230 | } 231 | { 232 | "_id" : ObjectId("55098b62251497209062424d"), 233 | "name" : "jane.tholt@enron.com", 234 | "idx" : NumberInt(46) 235 | } 236 | { 237 | "_id" : ObjectId("55098b62251497209062424e"), 238 | "name" : "jason.williams@enron.com", 239 | "idx" : NumberInt(47) 240 | } 241 | { 242 | "_id" : ObjectId("55098b62251497209062424f"), 243 | "name" : "jason.wolfe@enron.com", 244 | "idx" : NumberInt(48) 245 | } 246 | { 247 | "_id" : ObjectId("55098b622514972090624250"), 248 | "name" : "jay.reitmeyer@enron.com", 249 | "idx" : NumberInt(49) 250 | } 251 | { 252 | "_id" : ObjectId("55098b622514972090624251"), 253 | "name" : "jeff.dasovich@enron.com", 254 | "idx" : NumberInt(50) 255 | } 256 | { 257 | "_id" : ObjectId("55098b622514972090624252"), 258 | "name" : "jeff.king@enron.com", 259 | "idx" : NumberInt(51) 260 | } 261 | { 262 | "_id" : ObjectId("55098b622514972090624253"), 263 | "name" : "jeff.skilling@enron.com", 264 | "idx" : NumberInt(52) 265 | } 266 | { 267 | "_id" : ObjectId("55098b622514972090624254"), 268 | "name" : "jeffrey.hodge@enron.com", 269 | "idx" : NumberInt(53) 270 | } 271 | { 272 | "_id" : ObjectId("55098b622514972090624255"), 273 | "name" : "jeffrey.shankman@enron.com", 274 | "idx" : NumberInt(54) 275 | } 276 | { 277 | "_id" : ObjectId("55098b622514972090624256"), 278 | "name" : "jim.schwieger@enron.com", 279 | "idx" : NumberInt(55) 280 | } 281 | { 282 | "_id" : ObjectId("55098b622514972090624257"), 283 | "name" : "joannie.williamson@enron.com", 284 | "idx" : NumberInt(56) 285 | } 286 | { 287 | "_id" : ObjectId("55098b622514972090624258"), 288 | "name" : "joe.parks@enron.com", 289 | "idx" : NumberInt(57) 290 | } 291 | { 292 | "_id" : ObjectId("55098b622514972090624259"), 293 | "name" : "joe.quenet@enron.com", 294 | "idx" : NumberInt(58) 295 | } 296 | { 297 | "_id" : ObjectId("55098b62251497209062425a"), 298 | "name" : "joe.stepenovitch@enron.com", 299 | "idx" : NumberInt(59) 300 | } 301 | { 302 | "_id" : ObjectId("55098b62251497209062425b"), 303 | "name" : "john.forney@enron.com", 304 | "idx" : NumberInt(60) 305 | } 306 | { 307 | "_id" : ObjectId("55098b62251497209062425c"), 308 | "name" : "john.griffith@enron.com", 309 | "idx" : NumberInt(61) 310 | } 311 | { 312 | "_id" : ObjectId("55098b62251497209062425d"), 313 | "name" : "john.lavorato@enron.com", 314 | "idx" : NumberInt(62) 315 | } 316 | { 317 | "_id" : ObjectId("55098b62251497209062425e"), 318 | "name" : "john.zufferli@enron.com", 319 | "idx" : NumberInt(63) 320 | } 321 | { 322 | "_id" : ObjectId("55098b62251497209062425f"), 323 | "name" : "jonathan.mckay@enron.com", 324 | "idx" : NumberInt(64) 325 | } 326 | { 327 | "_id" : ObjectId("55098b622514972090624260"), 328 | "name" : "juan.hernandez@enron.com", 329 | "idx" : NumberInt(65) 330 | } 331 | { 332 | "_id" : ObjectId("55098b622514972090624261"), 333 | "name" : "judy.hernandez@enron.com", 334 | "idx" : NumberInt(66) 335 | } 336 | { 337 | "_id" : ObjectId("55098b622514972090624262"), 338 | "name" : "judy.townsend@enron.com", 339 | "idx" : NumberInt(67) 340 | } 341 | { 342 | "_id" : ObjectId("55098b622514972090624263"), 343 | "name" : "kam.keiser@enron.com", 344 | "idx" : NumberInt(68) 345 | } 346 | { 347 | "_id" : ObjectId("55098b622514972090624264"), 348 | "name" : "kate.symes@enron.com", 349 | "idx" : NumberInt(69) 350 | } 351 | { 352 | "_id" : ObjectId("55098b622514972090624265"), 353 | "name" : "kay.mann@enron.com", 354 | "idx" : NumberInt(70) 355 | } 356 | { 357 | "_id" : ObjectId("55098b622514972090624266"), 358 | "name" : "keith.holst@enron.com", 359 | "idx" : NumberInt(71) 360 | } 361 | { 362 | "_id" : ObjectId("55098b622514972090624267"), 363 | "name" : "kenneth.lay@enron.com", 364 | "idx" : NumberInt(72) 365 | } 366 | { 367 | "_id" : ObjectId("55098b622514972090624268"), 368 | "name" : "kevin.hyatt@enron.com", 369 | "idx" : NumberInt(73) 370 | } 371 | { 372 | "_id" : ObjectId("55098b622514972090624269"), 373 | "name" : "kevin.presto@enron.com", 374 | "idx" : NumberInt(74) 375 | } 376 | { 377 | "_id" : ObjectId("55098b62251497209062426a"), 378 | "name" : "kevin.ruscitti@enron.com", 379 | "idx" : NumberInt(75) 380 | } 381 | { 382 | "_id" : ObjectId("55098b62251497209062426b"), 383 | "name" : "kim.ward@enron.com", 384 | "idx" : NumberInt(76) 385 | } 386 | { 387 | "_id" : ObjectId("55098b62251497209062426c"), 388 | "name" : "kimberly.watson@enron.com", 389 | "idx" : NumberInt(77) 390 | } 391 | { 392 | "_id" : ObjectId("55098b62251497209062426d"), 393 | "name" : "larry.campbell@enron.com", 394 | "idx" : NumberInt(78) 395 | } 396 | { 397 | "_id" : ObjectId("55098b62251497209062426e"), 398 | "name" : "larry.may@enron.com", 399 | "idx" : NumberInt(79) 400 | } 401 | { 402 | "_id" : ObjectId("55098b62251497209062426f"), 403 | "name" : "lindy.donoho@enron.com", 404 | "idx" : NumberInt(80) 405 | } 406 | { 407 | "_id" : ObjectId("55098b622514972090624270"), 408 | "name" : "lisa.gang@enron.com", 409 | "idx" : NumberInt(81) 410 | } 411 | { 412 | "_id" : ObjectId("55098b622514972090624271"), 413 | "name" : "liz.taylor@enron.com", 414 | "idx" : NumberInt(82) 415 | } 416 | { 417 | "_id" : ObjectId("55098b622514972090624272"), 418 | "name" : "louise.kitchen@enron.com", 419 | "idx" : NumberInt(83) 420 | } 421 | { 422 | "_id" : ObjectId("55098b622514972090624273"), 423 | "name" : "lynn.blair@enron.com", 424 | "idx" : NumberInt(84) 425 | } 426 | { 427 | "_id" : ObjectId("55098b622514972090624274"), 428 | "name" : "m..smith@enron.com", 429 | "idx" : NumberInt(85) 430 | } 431 | { 432 | "_id" : ObjectId("55098b622514972090624275"), 433 | "name" : "marie.heard@enron.com", 434 | "idx" : NumberInt(86) 435 | } 436 | { 437 | "_id" : ObjectId("55098b622514972090624276"), 438 | "name" : "mark.e.haedicke@enron.com", 439 | "idx" : NumberInt(87) 440 | } 441 | { 442 | "_id" : ObjectId("55098b622514972090624277"), 443 | "name" : "mark.mcconnell@enron.com", 444 | "idx" : NumberInt(88) 445 | } 446 | { 447 | "_id" : ObjectId("55098b622514972090624278"), 448 | "name" : "mark.taylor@enron.com", 449 | "idx" : NumberInt(89) 450 | } 451 | { 452 | "_id" : ObjectId("55098b622514972090624279"), 453 | "name" : "mark.whitt@enron.com", 454 | "idx" : NumberInt(90) 455 | } 456 | { 457 | "_id" : ObjectId("55098b62251497209062427a"), 458 | "name" : "martin.cuilla@enron.com", 459 | "idx" : NumberInt(91) 460 | } 461 | { 462 | "_id" : ObjectId("55098b62251497209062427b"), 463 | "name" : "mary.fischer@enron.com", 464 | "idx" : NumberInt(92) 465 | } 466 | { 467 | "_id" : ObjectId("55098b62251497209062427c"), 468 | "name" : "matt.motley@enron.com", 469 | "idx" : NumberInt(93) 470 | } 471 | { 472 | "_id" : ObjectId("55098b62251497209062427d"), 473 | "name" : "matt.smith@enron.com", 474 | "idx" : NumberInt(94) 475 | } 476 | { 477 | "_id" : ObjectId("55098b62251497209062427e"), 478 | "name" : "matthew.lenhart@enron.com", 479 | "idx" : NumberInt(95) 480 | } 481 | { 482 | "_id" : ObjectId("55098b62251497209062427f"), 483 | "name" : "michele.lokay@enron.com", 484 | "idx" : NumberInt(96) 485 | } 486 | { 487 | "_id" : ObjectId("55098b622514972090624280"), 488 | "name" : "michelle.cash@enron.com", 489 | "idx" : NumberInt(97) 490 | } 491 | { 492 | "_id" : ObjectId("55098b622514972090624281"), 493 | "name" : "mike.carson@enron.com", 494 | "idx" : NumberInt(98) 495 | } 496 | { 497 | "_id" : ObjectId("55098b622514972090624282"), 498 | "name" : "mike.grigsby@enron.com", 499 | "idx" : NumberInt(99) 500 | } 501 | { 502 | "_id" : ObjectId("55098b622514972090624283"), 503 | "name" : "mike.maggi@enron.com", 504 | "idx" : NumberInt(100) 505 | } 506 | { 507 | "_id" : ObjectId("55098b622514972090624284"), 508 | "name" : "mike.mcconnell@enron.com", 509 | "idx" : NumberInt(101) 510 | } 511 | { 512 | "_id" : ObjectId("55098b622514972090624285"), 513 | "name" : "mike.swerzbin@enron.com", 514 | "idx" : NumberInt(102) 515 | } 516 | { 517 | "_id" : ObjectId("55098b622514972090624286"), 518 | "name" : "monika.causholli@enron.com", 519 | "idx" : NumberInt(103) 520 | } 521 | { 522 | "_id" : ObjectId("55098b622514972090624287"), 523 | "name" : "monique.sanchez@enron.com", 524 | "idx" : NumberInt(104) 525 | } 526 | { 527 | "_id" : ObjectId("55098b622514972090624288"), 528 | "name" : "patrice.mims@enron.com", 529 | "idx" : NumberInt(105) 530 | } 531 | { 532 | "_id" : ObjectId("55098b622514972090624289"), 533 | "name" : "paul.thomas@enron.com", 534 | "idx" : NumberInt(106) 535 | } 536 | { 537 | "_id" : ObjectId("55098b62251497209062428a"), 538 | "name" : "peter.keavey@enron.com", 539 | "idx" : NumberInt(107) 540 | } 541 | { 542 | "_id" : ObjectId("55098b62251497209062428b"), 543 | "name" : "phillip.love@enron.com", 544 | "idx" : NumberInt(108) 545 | } 546 | { 547 | "_id" : ObjectId("55098b62251497209062428c"), 548 | "name" : "phillip.platter@enron.com", 549 | "idx" : NumberInt(109) 550 | } 551 | { 552 | "_id" : ObjectId("55098b62251497209062428d"), 553 | "name" : "randall.gay@enron.com", 554 | "idx" : NumberInt(110) 555 | } 556 | { 557 | "_id" : ObjectId("55098b62251497209062428e"), 558 | "name" : "richard.ring@enron.com", 559 | "idx" : NumberInt(111) 560 | } 561 | { 562 | "_id" : ObjectId("55098b62251497209062428f"), 563 | "name" : "richard.sanders@enron.com", 564 | "idx" : NumberInt(112) 565 | } 566 | { 567 | "_id" : ObjectId("55098b622514972090624290"), 568 | "name" : "richard.shapiro@enron.com", 569 | "idx" : NumberInt(113) 570 | } 571 | { 572 | "_id" : ObjectId("55098b622514972090624291"), 573 | "name" : "rick.buy@enron.com", 574 | "idx" : NumberInt(114) 575 | } 576 | { 577 | "_id" : ObjectId("55098b622514972090624292"), 578 | "name" : "rob.gay@enron.com", 579 | "idx" : NumberInt(115) 580 | } 581 | { 582 | "_id" : ObjectId("55098b622514972090624293"), 583 | "name" : "robert.benson@enron.com", 584 | "idx" : NumberInt(116) 585 | } 586 | { 587 | "_id" : ObjectId("55098b622514972090624294"), 588 | "name" : "rod.hayslett@enron.com", 589 | "idx" : NumberInt(117) 590 | } 591 | { 592 | "_id" : ObjectId("55098b622514972090624295"), 593 | "name" : "ryan.slinger@enron.com", 594 | "idx" : NumberInt(118) 595 | } 596 | { 597 | "_id" : ObjectId("55098b622514972090624296"), 598 | "name" : "sally.beck@enron.com", 599 | "idx" : NumberInt(119) 600 | } 601 | { 602 | "_id" : ObjectId("55098b622514972090624297"), 603 | "name" : "sandra.brawner@enron.com", 604 | "idx" : NumberInt(120) 605 | } 606 | { 607 | "_id" : ObjectId("55098b622514972090624298"), 608 | "name" : "sara.shackleton@enron.com", 609 | "idx" : NumberInt(121) 610 | } 611 | { 612 | "_id" : ObjectId("55098b622514972090624299"), 613 | "name" : "scott.hendrickson@enron.com", 614 | "idx" : NumberInt(122) 615 | } 616 | { 617 | "_id" : ObjectId("55098b62251497209062429a"), 618 | "name" : "scott.neal@enron.com", 619 | "idx" : NumberInt(123) 620 | } 621 | { 622 | "_id" : ObjectId("55098b62251497209062429b"), 623 | "name" : "shelley.corman@enron.com", 624 | "idx" : NumberInt(124) 625 | } 626 | { 627 | "_id" : ObjectId("55098b62251497209062429c"), 628 | "name" : "stacy.dickson@enron.com", 629 | "idx" : NumberInt(125) 630 | } 631 | { 632 | "_id" : ObjectId("55098b62251497209062429d"), 633 | "name" : "stanley.horton@enron.com", 634 | "idx" : NumberInt(126) 635 | } 636 | { 637 | "_id" : ObjectId("55098b62251497209062429e"), 638 | "name" : "stephanie.panus@enron.com", 639 | "idx" : NumberInt(127) 640 | } 641 | { 642 | "_id" : ObjectId("55098b62251497209062429f"), 643 | "name" : "steven.kean@enron.com", 644 | "idx" : NumberInt(128) 645 | } 646 | { 647 | "_id" : ObjectId("55098b6225149720906242a0"), 648 | "name" : "steven.south@enron.com", 649 | "idx" : NumberInt(129) 650 | } 651 | { 652 | "_id" : ObjectId("55098b6225149720906242a1"), 653 | "name" : "susan.pereira@enron.com", 654 | "idx" : NumberInt(130) 655 | } 656 | { 657 | "_id" : ObjectId("55098b6225149720906242a2"), 658 | "name" : "susan.scott@enron.com", 659 | "idx" : NumberInt(131) 660 | } 661 | { 662 | "_id" : ObjectId("55098b6225149720906242a3"), 663 | "name" : "t..lucci@enron.com", 664 | "idx" : NumberInt(132) 665 | } 666 | { 667 | "_id" : ObjectId("55098b6225149720906242a4"), 668 | "name" : "tana.jones@enron.com", 669 | "idx" : NumberInt(133) 670 | } 671 | { 672 | "_id" : ObjectId("55098b6225149720906242a5"), 673 | "name" : "teb.lokey@enron.com", 674 | "idx" : NumberInt(134) 675 | } 676 | { 677 | "_id" : ObjectId("55098b6225149720906242a6"), 678 | "name" : "theresa.staab@enron.com", 679 | "idx" : NumberInt(135) 680 | } 681 | { 682 | "_id" : ObjectId("55098b6225149720906242a7"), 683 | "name" : "thomas.martin@enron.com", 684 | "idx" : NumberInt(136) 685 | } 686 | { 687 | "_id" : ObjectId("55098b6225149720906242a8"), 688 | "name" : "tom.donohoe@enron.com", 689 | "idx" : NumberInt(137) 690 | } 691 | { 692 | "_id" : ObjectId("55098b6225149720906242a9"), 693 | "name" : "tori.kuykendall@enron.com", 694 | "idx" : NumberInt(138) 695 | } 696 | { 697 | "_id" : ObjectId("55098b6225149720906242aa"), 698 | "name" : "tracy.geaccone@enron.com", 699 | "idx" : NumberInt(139) 700 | } 701 | { 702 | "_id" : ObjectId("55098b6225149720906242ab"), 703 | "name" : "vince.kaminski@enron.com", 704 | "idx" : NumberInt(140) 705 | } 706 | { 707 | "_id" : ObjectId("55098b6225149720906242ac"), 708 | "name" : "vladi.pimenov@enron.com", 709 | "idx" : NumberInt(141) 710 | } 711 | { 712 | "_id" : ObjectId("55098b6225149720906242ad"), 713 | "name" : "w..white@enron.com", 714 | "idx" : NumberInt(142) 715 | } 716 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dill==0.2.8.1 2 | networkx==2.5 3 | numpy==1.20.1 4 | scikit-learn==0.24.1 5 | scipy==1.6.1 6 | torch==1.7.0 7 | torch-cluster==1.5.8 8 | torch-geometric==1.6.1 9 | torch-scatter==2.0.5 10 | torch-sparse==0.6.8 11 | torch-spline-conv==1.2.0 12 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : train.py 4 | @Time : 2021/02/20 10:25:13 5 | @Author : Fei gao 6 | @Contact : feig@mail.bnu.edu.cn 7 | BNU, Beijing, China 8 | ''' 9 | import argparse 10 | import networkx as nx 11 | import numpy as np 12 | import dill 13 | import pickle as pkl 14 | import scipy 15 | from torch.utils.data import DataLoader 16 | 17 | from utils.preprocess import load_graphs, get_context_pairs, get_evaluation_data 18 | from utils.minibatch import MyDataset 19 | from utils.utilities import to_device 20 | from eval.link_prediction import evaluate_classifier 21 | from models.model import DySAT 22 | 23 | import torch 24 | torch.autograd.set_detect_anomaly(True) 25 | 26 | def inductive_graph(graph_former, graph_later): 27 | """Create the adj_train so that it includes nodes from (t+1) 28 | but only edges from t: this is for the purpose of inductive testing. 29 | 30 | Args: 31 | graph_former ([type]): [description] 32 | graph_later ([type]): [description] 33 | """ 34 | newG = nx.MultiGraph() 35 | newG.add_nodes_from(graph_later.nodes(data=True)) 36 | newG.add_edges_from(graph_former.edges(data=False)) 37 | return newG 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--time_steps', type=int, nargs='?', default=16, 43 | help="total time steps used for train, eval and test") 44 | # Experimental settings. 45 | parser.add_argument('--dataset', type=str, nargs='?', default='Enron', 46 | help='dataset name') 47 | parser.add_argument('--GPU_ID', type=int, nargs='?', default=0, 48 | help='GPU_ID (0/1 etc.)') 49 | parser.add_argument('--epochs', type=int, nargs='?', default=200, 50 | help='# epochs') 51 | parser.add_argument('--val_freq', type=int, nargs='?', default=1, 52 | help='Validation frequency (in epochs)') 53 | parser.add_argument('--test_freq', type=int, nargs='?', default=1, 54 | help='Testing frequency (in epochs)') 55 | parser.add_argument('--batch_size', type=int, nargs='?', default=512, 56 | help='Batch size (# nodes)') 57 | parser.add_argument('--featureless', type=bool, nargs='?', default=True, 58 | help='True if one-hot encoding.') 59 | parser.add_argument("--early_stop", type=int, default=10, 60 | help="patient") 61 | # 1-hot encoding is input as a sparse matrix - hence no scalability issue for large datasets. 62 | # Tunable hyper-params 63 | # TODO: Implementation has not been verified, performance may not be good. 64 | parser.add_argument('--residual', type=bool, nargs='?', default=True, 65 | help='Use residual') 66 | # Number of negative samples per positive pair. 67 | parser.add_argument('--neg_sample_size', type=int, nargs='?', default=10, 68 | help='# negative samples per positive') 69 | # Walk length for random walk sampling. 70 | parser.add_argument('--walk_len', type=int, nargs='?', default=20, 71 | help='Walk length for random walk sampling') 72 | # Weight for negative samples in the binary cross-entropy loss function. 73 | parser.add_argument('--neg_weight', type=float, nargs='?', default=1.0, 74 | help='Weightage for negative samples') 75 | parser.add_argument('--learning_rate', type=float, nargs='?', default=0.01, 76 | help='Initial learning rate for self-attention model.') 77 | parser.add_argument('--spatial_drop', type=float, nargs='?', default=0.1, 78 | help='Spatial (structural) attention Dropout (1 - keep probability).') 79 | parser.add_argument('--temporal_drop', type=float, nargs='?', default=0.5, 80 | help='Temporal attention Dropout (1 - keep probability).') 81 | parser.add_argument('--weight_decay', type=float, nargs='?', default=0.0005, 82 | help='Initial learning rate for self-attention model.') 83 | # Architecture params 84 | parser.add_argument('--structural_head_config', type=str, nargs='?', default='16,8,8', 85 | help='Encoder layer config: # attention heads in each GAT layer') 86 | parser.add_argument('--structural_layer_config', type=str, nargs='?', default='128', 87 | help='Encoder layer config: # units in each GAT layer') 88 | parser.add_argument('--temporal_head_config', type=str, nargs='?', default='16', 89 | help='Encoder layer config: # attention heads in each Temporal layer') 90 | parser.add_argument('--temporal_layer_config', type=str, nargs='?', default='128', 91 | help='Encoder layer config: # units in each Temporal layer') 92 | parser.add_argument('--position_ffn', type=str, nargs='?', default='True', 93 | help='Position wise feedforward') 94 | parser.add_argument('--window', type=int, nargs='?', default=-1, 95 | help='Window for temporal attention (default : -1 => full)') 96 | args = parser.parse_args() 97 | print(args) 98 | 99 | #graphs, feats, adjs = load_graphs(args.dataset) 100 | graphs, adjs = load_graphs(args.dataset) 101 | if args.featureless == True: 102 | feats = [scipy.sparse.identity(adjs[args.time_steps - 1].shape[0]).tocsr()[range(0, x.shape[0]), :] for x in adjs if 103 | x.shape[0] <= adjs[args.time_steps - 1].shape[0]] 104 | 105 | assert args.time_steps <= len(adjs), "Time steps is illegal" 106 | 107 | context_pairs_train = get_context_pairs(graphs, adjs) 108 | 109 | # Load evaluation data for link prediction. 110 | train_edges_pos, train_edges_neg, val_edges_pos, val_edges_neg, \ 111 | test_edges_pos, test_edges_neg = get_evaluation_data(graphs) 112 | print("No. Train: Pos={}, Neg={} \nNo. Val: Pos={}, Neg={} \nNo. Test: Pos={}, Neg={}".format( 113 | len(train_edges_pos), len(train_edges_neg), len(val_edges_pos), len(val_edges_neg), 114 | len(test_edges_pos), len(test_edges_neg))) 115 | 116 | # Create the adj_train so that it includes nodes from (t+1) but only edges from t: this is for the purpose of 117 | # inductive testing. 118 | new_G = inductive_graph(graphs[args.time_steps-2], graphs[args.time_steps-1]) 119 | graphs[args.time_steps-1] = new_G 120 | adjs[args.time_steps-1] = nx.adjacency_matrix(new_G) 121 | 122 | # build dataloader and model 123 | device = torch.device("cuda:0") 124 | dataset = MyDataset(args, graphs, feats, adjs, context_pairs_train) 125 | dataloader = DataLoader(dataset, 126 | batch_size=args.batch_size, 127 | shuffle=True, 128 | num_workers=10, 129 | collate_fn=MyDataset.collate_fn) 130 | #dataloader = NodeMinibatchIterator(args, graphs, feats, adjs, context_pairs_train, device) 131 | model = DySAT(args, feats[0].shape[1], args.time_steps).to(device) 132 | opt = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 133 | 134 | # in training 135 | best_epoch_val = 0 136 | patient = 0 137 | for epoch in range(args.epochs): 138 | model.train() 139 | epoch_loss = [] 140 | for idx, feed_dict in enumerate(dataloader): 141 | feed_dict = to_device(feed_dict, device) 142 | opt.zero_grad() 143 | loss = model.get_loss(feed_dict) 144 | loss.backward() 145 | opt.step() 146 | epoch_loss.append(loss.item()) 147 | 148 | model.eval() 149 | emb = model(feed_dict["graphs"])[:, -2, :].detach().cpu().numpy() 150 | val_results, test_results, _, _ = evaluate_classifier(train_edges_pos, 151 | train_edges_neg, 152 | val_edges_pos, 153 | val_edges_neg, 154 | test_edges_pos, 155 | test_edges_neg, 156 | emb, 157 | emb) 158 | epoch_auc_val = val_results["HAD"][1] 159 | epoch_auc_test = test_results["HAD"][1] 160 | 161 | if epoch_auc_val > best_epoch_val: 162 | best_epoch_val = epoch_auc_val 163 | torch.save(model.state_dict(), "./model_checkpoints/model.pt") 164 | patient = 0 165 | else: 166 | patient += 1 167 | if patient > args.early_stop: 168 | break 169 | 170 | print("Epoch {:<3}, Loss = {:.3f}, Val AUC {:.3f} Test AUC {:.3f}".format(epoch, 171 | np.mean(epoch_loss), 172 | epoch_auc_val, 173 | epoch_auc_test)) 174 | # Test Best Model 175 | model.load_state_dict(torch.load("./model_checkpoints/model.pt")) 176 | model.eval() 177 | emb = model(feed_dict["graphs"])[:, -2, :].detach().cpu().numpy() 178 | val_results, test_results, _, _ = evaluate_classifier(train_edges_pos, 179 | train_edges_neg, 180 | val_edges_pos, 181 | val_edges_neg, 182 | test_edges_pos, 183 | test_edges_neg, 184 | emb, 185 | emb) 186 | auc_val = val_results["HAD"][1] 187 | auc_test = test_results["HAD"][1] 188 | print("Best Test AUC = {:.3f}".format(auc_test)) 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FeiGSSS/DySAT_pytorch/6230d7dcfe83e3f2f85b8b771dc8f57e261004e6/utils/__init__.py -------------------------------------------------------------------------------- /utils/minibatch.py: -------------------------------------------------------------------------------- 1 | from typing import DefaultDict 2 | from collections import defaultdict 3 | from torch.functional import Tensor 4 | from torch_geometric.data import Data 5 | from utils.utilities import fixed_unigram_candidate_sampler 6 | import torch 7 | import numpy as np 8 | import torch_geometric as tg 9 | import scipy.sparse as sp 10 | 11 | 12 | import torch 13 | from torch.utils.data import Dataset 14 | 15 | class MyDataset(Dataset): 16 | def __init__(self, args, graphs, features, adjs, context_pairs): 17 | super(MyDataset, self).__init__() 18 | self.args = args 19 | self.graphs = graphs 20 | self.features = [self._preprocess_features(feat) for feat in features] 21 | self.adjs = [self._normalize_graph_gcn(a) for a in adjs] 22 | self.time_steps = args.time_steps 23 | self.context_pairs = context_pairs 24 | self.max_positive = args.neg_sample_size 25 | self.train_nodes = list(self.graphs[self.time_steps-1].nodes()) # all nodes in the graph. 26 | self.min_t = max(self.time_steps - self.args.window - 1, 0) if args.window > 0 else 0 27 | self.degs = self.construct_degs() 28 | self.pyg_graphs = self._build_pyg_graphs() 29 | self.__createitems__() 30 | 31 | def _normalize_graph_gcn(self, adj): 32 | """GCN-based normalization of adjacency matrix (scipy sparse format). Output is in tuple format""" 33 | adj = sp.coo_matrix(adj, dtype=np.float32) 34 | adj_ = adj + sp.eye(adj.shape[0], dtype=np.float32) 35 | rowsum = np.array(adj_.sum(1), dtype=np.float32) 36 | degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten(), dtype=np.float32) 37 | adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo() 38 | return adj_normalized 39 | 40 | def _preprocess_features(self, features): 41 | """Row-normalize feature matrix and convert to tuple representation""" 42 | features = np.array(features.todense()) 43 | rowsum = np.array(features.sum(1)) 44 | r_inv = np.power(rowsum, -1).flatten() 45 | r_inv[np.isinf(r_inv)] = 0. 46 | r_mat_inv = sp.diags(r_inv) 47 | features = r_mat_inv.dot(features) 48 | return features 49 | 50 | def construct_degs(self): 51 | """ Compute node degrees in each graph snapshot.""" 52 | # different from the original implementation 53 | # degree is counted using multi graph 54 | degs = [] 55 | for i in range(self.min_t, self.time_steps): 56 | G = self.graphs[i] 57 | deg = [] 58 | for nodeid in G.nodes(): 59 | deg.append(G.degree(nodeid)) 60 | degs.append(deg) 61 | return degs 62 | 63 | def _build_pyg_graphs(self): 64 | pyg_graphs = [] 65 | for feat, adj in zip(self.features, self.adjs): 66 | x = torch.Tensor(feat) 67 | edge_index, edge_weight = tg.utils.from_scipy_sparse_matrix(adj) 68 | data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight) 69 | pyg_graphs.append(data) 70 | return pyg_graphs 71 | 72 | def __len__(self): 73 | return len(self.train_nodes) 74 | 75 | def __getitem__(self, index): 76 | node = self.train_nodes[index] 77 | return self.data_items[node] 78 | 79 | def __createitems__(self): 80 | self.data_items = {} 81 | for node in list(self.graphs[self.time_steps-1].nodes()): 82 | feed_dict = {} 83 | node_1_all_time = [] 84 | node_2_all_time = [] 85 | for t in range(self.min_t, self.time_steps): 86 | node_1 = [] 87 | node_2 = [] 88 | if len(self.context_pairs[t][node]) > self.max_positive: 89 | node_1.extend([node]* self.max_positive) 90 | node_2.extend(np.random.choice(self.context_pairs[t][node], self.max_positive, replace=False)) 91 | else: 92 | node_1.extend([node]* len(self.context_pairs[t][node])) 93 | node_2.extend(self.context_pairs[t][node]) 94 | assert len(node_1) == len(node_2) 95 | node_1_all_time.append(node_1) 96 | node_2_all_time.append(node_2) 97 | 98 | node_1_list = [torch.LongTensor(node) for node in node_1_all_time] 99 | node_2_list = [torch.LongTensor(node) for node in node_2_all_time] 100 | node_2_negative = [] 101 | for t in range(len(node_2_list)): 102 | degree = self.degs[t] 103 | node_positive = node_2_list[t][:, None] 104 | node_negative = fixed_unigram_candidate_sampler(true_clasees=node_positive, 105 | num_true=1, 106 | num_sampled=self.args.neg_sample_size, 107 | unique=False, 108 | distortion=0.75, 109 | unigrams=degree) 110 | node_2_negative.append(node_negative) 111 | node_2_neg_list = [torch.LongTensor(node) for node in node_2_negative] 112 | feed_dict['node_1']=node_1_list 113 | feed_dict['node_2']=node_2_list 114 | feed_dict['node_2_neg']=node_2_neg_list 115 | feed_dict["graphs"] = self.pyg_graphs 116 | 117 | self.data_items[node] = feed_dict 118 | 119 | @staticmethod 120 | def collate_fn(samples): 121 | batch_dict = {} 122 | for key in ["node_1", "node_2", "node_2_neg"]: 123 | data_list = [] 124 | for sample in samples: 125 | data_list.append(sample[key]) 126 | concate = [] 127 | for t in range(len(data_list[0])): 128 | concate.append(torch.cat([data[t] for data in data_list])) 129 | batch_dict[key] = concate 130 | batch_dict["graphs"] = samples[0]["graphs"] 131 | return batch_dict 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import dill 4 | import pickle as pkl 5 | import networkx as nx 6 | import scipy.sparse as sp 7 | 8 | from sklearn.model_selection import train_test_split 9 | from utils.utilities import run_random_walks_n2v 10 | 11 | np.random.seed(123) 12 | 13 | def load_graphs(dataset_str): 14 | """Load graph snapshots given the name of dataset""" 15 | with open("data/{}/{}".format(dataset_str, "graph.pkl"), "rb") as f: 16 | graphs = pkl.load(f) 17 | print("Loaded {} graphs ".format(len(graphs))) 18 | adjs = [nx.adjacency_matrix(g) for g in graphs] 19 | return graphs, adjs 20 | 21 | def get_context_pairs(graphs, adjs): 22 | """ Load/generate context pairs for each snapshot through random walk sampling.""" 23 | print("Computing training pairs ...") 24 | context_pairs_train = [] 25 | for i in range(len(graphs)): 26 | context_pairs_train.append(run_random_walks_n2v(graphs[i], adjs[i], num_walks=10, walk_len=20)) 27 | 28 | return context_pairs_train 29 | 30 | def get_evaluation_data(graphs): 31 | """ Load train/val/test examples to evaluate link prediction performance""" 32 | eval_idx = len(graphs) - 2 33 | eval_graph = graphs[eval_idx] 34 | next_graph = graphs[eval_idx+1] 35 | print("Generating eval data ....") 36 | train_edges, train_edges_false, val_edges, val_edges_false, test_edges, test_edges_false = \ 37 | create_data_splits(eval_graph, next_graph, val_mask_fraction=0.2, 38 | test_mask_fraction=0.6) 39 | 40 | return train_edges, train_edges_false, val_edges, val_edges_false, test_edges, test_edges_false 41 | 42 | def create_data_splits(graph, next_graph, val_mask_fraction=0.2, test_mask_fraction=0.6): 43 | edges_next = np.array(list(nx.Graph(next_graph).edges())) 44 | edges_positive = [] # Constraint to restrict new links to existing nodes. 45 | for e in edges_next: 46 | if graph.has_node(e[0]) and graph.has_node(e[1]): 47 | edges_positive.append(e) 48 | edges_positive = np.array(edges_positive) # [E, 2] 49 | edges_negative = negative_sample(edges_positive, graph.number_of_nodes(), next_graph) 50 | 51 | 52 | train_edges_pos, test_pos, train_edges_neg, test_neg = train_test_split(edges_positive, 53 | edges_negative, test_size=val_mask_fraction+test_mask_fraction) 54 | val_edges_pos, test_edges_pos, val_edges_neg, test_edges_neg = train_test_split(test_pos, 55 | test_neg, test_size=test_mask_fraction/(test_mask_fraction+val_mask_fraction)) 56 | 57 | return train_edges_pos, train_edges_neg, val_edges_pos, val_edges_neg, test_edges_pos, test_edges_neg 58 | 59 | def negative_sample(edges_pos, nodes_num, next_graph): 60 | edges_neg = [] 61 | while len(edges_neg) < len(edges_pos): 62 | idx_i = np.random.randint(0, nodes_num) 63 | idx_j = np.random.randint(0, nodes_num) 64 | if idx_i == idx_j: 65 | continue 66 | if next_graph.has_edge(idx_i, idx_j) or next_graph.has_edge(idx_j, idx_i): 67 | continue 68 | if edges_neg: 69 | if [idx_i, idx_j] in edges_neg or [idx_j, idx_i] in edges_neg: 70 | continue 71 | edges_neg.append([idx_i, idx_j]) 72 | return edges_neg -------------------------------------------------------------------------------- /utils/random_walk.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | import random 4 | 5 | # DISCLAIMER: 6 | # Parts of this code file are derived from 7 | # https://github.com/aditya-grover/node2vec 8 | 9 | '''Random walk sampling code''' 10 | 11 | class Graph_RandomWalk(): 12 | def __init__(self, nx_G, is_directed, p, q): 13 | self.G = nx_G 14 | self.is_directed = is_directed 15 | self.p = p 16 | self.q = q 17 | 18 | def node2vec_walk(self, walk_length, start_node): 19 | ''' 20 | Simulate a random walk starting from start node. 21 | ''' 22 | G = self.G 23 | alias_nodes = self.alias_nodes 24 | alias_edges = self.alias_edges 25 | 26 | walk = [start_node] 27 | 28 | while len(walk) < walk_length: 29 | cur = walk[-1] 30 | cur_nbrs = sorted(G.neighbors(cur)) 31 | if len(cur_nbrs) > 0: 32 | if len(walk) == 1: 33 | walk.append(cur_nbrs[alias_draw(alias_nodes[cur][0], alias_nodes[cur][1])]) 34 | else: 35 | prev = walk[-2] 36 | next = cur_nbrs[alias_draw(alias_edges[(prev, cur)][0], 37 | alias_edges[(prev, cur)][1])] 38 | walk.append(next) 39 | else: 40 | break 41 | return walk 42 | 43 | def simulate_walks(self, num_walks, walk_length): 44 | ''' 45 | Repeatedly simulate random walks from each node. 46 | ''' 47 | G = self.G 48 | walks = [] 49 | nodes = list(G.nodes()) 50 | for walk_iter in range(num_walks): 51 | random.shuffle(nodes) 52 | for node in nodes: 53 | walks.append(self.node2vec_walk(walk_length=walk_length, start_node=node)) 54 | 55 | return walks 56 | 57 | def get_alias_edge(self, src, dst): 58 | ''' 59 | Get the alias edge setup lists for a given edge. 60 | ''' 61 | G = self.G 62 | p = self.p 63 | q = self.q 64 | 65 | unnormalized_probs = [] 66 | for dst_nbr in sorted(G.neighbors(dst)): 67 | if dst_nbr == src: 68 | unnormalized_probs.append(G[dst][dst_nbr]['weight']/p) 69 | elif G.has_edge(dst_nbr, src): 70 | unnormalized_probs.append(G[dst][dst_nbr]['weight']) 71 | else: 72 | unnormalized_probs.append(G[dst][dst_nbr]['weight']/q) 73 | norm_const = sum(unnormalized_probs) 74 | normalized_probs = [float(u_prob)/norm_const for u_prob in unnormalized_probs] 75 | 76 | return alias_setup(normalized_probs) 77 | 78 | def preprocess_transition_probs(self): 79 | ''' 80 | Preprocessing of transition probabilities for guiding the random walks. 81 | ''' 82 | G = self.G 83 | is_directed = self.is_directed 84 | 85 | alias_nodes = {} 86 | for node in G.nodes(): 87 | unnormalized_probs = [G[node][nbr]['weight'] for nbr in sorted(G.neighbors(node))] 88 | norm_const = sum(unnormalized_probs) 89 | normalized_probs = [float(u_prob)/norm_const for u_prob in unnormalized_probs] 90 | alias_nodes[node] = alias_setup(normalized_probs) 91 | 92 | alias_edges = {} 93 | triads = {} 94 | 95 | if is_directed: 96 | for edge in G.edges(): 97 | alias_edges[edge] = self.get_alias_edge(edge[0], edge[1]) 98 | else: 99 | for edge in G.edges(): 100 | alias_edges[edge] = self.get_alias_edge(edge[0], edge[1]) 101 | alias_edges[(edge[1], edge[0])] = self.get_alias_edge(edge[1], edge[0]) 102 | 103 | self.alias_nodes = alias_nodes 104 | self.alias_edges = alias_edges 105 | 106 | return 107 | 108 | 109 | def alias_setup(probs): 110 | ''' 111 | Compute utility lists for non-uniform sampling from discrete distributions. 112 | Refer to https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 113 | for details 114 | ''' 115 | K = len(probs) 116 | q = np.zeros(K) 117 | J = np.zeros(K, dtype=np.int) 118 | 119 | smaller = [] 120 | larger = [] 121 | for kk, prob in enumerate(probs): 122 | q[kk] = K*prob 123 | if q[kk] < 1.0: 124 | smaller.append(kk) 125 | else: 126 | larger.append(kk) 127 | 128 | while len(smaller) > 0 and len(larger) > 0: 129 | small = smaller.pop() 130 | large = larger.pop() 131 | 132 | J[small] = large 133 | q[large] = q[large] + q[small] - 1.0 134 | if q[large] < 1.0: 135 | smaller.append(large) 136 | else: 137 | larger.append(large) 138 | 139 | return J, q 140 | 141 | def alias_draw(J, q): 142 | ''' 143 | Draw sample from a non-uniform discrete distribution using alias sampling. 144 | ''' 145 | K = len(J) 146 | 147 | kk = int(np.floor(np.random.rand()*K)) 148 | if np.random.rand() < q[kk]: 149 | return kk 150 | else: 151 | return J[kk] 152 | -------------------------------------------------------------------------------- /utils/utilities.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import copy 4 | import networkx as nx 5 | from collections import defaultdict 6 | from sklearn.preprocessing import MultiLabelBinarizer 7 | from utils.random_walk import Graph_RandomWalk 8 | 9 | import torch 10 | 11 | 12 | """Random walk-based pair generation.""" 13 | 14 | def run_random_walks_n2v(graph, adj, num_walks, walk_len): 15 | """ In: Graph and list of nodes 16 | Out: (target, context) pairs from random walk sampling using 17 | the sampling strategy of node2vec (deepwalk)""" 18 | nx_G = nx.Graph() 19 | for e in graph.edges(): 20 | nx_G.add_edge(e[0], e[1]) 21 | for edge in graph.edges(): 22 | nx_G[edge[0]][edge[1]]['weight'] = adj[edge[0], edge[1]] 23 | 24 | G = Graph_RandomWalk(nx_G, False, 1.0, 1.0) 25 | G.preprocess_transition_probs() 26 | walks = G.simulate_walks(num_walks, walk_len) 27 | WINDOW_SIZE = 10 28 | pairs = defaultdict(list) 29 | pairs_cnt = 0 30 | for walk in walks: 31 | for word_index, word in enumerate(walk): 32 | for nb_word in walk[max(word_index - WINDOW_SIZE, 0): min(word_index + WINDOW_SIZE, len(walk)) + 1]: 33 | if nb_word != word: 34 | pairs[word].append(nb_word) 35 | pairs_cnt += 1 36 | print("# nodes with random walk samples: {}".format(len(pairs))) 37 | print("# sampled pairs: {}".format(pairs_cnt)) 38 | return pairs 39 | 40 | def fixed_unigram_candidate_sampler(true_clasees, 41 | num_true, 42 | num_sampled, 43 | unique, 44 | distortion, 45 | unigrams): 46 | # TODO: implementate distortion to unigrams 47 | assert true_clasees.shape[1] == num_true 48 | samples = [] 49 | for i in range(true_clasees.shape[0]): 50 | dist = copy.deepcopy(unigrams) 51 | candidate = list(range(len(dist))) 52 | taboo = true_clasees[i].cpu().tolist() 53 | for tabo in sorted(taboo, reverse=True): 54 | candidate.remove(tabo) 55 | dist.pop(tabo) 56 | sample = np.random.choice(candidate, size=num_sampled, replace=unique, p=dist/np.sum(dist)) 57 | samples.append(sample) 58 | return samples 59 | 60 | def to_device(batch, device): 61 | feed_dict = copy.deepcopy(batch) 62 | node_1, node_2, node_2_negative, graphs = feed_dict.values() 63 | # to device 64 | feed_dict["node_1"] = [x.to(device) for x in node_1] 65 | feed_dict["node_2"] = [x.to(device) for x in node_2] 66 | feed_dict["node_2_neg"] = [x.to(device) for x in node_2_negative] 67 | feed_dict["graphs"] = [g.to(device) for g in graphs] 68 | 69 | return feed_dict 70 | 71 | 72 | 73 | 74 | 75 | 76 | --------------------------------------------------------------------------------