├── LICENSE ├── README.md ├── dataset.zip ├── main.py ├── models ├── graphcnn.py └── mlp.py ├── requirements.txt └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Weihua Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # How Powerful are Graph Neural Networks? 2 | 3 | This repository is the official PyTorch implementation of the experiments in the following paper: 4 | 5 | Keyulu Xu*, Weihua Hu*, Jure Leskovec, Stefanie Jegelka. How Powerful are Graph Neural Networks? ICLR 2019. 6 | 7 | [arXiv](https://arxiv.org/abs/1810.00826) [OpenReview](https://openreview.net/forum?id=ryGs6iA5Km) 8 | 9 | If you make use of the code/experiment or GIN algorithm in your work, please cite our paper (Bibtex below). 10 | ``` 11 | @inproceedings{ 12 | xu2018how, 13 | title={How Powerful are Graph Neural Networks?}, 14 | author={Keyulu Xu and Weihua Hu and Jure Leskovec and Stefanie Jegelka}, 15 | booktitle={International Conference on Learning Representations}, 16 | year={2019}, 17 | url={https://openreview.net/forum?id=ryGs6iA5Km}, 18 | } 19 | ``` 20 | 21 | ## Installation 22 | Install PyTorch following the instuctions on the [official website] (https://pytorch.org/). The code has been tested over PyTorch 0.4.1 and 1.0.0 versions. 23 | 24 | Then install the other dependencies. 25 | ``` 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | ## Test run 30 | Unzip the dataset file 31 | ``` 32 | unzip dataset.zip 33 | ``` 34 | 35 | and run 36 | 37 | ``` 38 | python main.py 39 | ``` 40 | 41 | The default parameters are not the best performing-hyper-parameters used to reproduce our results in the paper. Hyper-parameters need to be specified through the commandline arguments. Please refer to our paper for the details of how we set the hyper-parameters. For instance, for the COLLAB and IMDB datasets, you need to add `--degree_as_tag` so that the node degrees are used for input node features. 42 | 43 | To learn hyper-parameters to be specified, please type 44 | ``` 45 | python main.py --help 46 | ``` 47 | 48 | 49 | 50 | ## Cross-validation strategy in the paper 51 | The cross-validation in our paper only uses training and validation sets (no test set) due to small dataset size. Specifically, after obtaining 10 validation curves corresponding to 10 folds, we first took average of validation curves across the 10 folds (thus, we obtain an averaged validation curve), and then selected a single epoch that achieved the maximum averaged validation accuracy. Finally, the standard devision over the 10 folds was computed at the selected epoch. 52 | 53 | -------------------------------------------------------------------------------- /dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weihua916/powerful-gnns/9a2ce8ac3e99278307093a464a95caf0fb04b602/dataset.zip -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import numpy as np 7 | 8 | from tqdm import tqdm 9 | 10 | from util import load_data, separate_data 11 | from models.graphcnn import GraphCNN 12 | 13 | criterion = nn.CrossEntropyLoss() 14 | 15 | def train(args, model, device, train_graphs, optimizer, epoch): 16 | model.train() 17 | 18 | total_iters = args.iters_per_epoch 19 | pbar = tqdm(range(total_iters), unit='batch') 20 | 21 | loss_accum = 0 22 | for pos in pbar: 23 | selected_idx = np.random.permutation(len(train_graphs))[:args.batch_size] 24 | 25 | batch_graph = [train_graphs[idx] for idx in selected_idx] 26 | output = model(batch_graph) 27 | 28 | labels = torch.LongTensor([graph.label for graph in batch_graph]).to(device) 29 | 30 | #compute loss 31 | loss = criterion(output, labels) 32 | 33 | #backprop 34 | if optimizer is not None: 35 | optimizer.zero_grad() 36 | loss.backward() 37 | optimizer.step() 38 | 39 | 40 | loss = loss.detach().cpu().numpy() 41 | loss_accum += loss 42 | 43 | #report 44 | pbar.set_description('epoch: %d' % (epoch)) 45 | 46 | average_loss = loss_accum/total_iters 47 | print("loss training: %f" % (average_loss)) 48 | 49 | return average_loss 50 | 51 | ###pass data to model with minibatch during testing to avoid memory overflow (does not perform backpropagation) 52 | def pass_data_iteratively(model, graphs, minibatch_size = 64): 53 | model.eval() 54 | output = [] 55 | idx = np.arange(len(graphs)) 56 | for i in range(0, len(graphs), minibatch_size): 57 | sampled_idx = idx[i:i+minibatch_size] 58 | if len(sampled_idx) == 0: 59 | continue 60 | output.append(model([graphs[j] for j in sampled_idx]).detach()) 61 | return torch.cat(output, 0) 62 | 63 | def test(args, model, device, train_graphs, test_graphs, epoch): 64 | model.eval() 65 | 66 | output = pass_data_iteratively(model, train_graphs) 67 | pred = output.max(1, keepdim=True)[1] 68 | labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device) 69 | correct = pred.eq(labels.view_as(pred)).sum().cpu().item() 70 | acc_train = correct / float(len(train_graphs)) 71 | 72 | output = pass_data_iteratively(model, test_graphs) 73 | pred = output.max(1, keepdim=True)[1] 74 | labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device) 75 | correct = pred.eq(labels.view_as(pred)).sum().cpu().item() 76 | acc_test = correct / float(len(test_graphs)) 77 | 78 | print("accuracy train: %f test: %f" % (acc_train, acc_test)) 79 | 80 | return acc_train, acc_test 81 | 82 | def main(): 83 | # Training settings 84 | # Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper. 85 | parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification') 86 | parser.add_argument('--dataset', type=str, default="MUTAG", 87 | help='name of dataset (default: MUTAG)') 88 | parser.add_argument('--device', type=int, default=0, 89 | help='which gpu to use if any (default: 0)') 90 | parser.add_argument('--batch_size', type=int, default=32, 91 | help='input batch size for training (default: 32)') 92 | parser.add_argument('--iters_per_epoch', type=int, default=50, 93 | help='number of iterations per each epoch (default: 50)') 94 | parser.add_argument('--epochs', type=int, default=350, 95 | help='number of epochs to train (default: 350)') 96 | parser.add_argument('--lr', type=float, default=0.01, 97 | help='learning rate (default: 0.01)') 98 | parser.add_argument('--seed', type=int, default=0, 99 | help='random seed for splitting the dataset into 10 (default: 0)') 100 | parser.add_argument('--fold_idx', type=int, default=0, 101 | help='the index of fold in 10-fold validation. Should be less then 10.') 102 | parser.add_argument('--num_layers', type=int, default=5, 103 | help='number of layers INCLUDING the input one (default: 5)') 104 | parser.add_argument('--num_mlp_layers', type=int, default=2, 105 | help='number of layers for MLP EXCLUDING the input one (default: 2). 1 means linear model.') 106 | parser.add_argument('--hidden_dim', type=int, default=64, 107 | help='number of hidden units (default: 64)') 108 | parser.add_argument('--final_dropout', type=float, default=0.5, 109 | help='final layer dropout (default: 0.5)') 110 | parser.add_argument('--graph_pooling_type', type=str, default="sum", choices=["sum", "average"], 111 | help='Pooling for over nodes in a graph: sum or average') 112 | parser.add_argument('--neighbor_pooling_type', type=str, default="sum", choices=["sum", "average", "max"], 113 | help='Pooling for over neighboring nodes: sum, average or max') 114 | parser.add_argument('--learn_eps', action="store_true", 115 | help='Whether to learn the epsilon weighting for the center nodes. Does not affect training accuracy though.') 116 | parser.add_argument('--degree_as_tag', action="store_true", 117 | help='let the input node features be the degree of nodes (heuristics for unlabeled graph)') 118 | parser.add_argument('--filename', type = str, default = "", 119 | help='output file') 120 | args = parser.parse_args() 121 | 122 | #set up seeds and gpu device 123 | torch.manual_seed(0) 124 | np.random.seed(0) 125 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 126 | if torch.cuda.is_available(): 127 | torch.cuda.manual_seed_all(0) 128 | 129 | graphs, num_classes = load_data(args.dataset, args.degree_as_tag) 130 | 131 | ##10-fold cross validation. Conduct an experiment on the fold specified by args.fold_idx. 132 | train_graphs, test_graphs = separate_data(graphs, args.seed, args.fold_idx) 133 | 134 | model = GraphCNN(args.num_layers, args.num_mlp_layers, train_graphs[0].node_features.shape[1], args.hidden_dim, num_classes, args.final_dropout, args.learn_eps, args.graph_pooling_type, args.neighbor_pooling_type, device).to(device) 135 | 136 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 137 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) 138 | 139 | 140 | for epoch in range(1, args.epochs + 1): 141 | scheduler.step() 142 | 143 | avg_loss = train(args, model, device, train_graphs, optimizer, epoch) 144 | acc_train, acc_test = test(args, model, device, train_graphs, test_graphs, epoch) 145 | 146 | if not args.filename == "": 147 | with open(args.filename, 'w') as f: 148 | f.write("%f %f %f" % (avg_loss, acc_train, acc_test)) 149 | f.write("\n") 150 | print("") 151 | 152 | print(model.eps) 153 | 154 | 155 | if __name__ == '__main__': 156 | main() 157 | -------------------------------------------------------------------------------- /models/graphcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import sys 6 | sys.path.append("models/") 7 | from mlp import MLP 8 | 9 | class GraphCNN(nn.Module): 10 | def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type, device): 11 | ''' 12 | num_layers: number of layers in the neural networks (INCLUDING the input layer) 13 | num_mlp_layers: number of layers in mlps (EXCLUDING the input layer) 14 | input_dim: dimensionality of input features 15 | hidden_dim: dimensionality of hidden units at ALL layers 16 | output_dim: number of classes for prediction 17 | final_dropout: dropout ratio on the final linear layer 18 | learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether. 19 | neighbor_pooling_type: how to aggregate neighbors (mean, average, or max) 20 | graph_pooling_type: how to aggregate entire nodes in a graph (mean, average) 21 | device: which device to use 22 | ''' 23 | 24 | super(GraphCNN, self).__init__() 25 | 26 | self.final_dropout = final_dropout 27 | self.device = device 28 | self.num_layers = num_layers 29 | self.graph_pooling_type = graph_pooling_type 30 | self.neighbor_pooling_type = neighbor_pooling_type 31 | self.learn_eps = learn_eps 32 | self.eps = nn.Parameter(torch.zeros(self.num_layers-1)) 33 | 34 | ###List of MLPs 35 | self.mlps = torch.nn.ModuleList() 36 | 37 | ###List of batchnorms applied to the output of MLP (input of the final prediction linear layer) 38 | self.batch_norms = torch.nn.ModuleList() 39 | 40 | for layer in range(self.num_layers-1): 41 | if layer == 0: 42 | self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)) 43 | else: 44 | self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)) 45 | 46 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) 47 | 48 | #Linear function that maps the hidden representation at dofferemt layers into a prediction score 49 | self.linears_prediction = torch.nn.ModuleList() 50 | for layer in range(num_layers): 51 | if layer == 0: 52 | self.linears_prediction.append(nn.Linear(input_dim, output_dim)) 53 | else: 54 | self.linears_prediction.append(nn.Linear(hidden_dim, output_dim)) 55 | 56 | 57 | def __preprocess_neighbors_maxpool(self, batch_graph): 58 | ###create padded_neighbor_list in concatenated graph 59 | 60 | #compute the maximum number of neighbors within the graphs in the current minibatch 61 | max_deg = max([graph.max_neighbor for graph in batch_graph]) 62 | 63 | padded_neighbor_list = [] 64 | start_idx = [0] 65 | 66 | 67 | for i, graph in enumerate(batch_graph): 68 | start_idx.append(start_idx[i] + len(graph.g)) 69 | padded_neighbors = [] 70 | for j in range(len(graph.neighbors)): 71 | #add off-set values to the neighbor indices 72 | pad = [n + start_idx[i] for n in graph.neighbors[j]] 73 | #padding, dummy data is assumed to be stored in -1 74 | pad.extend([-1]*(max_deg - len(pad))) 75 | 76 | #Add center nodes in the maxpooling if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether. 77 | if not self.learn_eps: 78 | pad.append(j + start_idx[i]) 79 | 80 | padded_neighbors.append(pad) 81 | padded_neighbor_list.extend(padded_neighbors) 82 | 83 | return torch.LongTensor(padded_neighbor_list) 84 | 85 | 86 | def __preprocess_neighbors_sumavepool(self, batch_graph): 87 | ###create block diagonal sparse matrix 88 | 89 | edge_mat_list = [] 90 | start_idx = [0] 91 | for i, graph in enumerate(batch_graph): 92 | start_idx.append(start_idx[i] + len(graph.g)) 93 | edge_mat_list.append(graph.edge_mat + start_idx[i]) 94 | Adj_block_idx = torch.cat(edge_mat_list, 1) 95 | Adj_block_elem = torch.ones(Adj_block_idx.shape[1]) 96 | 97 | #Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether. 98 | 99 | if not self.learn_eps: 100 | num_node = start_idx[-1] 101 | self_loop_edge = torch.LongTensor([range(num_node), range(num_node)]) 102 | elem = torch.ones(num_node) 103 | Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1) 104 | Adj_block_elem = torch.cat([Adj_block_elem, elem], 0) 105 | 106 | Adj_block = torch.sparse.FloatTensor(Adj_block_idx, Adj_block_elem, torch.Size([start_idx[-1],start_idx[-1]])) 107 | 108 | return Adj_block.to(self.device) 109 | 110 | 111 | def __preprocess_graphpool(self, batch_graph): 112 | ###create sum or average pooling sparse matrix over entire nodes in each graph (num graphs x num nodes) 113 | 114 | start_idx = [0] 115 | 116 | #compute the padded neighbor list 117 | for i, graph in enumerate(batch_graph): 118 | start_idx.append(start_idx[i] + len(graph.g)) 119 | 120 | idx = [] 121 | elem = [] 122 | for i, graph in enumerate(batch_graph): 123 | ###average pooling 124 | if self.graph_pooling_type == "average": 125 | elem.extend([1./len(graph.g)]*len(graph.g)) 126 | 127 | else: 128 | ###sum pooling 129 | elem.extend([1]*len(graph.g)) 130 | 131 | idx.extend([[i, j] for j in range(start_idx[i], start_idx[i+1], 1)]) 132 | elem = torch.FloatTensor(elem) 133 | idx = torch.LongTensor(idx).transpose(0,1) 134 | graph_pool = torch.sparse.FloatTensor(idx, elem, torch.Size([len(batch_graph), start_idx[-1]])) 135 | 136 | return graph_pool.to(self.device) 137 | 138 | def maxpool(self, h, padded_neighbor_list): 139 | ###Element-wise minimum will never affect max-pooling 140 | 141 | dummy = torch.min(h, dim = 0)[0] 142 | h_with_dummy = torch.cat([h, dummy.reshape((1, -1)).to(self.device)]) 143 | pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim = 1)[0] 144 | return pooled_rep 145 | 146 | 147 | def next_layer_eps(self, h, layer, padded_neighbor_list = None, Adj_block = None): 148 | ###pooling neighboring nodes and center nodes separately by epsilon reweighting. 149 | 150 | if self.neighbor_pooling_type == "max": 151 | ##If max pooling 152 | pooled = self.maxpool(h, padded_neighbor_list) 153 | else: 154 | #If sum or average pooling 155 | pooled = torch.spmm(Adj_block, h) 156 | if self.neighbor_pooling_type == "average": 157 | #If average pooling 158 | degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device)) 159 | pooled = pooled/degree 160 | 161 | #Reweights the center node representation when aggregating it with its neighbors 162 | pooled = pooled + (1 + self.eps[layer])*h 163 | pooled_rep = self.mlps[layer](pooled) 164 | h = self.batch_norms[layer](pooled_rep) 165 | 166 | #non-linearity 167 | h = F.relu(h) 168 | return h 169 | 170 | 171 | def next_layer(self, h, layer, padded_neighbor_list = None, Adj_block = None): 172 | ###pooling neighboring nodes and center nodes altogether 173 | 174 | if self.neighbor_pooling_type == "max": 175 | ##If max pooling 176 | pooled = self.maxpool(h, padded_neighbor_list) 177 | else: 178 | #If sum or average pooling 179 | pooled = torch.spmm(Adj_block, h) 180 | if self.neighbor_pooling_type == "average": 181 | #If average pooling 182 | degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device)) 183 | pooled = pooled/degree 184 | 185 | #representation of neighboring and center nodes 186 | pooled_rep = self.mlps[layer](pooled) 187 | 188 | h = self.batch_norms[layer](pooled_rep) 189 | 190 | #non-linearity 191 | h = F.relu(h) 192 | return h 193 | 194 | 195 | def forward(self, batch_graph): 196 | X_concat = torch.cat([graph.node_features for graph in batch_graph], 0).to(self.device) 197 | graph_pool = self.__preprocess_graphpool(batch_graph) 198 | 199 | if self.neighbor_pooling_type == "max": 200 | padded_neighbor_list = self.__preprocess_neighbors_maxpool(batch_graph) 201 | else: 202 | Adj_block = self.__preprocess_neighbors_sumavepool(batch_graph) 203 | 204 | #list of hidden representation at each layer (including input) 205 | hidden_rep = [X_concat] 206 | h = X_concat 207 | 208 | for layer in range(self.num_layers-1): 209 | if self.neighbor_pooling_type == "max" and self.learn_eps: 210 | h = self.next_layer_eps(h, layer, padded_neighbor_list = padded_neighbor_list) 211 | elif not self.neighbor_pooling_type == "max" and self.learn_eps: 212 | h = self.next_layer_eps(h, layer, Adj_block = Adj_block) 213 | elif self.neighbor_pooling_type == "max" and not self.learn_eps: 214 | h = self.next_layer(h, layer, padded_neighbor_list = padded_neighbor_list) 215 | elif not self.neighbor_pooling_type == "max" and not self.learn_eps: 216 | h = self.next_layer(h, layer, Adj_block = Adj_block) 217 | 218 | hidden_rep.append(h) 219 | 220 | score_over_layer = 0 221 | 222 | #perform pooling over all nodes in each graph in every layer 223 | for layer, h in enumerate(hidden_rep): 224 | pooled_h = torch.spmm(graph_pool, h) 225 | score_over_layer += F.dropout(self.linears_prediction[layer](pooled_h), self.final_dropout, training = self.training) 226 | 227 | return score_over_layer 228 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | ###MLP with lienar output 6 | class MLP(nn.Module): 7 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim): 8 | ''' 9 | num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model. 10 | input_dim: dimensionality of input features 11 | hidden_dim: dimensionality of hidden units at ALL layers 12 | output_dim: number of classes for prediction 13 | device: which device to use 14 | ''' 15 | 16 | super(MLP, self).__init__() 17 | 18 | self.linear_or_not = True #default is linear model 19 | self.num_layers = num_layers 20 | 21 | if num_layers < 1: 22 | raise ValueError("number of layers should be positive!") 23 | elif num_layers == 1: 24 | #Linear model 25 | self.linear = nn.Linear(input_dim, output_dim) 26 | else: 27 | #Multi-layer model 28 | self.linear_or_not = False 29 | self.linears = torch.nn.ModuleList() 30 | self.batch_norms = torch.nn.ModuleList() 31 | 32 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 33 | for layer in range(num_layers - 2): 34 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 35 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 36 | 37 | for layer in range(num_layers - 1): 38 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) 39 | 40 | def forward(self, x): 41 | if self.linear_or_not: 42 | #If linear model 43 | return self.linear(x) 44 | else: 45 | #If MLP 46 | h = x 47 | for layer in range(self.num_layers - 1): 48 | h = F.relu(self.batch_norms[layer](self.linears[layer](h))) 49 | return self.linears[self.num_layers - 1](h) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | numpy 3 | networkx -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import random 4 | import torch 5 | from sklearn.model_selection import StratifiedKFold 6 | 7 | class S2VGraph(object): 8 | def __init__(self, g, label, node_tags=None, node_features=None): 9 | ''' 10 | g: a networkx graph 11 | label: an integer graph label 12 | node_tags: a list of integer node tags 13 | node_features: a torch float tensor, one-hot representation of the tag that is used as input to neural nets 14 | edge_mat: a torch long tensor, contain edge list, will be used to create torch sparse tensor 15 | neighbors: list of neighbors (without self-loop) 16 | ''' 17 | self.label = label 18 | self.g = g 19 | self.node_tags = node_tags 20 | self.neighbors = [] 21 | self.node_features = 0 22 | self.edge_mat = 0 23 | 24 | self.max_neighbor = 0 25 | 26 | 27 | def load_data(dataset, degree_as_tag): 28 | ''' 29 | dataset: name of dataset 30 | test_proportion: ratio of test train split 31 | seed: random seed for random splitting of dataset 32 | ''' 33 | 34 | print('loading data') 35 | g_list = [] 36 | label_dict = {} 37 | feat_dict = {} 38 | 39 | with open('dataset/%s/%s.txt' % (dataset, dataset), 'r') as f: 40 | n_g = int(f.readline().strip()) 41 | for i in range(n_g): 42 | row = f.readline().strip().split() 43 | n, l = [int(w) for w in row] 44 | if not l in label_dict: 45 | mapped = len(label_dict) 46 | label_dict[l] = mapped 47 | g = nx.Graph() 48 | node_tags = [] 49 | node_features = [] 50 | n_edges = 0 51 | for j in range(n): 52 | g.add_node(j) 53 | row = f.readline().strip().split() 54 | tmp = int(row[1]) + 2 55 | if tmp == len(row): 56 | # no node attributes 57 | row = [int(w) for w in row] 58 | attr = None 59 | else: 60 | row, attr = [int(w) for w in row[:tmp]], np.array([float(w) for w in row[tmp:]]) 61 | if not row[0] in feat_dict: 62 | mapped = len(feat_dict) 63 | feat_dict[row[0]] = mapped 64 | node_tags.append(feat_dict[row[0]]) 65 | 66 | if tmp > len(row): 67 | node_features.append(attr) 68 | 69 | n_edges += row[1] 70 | for k in range(2, len(row)): 71 | g.add_edge(j, row[k]) 72 | 73 | if node_features != []: 74 | node_features = np.stack(node_features) 75 | node_feature_flag = True 76 | else: 77 | node_features = None 78 | node_feature_flag = False 79 | 80 | assert len(g) == n 81 | 82 | g_list.append(S2VGraph(g, l, node_tags)) 83 | 84 | #add labels and edge_mat 85 | for g in g_list: 86 | g.neighbors = [[] for i in range(len(g.g))] 87 | for i, j in g.g.edges(): 88 | g.neighbors[i].append(j) 89 | g.neighbors[j].append(i) 90 | degree_list = [] 91 | for i in range(len(g.g)): 92 | g.neighbors[i] = g.neighbors[i] 93 | degree_list.append(len(g.neighbors[i])) 94 | g.max_neighbor = max(degree_list) 95 | 96 | g.label = label_dict[g.label] 97 | 98 | edges = [list(pair) for pair in g.g.edges()] 99 | edges.extend([[i, j] for j, i in edges]) 100 | 101 | deg_list = list(dict(g.g.degree(range(len(g.g)))).values()) 102 | g.edge_mat = torch.LongTensor(edges).transpose(0,1) 103 | 104 | if degree_as_tag: 105 | for g in g_list: 106 | g.node_tags = list(dict(g.g.degree).values()) 107 | 108 | #Extracting unique tag labels 109 | tagset = set([]) 110 | for g in g_list: 111 | tagset = tagset.union(set(g.node_tags)) 112 | 113 | tagset = list(tagset) 114 | tag2index = {tagset[i]:i for i in range(len(tagset))} 115 | 116 | for g in g_list: 117 | g.node_features = torch.zeros(len(g.node_tags), len(tagset)) 118 | g.node_features[range(len(g.node_tags)), [tag2index[tag] for tag in g.node_tags]] = 1 119 | 120 | 121 | print('# classes: %d' % len(label_dict)) 122 | print('# maximum node tag: %d' % len(tagset)) 123 | 124 | print("# data: %d" % len(g_list)) 125 | 126 | return g_list, len(label_dict) 127 | 128 | def separate_data(graph_list, seed, fold_idx): 129 | assert 0 <= fold_idx and fold_idx < 10, "fold_idx must be from 0 to 9." 130 | skf = StratifiedKFold(n_splits=10, shuffle = True, random_state = seed) 131 | 132 | labels = [graph.label for graph in graph_list] 133 | idx_list = [] 134 | for idx in skf.split(np.zeros(len(labels)), labels): 135 | idx_list.append(idx) 136 | train_idx, test_idx = idx_list[fold_idx] 137 | 138 | train_graph_list = [graph_list[i] for i in train_idx] 139 | test_graph_list = [graph_list[i] for i in test_idx] 140 | 141 | return train_graph_list, test_graph_list 142 | 143 | 144 | --------------------------------------------------------------------------------