├── LICENSE ├── README.md ├── karate.edgelist └── link_prediction.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Andrew Docherty 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # node2vec_linkprediction 2 | Testing link prediction using Node2Vec 3 | 4 | Installation 5 | ============ 6 | 7 | Requirements: 8 | ------------- 9 | * git 10 | * Python 2.7 11 | * gensim 12 | * networkx 13 | * numpy 14 | * matplotlib 15 | * scikit-learn 16 | * node2vec 17 | 18 | To install using Anaconda: 19 | -------------------------- 20 | 21 | 1) To install on Mac OS or Linux, download and install Anaconda (2 or 3) from the following website: 22 | https://www.continuum.io/downloads 23 | 24 | 2) At a command prompt, create a python 2.7 environment and install required packages: 25 | 26 | conda create -n py27 python=2.7 numpy ipython matplotlib seaborn networkx gensim scikit-learn 27 | 28 | 3) Switch to this environment: 29 | 30 | source activate py27 31 | 32 | 3) Get node2vec python code: 33 | 34 | git clone https://github.com/aditya-grover/node2vec.git 35 | 36 | 4) Copy node2vec.py to link prediction code directory: 37 | 38 | cp node2vec/src/node2vec.py 39 | 40 | Usage 41 | ===== 42 | 43 | To use the link_prediction code, we assume the graph data is saved in the form of an edgelist of node pairs on a seperate line: 44 | 45 | Example edgelist: 46 | 1 2 47 | 3 4 48 | 4 2 49 | 50 | A task must be specified, which is one of: 51 | 52 | * *edgeencoding*: Test the node2vec embedding using different edge functions, and analyse their performance. 53 | 54 | * *sensitivity*: Run a parameter sensitivity test on the node2vec parameters of q, p, r, l, d, and k. 55 | 56 | * *gridsearch*: Run a grid search on the node2vec parameters of q, p. 57 | 58 | For example, to test the edge encodings for the graph AstroPh.edgelist, with averaging over five random walk samplings in node2vec: 59 | 60 | python link_prediction.py edgeembedding --input AstroPh.edgelist --num_experiments 5 61 | 62 | For help on the options, use: 63 | 64 | python link_prediction.py --help 65 | 66 | The default values for the experiments and parameter search settings are in the code link_prediction.py. 67 | -------------------------------------------------------------------------------- /karate.edgelist: -------------------------------------------------------------------------------- 1 | 1 32 2 | 1 22 3 | 1 20 4 | 1 18 5 | 1 14 6 | 1 13 7 | 1 12 8 | 1 11 9 | 1 9 10 | 1 8 11 | 1 7 12 | 1 6 13 | 1 5 14 | 1 4 15 | 1 3 16 | 1 2 17 | 2 31 18 | 2 22 19 | 2 20 20 | 2 18 21 | 2 14 22 | 2 8 23 | 2 4 24 | 2 3 25 | 3 14 26 | 3 9 27 | 3 10 28 | 3 33 29 | 3 29 30 | 3 28 31 | 3 8 32 | 3 4 33 | 4 14 34 | 4 13 35 | 4 8 36 | 5 11 37 | 5 7 38 | 6 17 39 | 6 11 40 | 6 7 41 | 7 17 42 | 9 34 43 | 9 33 44 | 9 33 45 | 10 34 46 | 14 34 47 | 15 34 48 | 15 33 49 | 16 34 50 | 16 33 51 | 19 34 52 | 19 33 53 | 20 34 54 | 21 34 55 | 21 33 56 | 23 34 57 | 23 33 58 | 24 30 59 | 24 34 60 | 24 33 61 | 24 28 62 | 24 26 63 | 25 32 64 | 25 28 65 | 25 26 66 | 26 32 67 | 27 34 68 | 27 30 69 | 28 34 70 | 29 34 71 | 29 32 72 | 30 34 73 | 30 33 74 | 31 34 75 | 31 33 76 | 32 34 77 | 32 33 78 | 33 34 -------------------------------------------------------------------------------- /link_prediction.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | from __future__ import print_function, division 4 | 5 | import pickle 6 | import argparse 7 | import os 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import networkx as nx 11 | import node2vec 12 | from gensim.models import Word2Vec 13 | from sklearn import metrics, model_selection, pipeline 14 | from sklearn.linear_model import LogisticRegression 15 | from sklearn.preprocessing import StandardScaler 16 | 17 | # Default parameters from node2vec paper (and for DeepWalk) 18 | default_params = { 19 | 'log2p': 0, # Parameter p, p = 2**log2p 20 | 'log2q': 0, # Parameter q, q = 2**log2q 21 | 'log2d': 7, # Feature size, dimensions = 2**log2d 22 | 'num_walks': 10, # Number of walks from each node 23 | 'walk_length': 80, # Walk length 24 | 'window_size': 10, # Context size for word2vec 25 | 'edge_function': "hadamard", # Default edge function to use 26 | "prop_pos": 0.5, # Proportion of edges to remove nad use as positive samples 27 | "prop_neg": 0.5, # Number of non-edges to use as negative samples 28 | # (as a proportion of existing edges, same as prop_pos) 29 | } 30 | 31 | parameter_searches = { 32 | 'log2p': (np.arange(-2, 3), '$\log_2 p$'), 33 | 'log2q': (np.arange(-2, 3), '$\log_2 q$'), 34 | 'log2d': (np.arange(4, 9), '$\log_2 d$'), 35 | 'num_walks': (np.arange(6, 21, 2), 'Number of walks, r'), 36 | 'walk_length': (np.arange(40, 101, 10), 'Walk length, l'), 37 | 'window_size': (np.arange(8, 21, 2), 'Context size, k'), 38 | } 39 | 40 | edge_functions = { 41 | "hadamard": lambda a, b: a * b, 42 | "average": lambda a, b: 0.5 * (a + b), 43 | "l1": lambda a, b: np.abs(a - b), 44 | "l2": lambda a, b: np.abs(a - b) ** 2, 45 | } 46 | 47 | 48 | def parse_args(): 49 | ''' 50 | Parses the node2vec arguments. 51 | ''' 52 | parser = argparse.ArgumentParser(description="Run node2vec.") 53 | 54 | parser.add_argument('task', type=str, 55 | help="Task to run, one of 'gridsearch', 'edgeencoding', and 'sensitivity'") 56 | 57 | parser.add_argument('--input', nargs='?', default='karate.edgelist', 58 | help='Input graph path') 59 | 60 | parser.add_argument('--regen', dest='regen', action='store_true', 61 | help='Regenerate random positive/negative links') 62 | 63 | parser.add_argument('--iter', default=1, type=int, 64 | help='Number of epochs in SGD') 65 | 66 | parser.add_argument('--workers', type=int, default=8, 67 | help='Number of parallel workers. Default is 8.') 68 | 69 | parser.add_argument('--num_experiments', type=int, default=5, 70 | help='Number of experiments to average. Default is 5.') 71 | 72 | parser.add_argument('--weighted', dest='weighted', action='store_true', 73 | help='Boolean specifying (un)weighted. Default is unweighted.') 74 | parser.add_argument('--unweighted', dest='unweighted', action='store_false') 75 | parser.set_defaults(weighted=False) 76 | 77 | parser.add_argument('--directed', dest='directed', action='store_true', 78 | help='Graph is (un)directed. Default is undirected.') 79 | parser.add_argument('--undirected', dest='undirected', action='store_false') 80 | parser.set_defaults(directed=False) 81 | 82 | return parser.parse_args() 83 | 84 | 85 | class GraphN2V(node2vec.Graph): 86 | def __init__(self, 87 | nx_G=None, is_directed=False, 88 | prop_pos=0.5, prop_neg=0.5, 89 | workers=1, 90 | random_seed=None): 91 | self.G = nx_G 92 | self.is_directed = is_directed 93 | self.prop_pos = prop_neg 94 | self.prop_neg = prop_pos 95 | self.wvecs = None 96 | self.workers = workers 97 | self._rnd = np.random.RandomState(seed=random_seed) 98 | 99 | def read_graph(self, input, enforce_connectivity=True, weighted=False, directed=False): 100 | ''' 101 | Reads the input network in networkx. 102 | ''' 103 | if weighted: 104 | G = nx.read_edgelist(input, nodetype=int, data=(('weight', float),), create_using=nx.DiGraph()) 105 | else: 106 | G = nx.read_edgelist(input, nodetype=int, create_using=nx.DiGraph()) 107 | for edge in G.edges(): 108 | G.edge[edge[0]][edge[1]]['weight'] = 1 109 | 110 | if not directed: 111 | G = G.to_undirected() 112 | 113 | # Take largest connected subgraph 114 | if enforce_connectivity and not nx.is_connected(G): 115 | G = max(nx.connected_component_subgraphs(G), key=len) 116 | print("Input graph not connected: using largest connected subgraph") 117 | 118 | # Remove nodes with self-edges 119 | # I'm not sure what these imply in the dataset 120 | for se in G.nodes_with_selfloops(): 121 | G.remove_edge(se, se) 122 | 123 | print("Read graph, nodes: %d, edges: %d" % (G.number_of_nodes(), G.number_of_edges())) 124 | self.G = G 125 | 126 | def learn_embeddings(self, walks, dimensions, window_size=10, niter=5): 127 | ''' 128 | Learn embeddings by optimizing the Skipgram objective using SGD. 129 | ''' 130 | # TODO: Python27 only 131 | walks = [map(str, walk) for walk in walks] 132 | model = Word2Vec(walks, 133 | size=dimensions, 134 | window=window_size, 135 | min_count=0, 136 | sg=1, 137 | workers=self.workers, 138 | iter=niter) 139 | self.wvecs = model.wv 140 | 141 | def generate_pos_neg_links(self): 142 | """ 143 | Select random existing edges in the graph to be postive links, 144 | and random non-edges to be negative links. 145 | 146 | Modify graph by removing the postive links. 147 | """ 148 | # Select n edges at random (positive samples) 149 | n_edges = self.G.number_of_edges() 150 | n_nodes = self.G.number_of_nodes() 151 | npos = int(self.prop_pos * n_edges) 152 | nneg = int(self.prop_neg * n_edges) 153 | 154 | if not nx.is_connected(self.G): 155 | raise RuntimeError("Input graph is not connected") 156 | 157 | n_neighbors = [len(self.G.neighbors(v)) for v in self.G.nodes_iter()] 158 | n_non_edges = n_nodes - 1 - np.array(n_neighbors) 159 | 160 | non_edges = [e for e in nx.non_edges(self.G)] 161 | print("Finding %d of %d non-edges" % (nneg, len(non_edges))) 162 | 163 | # Select m pairs of non-edges (negative samples) 164 | rnd_inx = self._rnd.choice(len(non_edges), nneg, replace=False) 165 | neg_edge_list = [non_edges[ii] for ii in rnd_inx] 166 | 167 | if len(neg_edge_list) < nneg: 168 | raise RuntimeWarning( 169 | "Only %d negative edges found" % (len(neg_edge_list)) 170 | ) 171 | 172 | print("Finding %d positive edges of %d total edges" % (npos, n_edges)) 173 | 174 | # Find positive edges, and remove them. 175 | edges = self.G.edges() 176 | pos_edge_list = [] 177 | n_count = 0 178 | n_ignored_count = 0 179 | rnd_inx = self._rnd.permutation(n_edges) 180 | for eii in rnd_inx: 181 | edge = edges[eii] 182 | 183 | # Remove edge from graph 184 | data = self.G[edge[0]][edge[1]] 185 | self.G.remove_edge(*edge) 186 | 187 | # Check if graph is still connected 188 | #TODO: We shouldn't be using a private function for bfs 189 | reachable_from_v1 = nx.connected._plain_bfs(self.G, edge[0]) 190 | if edge[1] not in reachable_from_v1: 191 | self.G.add_edge(*edge, **data) 192 | n_ignored_count += 1 193 | else: 194 | pos_edge_list.append(edge) 195 | print("Found: %d " % (n_count), end="\r") 196 | n_count += 1 197 | 198 | # Exit if we've found npos nodes or we have gone through the whole list 199 | if n_count >= npos: 200 | break 201 | 202 | if len(pos_edge_list) < npos: 203 | raise RuntimeWarning("Only %d positive edges found." % (n_count)) 204 | 205 | self._pos_edge_list = pos_edge_list 206 | self._neg_edge_list = neg_edge_list 207 | 208 | def get_selected_edges(self): 209 | edges = self._pos_edge_list + self._neg_edge_list 210 | labels = np.zeros(len(edges)) 211 | labels[:len(self._pos_edge_list)] = 1 212 | return edges, labels 213 | 214 | def train_embeddings(self, p, q, dimensions, num_walks, walk_length, window_size): 215 | """ 216 | Calculate nodde embedding with specified parameters 217 | :param p: 218 | :param q: 219 | :param dimensions: 220 | :param num_walks: 221 | :param walk_length: 222 | :param window_size: 223 | :return: 224 | """ 225 | self.p = p 226 | self.q = q 227 | self.preprocess_transition_probs() 228 | walks = self.simulate_walks(num_walks, walk_length) 229 | self.learn_embeddings( 230 | walks, dimensions, window_size 231 | ) 232 | 233 | def edges_to_features(self, edge_list, edge_function, dimensions): 234 | """ 235 | Given a list of edge lists and a list of labels, create 236 | an edge feature array using binary_edge_function and 237 | create a label array matching the label in the list to all 238 | edges in the corresponding edge list 239 | 240 | :param edge_function: 241 | Function of two arguments taking the node features and returning 242 | an edge feature of given dimension 243 | :param dimension: 244 | Size of returned edge feature vector, if None defaults to 245 | node feature size. 246 | :param k: 247 | Partition number. If None use all positive & negative edges 248 | :return: 249 | feature_vec (n, dimensions), label_vec (n) 250 | """ 251 | n_tot = len(edge_list) 252 | feature_vec = np.empty((n_tot, dimensions), dtype='f') 253 | 254 | # Iterate over edges 255 | for ii in range(n_tot): 256 | v1, v2 = edge_list[ii] 257 | 258 | # Edge-node features 259 | emb1 = np.asarray(self.wvecs[str(v1)]) 260 | emb2 = np.asarray(self.wvecs[str(v2)]) 261 | 262 | # Calculate edge feature 263 | feature_vec[ii] = edge_function(emb1, emb2) 264 | 265 | return feature_vec 266 | 267 | 268 | def create_train_test_graphs(args): 269 | """ 270 | Create and cache train & test graphs. 271 | Will load from cache if exists unless --regen option is given. 272 | 273 | :param args: 274 | :return: 275 | Gtrain, Gtest: Train & test graphs 276 | """ 277 | # Remove half the edges, and the same number of "negative" edges 278 | prop_pos = default_params['prop_pos'] 279 | prop_neg = default_params['prop_neg'] 280 | 281 | # Create random training and test graphs with different random edge selections 282 | cached_fn = "%s.graph" % (os.path.basename(args.input)) 283 | if os.path.exists(cached_fn) and not args.regen: 284 | print("Loading link prediction graphs from %s" % cached_fn) 285 | with open(cached_fn, 'rb') as f: 286 | cache_data = pickle.load(f) 287 | Gtrain = cache_data['g_train'] 288 | Gtest = cache_data['g_test'] 289 | 290 | else: 291 | print("Regenerating link prediction graphs") 292 | # Train graph embeddings on graph with random links 293 | Gtrain = GraphN2V(is_directed=False, 294 | prop_pos=prop_pos, 295 | prop_neg=prop_neg, 296 | workers=args.workers) 297 | Gtrain.read_graph(args.input, 298 | weighted=args.weighted, 299 | directed=args.directed) 300 | Gtrain.generate_pos_neg_links() 301 | 302 | # Generate a different random graph for testing 303 | Gtest = GraphN2V(is_directed=False, 304 | prop_pos=prop_pos, 305 | prop_neg=prop_neg, 306 | workers = args.workers) 307 | Gtest.read_graph(args.input, 308 | weighted=args.weighted, 309 | directed=args.directed) 310 | Gtest.generate_pos_neg_links() 311 | 312 | # Cache generated graph 313 | cache_data = {'g_train': Gtrain, 'g_test': Gtest} 314 | with open(cached_fn, 'wb') as f: 315 | pickle.dump(cache_data, f) 316 | 317 | return Gtrain, Gtest 318 | 319 | 320 | def test_edge_functions(args): 321 | Gtrain, Gtest = create_train_test_graphs(args) 322 | 323 | p = 2.0**default_params['log2p'] 324 | q = 2.0**default_params['log2q'] 325 | dimensions = 2**default_params['log2d'] 326 | num_walks = default_params['num_walks'] 327 | walk_length = default_params['walk_length'] 328 | window_size = default_params['window_size'] 329 | 330 | # Train and test graphs, with different edges 331 | edges_train, labels_train = Gtrain.get_selected_edges() 332 | edges_test, labels_test = Gtest.get_selected_edges() 333 | 334 | # With fixed test & train graphs (these are expensive to generate) 335 | # we perform k iterations of the algorithm 336 | # TODO: It would be nice if the walks had a settable random seed 337 | aucs = {name: [] for name in edge_functions} 338 | for iter in range(args.num_experiments): 339 | print("Iteration %d of %d" % (iter, args.num_experiments)) 340 | 341 | # Learn embeddings with current parameter values 342 | Gtrain.train_embeddings(p, q, dimensions, num_walks, walk_length, window_size) 343 | Gtest.train_embeddings(p, q, dimensions, num_walks, walk_length, window_size) 344 | 345 | for edge_fn_name, edge_fn in edge_functions.items(): 346 | # Calculate edge embeddings using binary function 347 | edge_features_train = Gtrain.edges_to_features(edges_train, edge_fn, dimensions) 348 | edge_features_test = Gtest.edges_to_features(edges_test, edge_fn, dimensions) 349 | 350 | # Linear classifier 351 | scaler = StandardScaler() 352 | lin_clf = LogisticRegression(C=1) 353 | clf = pipeline.make_pipeline(scaler, lin_clf) 354 | 355 | # Train classifier 356 | clf.fit(edge_features_train, labels_train) 357 | auc_train = metrics.scorer.roc_auc_scorer(clf, edge_features_train, labels_train) 358 | 359 | # Test classifier 360 | auc_test = metrics.scorer.roc_auc_scorer(clf, edge_features_test, labels_test) 361 | aucs[edge_fn_name].append(auc_test) 362 | 363 | print("Edge function test performance (AUC):") 364 | for edge_name in aucs: 365 | auc_mean = np.mean(aucs[edge_name]) 366 | auc_std = np.std(aucs[edge_name]) 367 | print("[%s] mean: %.4g +/- %.3g" % (edge_name, auc_mean, auc_std)) 368 | 369 | return aucs 370 | 371 | 372 | def plot_parameter_sensitivity(args): 373 | # Train and test graphs, with different edges 374 | Gtrain, Gtest = create_train_test_graphs(args) 375 | edges_train, labels_train = Gtrain.get_selected_edges() 376 | edges_test, labels_test = Gtest.get_selected_edges() 377 | 378 | # Setup plot 379 | fig, axes = plt.subplots(2, int(np.ceil(len(parameter_searches)/2))) 380 | axes = axes.ravel() 381 | 382 | # Explore different parameters 383 | for ii, param in enumerate(parameter_searches): 384 | cparams = default_params.copy() 385 | param_values, xlabel = parameter_searches[param] 386 | param_aucs = [] 387 | for pv in param_values: 388 | # Update current parameters and get values for experiment 389 | cparams[param] = pv 390 | p = 2.0**cparams['log2p'] 391 | q = 2.0**cparams['log2q'] 392 | dimensions = 2**cparams['log2d'] 393 | edge_fn = edge_functions[default_params['edge_function']] 394 | num_walks = cparams['num_walks'] 395 | walk_length = cparams['walk_length'] 396 | window_size = cparams['window_size'] 397 | 398 | # With fixed test & train graphs (these are expensive to generate) 399 | # we perform num_experiments iterations of the algorithm, using 400 | # all positive & negative links in both graphs 401 | # TODO: It would be nice if the walks had a settable random seed 402 | cv_aucs = [] 403 | for iter in range(args.num_experiments): 404 | print("Iteration %d of %d" % (iter, args.num_experiments)) 405 | # Learn embeddings with current parameter values 406 | Gtrain.train_embeddings(p, q, dimensions, num_walks, walk_length, window_size) 407 | Gtest.train_embeddings(p, q, dimensions, num_walks, walk_length, window_size) 408 | 409 | # Calculate edge embeddings using binary function 410 | edge_features_train = Gtrain.edges_to_features(edges_train, edge_fn, dimensions) 411 | edge_features_test = Gtest.edges_to_features(edges_test, edge_fn, dimensions) 412 | 413 | # Linear classifier 414 | scaler = StandardScaler() 415 | lin_clf = LogisticRegression(C=1) 416 | clf = pipeline.make_pipeline(scaler, lin_clf) 417 | 418 | # Train classifier 419 | clf.fit(edge_features_train, labels_train) 420 | auc_train = metrics.scorer.roc_auc_scorer(clf, edge_features_train, labels_train) 421 | 422 | # Test classifier 423 | auc_test = metrics.scorer.roc_auc_scorer(clf, edge_features_test, labels_test) 424 | 425 | cv_aucs.append(auc_test) 426 | 427 | print("%s = %.3f; AUC train: %.4g AUC test: %.4g" 428 | % (param, pv, auc_train, auc_test)) 429 | 430 | # Add mean of scores 431 | param_aucs.append(np.mean(cv_aucs)) 432 | 433 | # Plot figure 434 | ax = axes[ii] 435 | ax.plot(param_values, param_aucs, 'r-', marker='s', ms=4) 436 | ax.set_xlabel(xlabel) 437 | ax.set_ylabel('AUC') 438 | 439 | plt.tight_layout() 440 | sens_plot_fn = "sensitivity_%s.png" % (os.path.basename(args.input)) 441 | plt.savefig(sens_plot_fn) 442 | plt.show() 443 | 444 | 445 | def grid_search(args): 446 | Gtrain, Gtest = create_train_test_graphs(args) 447 | num_partitions = args.num_experiments 448 | 449 | # Parameter grid 450 | grid_parameters = ['log2p', 'log2q'] 451 | grid_values = [np.arange(-2,3), np.arange(-2,3)] 452 | grid_shape = [len(p) for p in grid_values] 453 | 454 | # Store values in tensor 455 | grid_aucs = np.zeros(grid_shape + [num_partitions]) 456 | 457 | # Explore different parameters 458 | cparams = default_params.copy() 459 | for grid_inx in np.ndindex(*grid_shape): 460 | for ii, param in enumerate(grid_parameters): 461 | cparams[param] = grid_values[ii][grid_inx[ii]] 462 | 463 | # I'm not sure about this, but it makes plotting things easier 464 | p = 2.0**cparams['log2p'] 465 | q = 2.0**cparams['log2q'] 466 | dimensions = 2**cparams['log2d'] 467 | edge_fn = edge_functions[cparams['edge_function']] 468 | num_walks = cparams['num_walks'] 469 | walk_length = cparams['walk_length'] 470 | window_size = cparams['window_size'] 471 | 472 | # With fixed test & train graphs (these are expensive to generate) 473 | # we perform num_experiments iterations of the algorithm, using 474 | # different sets of links to train & test the linear classifier. 475 | # This really isn't k-fold CV as the embeddings are learned without 476 | # holdout of any data, but it will average over the random walks and 477 | # estimate how the linear classifier generalizes, at least. 478 | partitioner = model_selection.StratifiedKFold(num_partitions, shuffle=True) 479 | edges_all, edge_labels_all = Gtrain.get_selected_edges() 480 | 481 | # Iterate over folds 482 | cv_aucs = [] 483 | iter = 0 484 | for train_inx, test_inx in partitioner.split(edges_all, edge_labels_all): 485 | edges_train = [edges_all[jj] for jj in train_inx] 486 | labels_train = [edge_labels_all[jj] for jj in train_inx] 487 | edges_test = [edges_all[jj] for jj in test_inx] 488 | labels_test = [edge_labels_all[jj] for jj in test_inx] 489 | 490 | # Learn embeddings with current parameter values 491 | Gtrain.train_embeddings(p, q, dimensions, num_walks, walk_length, window_size) 492 | 493 | # Calculate edge embeddings using binary function 494 | edge_features_train = Gtrain.edges_to_features(edges_train, edge_fn, dimensions) 495 | edge_features_test = Gtrain.edges_to_features(edges_test, edge_fn, dimensions) 496 | 497 | # Linear classifier 498 | scaler = StandardScaler() 499 | lin_clf = LogisticRegression(C=1) 500 | clf = pipeline.make_pipeline(scaler, lin_clf) 501 | 502 | # Train & validate classifier 503 | clf.fit(edge_features_train, labels_train) 504 | auc_train = metrics.scorer.roc_auc_scorer(clf, edge_features_train, labels_train) 505 | 506 | # Test classifier 507 | auc_test = metrics.scorer.roc_auc_scorer(clf, edge_features_test, labels_test) 508 | 509 | print("%s; AUC train: %.4g AUC test: %.4g" 510 | % (grid_inx, auc_train, auc_test)) 511 | 512 | # Add to grid scores 513 | grid_aucs[grid_inx + (iter,)] = auc_test 514 | iter += 1 515 | 516 | # Now find the best: 517 | mean_aucs = grid_aucs.mean(axis=-1) 518 | 519 | print("AUC mean:") 520 | print(mean_aucs) 521 | 522 | print("AUC std dev:") 523 | print(grid_aucs.std(axis=-1)) 524 | 525 | if len(grid_values) == 2: 526 | plt.figure() 527 | plt.pcolormesh(grid_values[0], grid_values[1], mean_aucs) 528 | plt.colorbar() 529 | plt.xlabel(grid_parameters[0]) 530 | plt.ylabel(grid_parameters[1]) 531 | plt.show() 532 | 533 | return grid_aucs 534 | 535 | 536 | if __name__ == "__main__": 537 | args = parse_args() 538 | 539 | if args.task is None: 540 | print("Specify task to run: edgeembedding, sensitivity, gridsearch") 541 | exit() 542 | 543 | if args.task.startswith("grid"): 544 | grid_search(args) 545 | 546 | elif args.task.startswith("edge"): 547 | test_edge_functions(args) 548 | 549 | elif args.task.startswith("sens"): 550 | plot_parameter_sensitivity(args) 551 | 552 | --------------------------------------------------------------------------------