├── images ├── ok.md ├── BFS.png ├── SeqAdj.png ├── archi_0.png ├── archi_1.png ├── archi_2.png ├── table.png ├── Projection.png └── archi_macro.png ├── args.py ├── VRGC.ipynb ├── model.py ├── README.md ├── train_test_functions.py └── data.py /images/ok.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/BFS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edouardpineau/Variational-Recurrent-Neural-Networks-for-Graph-Classification/HEAD/images/BFS.png -------------------------------------------------------------------------------- /images/SeqAdj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edouardpineau/Variational-Recurrent-Neural-Networks-for-Graph-Classification/HEAD/images/SeqAdj.png -------------------------------------------------------------------------------- /images/archi_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edouardpineau/Variational-Recurrent-Neural-Networks-for-Graph-Classification/HEAD/images/archi_0.png -------------------------------------------------------------------------------- /images/archi_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edouardpineau/Variational-Recurrent-Neural-Networks-for-Graph-Classification/HEAD/images/archi_1.png -------------------------------------------------------------------------------- /images/archi_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edouardpineau/Variational-Recurrent-Neural-Networks-for-Graph-Classification/HEAD/images/archi_2.png -------------------------------------------------------------------------------- /images/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edouardpineau/Variational-Recurrent-Neural-Networks-for-Graph-Classification/HEAD/images/table.png -------------------------------------------------------------------------------- /images/Projection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edouardpineau/Variational-Recurrent-Neural-Networks-for-Graph-Classification/HEAD/images/Projection.png -------------------------------------------------------------------------------- /images/archi_macro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edouardpineau/Variational-Recurrent-Neural-Networks-for-Graph-Classification/HEAD/images/archi_macro.png -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | # ---- Program configuration ---- # 2 | 3 | class Args: 4 | def __init__(self, data_directory='Graph_datasets/', cuda=False, graph_name='ENZYMES'): 5 | """ 6 | Class arguments to initialize the VRGC problem parameters 7 | 8 | :param data_directory: location of the data (under format as downloaded at https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets) 9 | :param cuda: use of CUDA library for GPU-based computation 10 | :param graph_name: name of the graph dataset 11 | """ 12 | 13 | self.data_directory = data_directory 14 | self.cuda = cuda 15 | self.graph_name = graph_name 16 | 17 | self.node_dims = {'MUTAG': 11, 'ENZYMES': 25, 'PROTEINS_full': 80, 'DD': 230, 18 | 'IMDB-BINARY': 134, 'IMDB-MULTI': 88, 'REDDIT-BINARY': 3068, 'REDDIT-MULTI-5K': 88, 'COLLAB':489} 19 | self.num_classes = {'MUTAG': 2, 'ENZYMES': 6, 'PROTEINS_full': 2, 'DD': 2, 20 | 'IMDB-BINARY': 2, 'IMDB-MULTI': 3, 'REDDIT-BINARY': 2, 'REDDIT-MULTI-5K': 5, 'COLLAB':3} 21 | 22 | # dimensions of the neural networks 23 | self.node_dim = self.node_dims[graph_name] 24 | self.num_layers = 2 25 | self.input_size_rnn = self.node_dims[graph_name] # input size for main RNN 26 | self.hidden_size_rnn = int(128) 27 | self.hidden_size_rnn_output = 16 28 | self.embedding_size_rnn = int(64) 29 | self.embedding_size_rnn_output = int(8) 30 | self.embedding_size_output = int(64) 31 | 32 | self.num_class = self.num_classes[graph_name] 33 | 34 | # coefficient of reconstruction loss in the total loss 35 | self.reco_importance = 0.1 36 | 37 | # ---- Training config ---- # 38 | self.loss = None 39 | self.batch_size = 128 40 | self.epochs = 2000 41 | self.epochs_log = 1 42 | 43 | self.lr = 0.001 44 | self.lr_rate = 0.3 45 | -------------------------------------------------------------------------------- /VRGC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "\n", 11 | "from args import *\n", 12 | "from train_test_functions import *\n", 13 | "from model import *\n", 14 | "from data import *" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "args = Args(cuda=torch.cuda.is_available(), graph_name='ENZYMES')" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "args.epochs = 2000\n", 33 | "args.batch_size = 128\n", 34 | "args.reco_importance = 0.1\n", 35 | "args.loss = nn.BCELoss()\n", 36 | "\n", 37 | "os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda)\n", 38 | "print('CUDA', args.cuda)" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "scrolled": true 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "graphs = graph_load_batch(data_directory=args.data_directory, name=args.graph_name)\n", 50 | "\n", 51 | "dataloaders_train, dataloaders_test = create_loaders(graphs, args)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "scrolled": true 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "results = {}\n", 63 | "\n", 64 | "args.num_fold = None\n", 65 | "\n", 66 | "for i in range(10):\n", 67 | "\n", 68 | " print('Fold number: {:.0f}'.format(i+1))\n", 69 | " args.num_fold = i\n", 70 | "\n", 71 | " rnn_embedding = RecurrentEmbedding(input_size=args.node_dim, \n", 72 | " embedding_size=args.embedding_size_rnn,\n", 73 | " hidden_size=args.hidden_size_rnn, \n", 74 | " num_layers=args.num_layers, \n", 75 | " is_cuda=args.cuda)\n", 76 | "\n", 77 | " var = VAR(h_size=args.hidden_size_rnn, \n", 78 | " embedding_size=args.embedding_size_output,\n", 79 | " y_size=args.node_dim, \n", 80 | " is_cuda=args.cuda)\n", 81 | "\n", 82 | " rnn_classifier = RecurrentClassifier(input_size=args.hidden_size_rnn, \n", 83 | " embedding_size=args.embedding_size_rnn,\n", 84 | " hidden_size=args.hidden_size_rnn, \n", 85 | " num_layers=args.num_layers, \n", 86 | " num_class=args.num_class,\n", 87 | " is_cuda=args.cuda)\n", 88 | "\n", 89 | " if args.cuda:\n", 90 | " rnn_embedding = rnn_embedding.cuda()\n", 91 | " var = var.cuda()\n", 92 | " rnn_classifier = rnn_classifier.cuda()\n", 93 | "\n", 94 | " learning_accuracy_test = classifier_train(args, \n", 95 | " dataloaders_train[i], \n", 96 | " dataloaders_test[i], \n", 97 | " rnn_embedding, var, rnn_classifier)\n", 98 | "\n", 99 | " accuracy_test, scores, predicted_labels, true_labels, vote = vote_test(args, \n", 100 | " rnn_embedding, \n", 101 | " var, \n", 102 | " rnn_classifier,\n", 103 | " dataloaders_test[i], \n", 104 | " num_iteration=100)\n", 105 | " \n", 106 | " results[i] = {'rnn': rnn_embedding, 'output': var, 'classifier_1': rnn_classifier,\n", 107 | " 'acc_test': accuracy_test, 'scores': scores}\n", 108 | "\n", 109 | "print([results[r]['acc_test'] for r in results])" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "print(np.mean([results[r]['acc_test'] for r in results]), \n", 119 | " np.std([results[r]['acc_test'] for r in results]))" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [] 128 | } 129 | ], 130 | "metadata": { 131 | "kernelspec": { 132 | "display_name": "Python 3", 133 | "language": "python", 134 | "name": "python3" 135 | }, 136 | "language_info": { 137 | "codemirror_mode": { 138 | "name": "ipython", 139 | "version": 3 140 | }, 141 | "file_extension": ".py", 142 | "mimetype": "text/x-python", 143 | "name": "python", 144 | "nbconvert_exporter": "python", 145 | "pygments_lexer": "ipython3", 146 | "version": "3.6.7" 147 | } 148 | }, 149 | "nbformat": 4, 150 | "nbformat_minor": 2 151 | } 152 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, print_function, division 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 6 | 7 | 8 | class RecurrentEmbedding(nn.Module): 9 | """ 10 | Recurrent embedding function: end of the blue block in the paper 11 | """ 12 | 13 | def __init__(self, input_size, embedding_size, hidden_size, num_layers, is_cuda=False): 14 | super(RecurrentEmbedding, self).__init__() 15 | self.num_layers = num_layers 16 | self.hidden_size = hidden_size 17 | self.is_cuda = is_cuda 18 | self.input = nn.Linear(input_size, embedding_size) 19 | self.rnn = nn.GRU(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True) 20 | 21 | self.relu = nn.ReLU() 22 | 23 | # initialize 24 | 25 | self.hidden = None # need initialize before forward run 26 | 27 | for name, param in self.rnn.named_parameters(): 28 | if 'bias' in name: 29 | nn.init.constant_(param, 0.25) 30 | elif 'weight' in name: 31 | nn.init.xavier_uniform_(param, gain=nn.init.calculate_gain('sigmoid')) 32 | for m in self.modules(): 33 | if isinstance(m, nn.Linear): 34 | m.weight.data =nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu')) 35 | 36 | def init_hidden(self, batch_size): 37 | self.hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size) 38 | if self.is_cuda: 39 | self.hidden = self.hidden.cuda() 40 | 41 | def forward(self, input_raw, pack=False, input_len=None): 42 | input = self.input(input_raw) 43 | input = self.relu(input) 44 | 45 | if pack: 46 | input = pack_padded_sequence(input, input_len, batch_first=True) 47 | output_raw, self.hidden = self.rnn(input, self.hidden) 48 | if pack: 49 | output_raw = pad_packed_sequence(output_raw, batch_first=True)[0] 50 | 51 | return output_raw 52 | 53 | 54 | class VAR(nn.Module): 55 | """ 56 | Variational regularization: green block in the paper 57 | """ 58 | 59 | def __init__(self, h_size, embedding_size, y_size, is_cuda=False): 60 | super(VAR, self).__init__() 61 | self.encode_11 = nn.Linear(h_size, embedding_size) # mu 62 | self.encode_12 = nn.Linear(h_size, embedding_size) # lsgms 63 | 64 | self.decode_1 = nn.Linear(embedding_size, embedding_size) 65 | self.decode_2 = nn.Linear(embedding_size, y_size) # make edge prediction (reconstruct) 66 | self.relu = nn.ReLU() 67 | 68 | self.is_cuda = is_cuda 69 | 70 | for m in self.modules(): 71 | if isinstance(m, nn.Linear): 72 | m.weight.data =nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu')) 73 | 74 | def forward(self, h): 75 | # encoder 76 | z_mu = self.encode_11(h) 77 | z_lsgms = self.encode_12(h) 78 | # reparameterize 79 | z_sgm = z_lsgms.mul(0.5).exp_() 80 | 81 | if self.training: 82 | eps = torch.randn(z_sgm.size()) 83 | if self.is_cuda: 84 | eps = eps.cuda() 85 | 86 | z = eps*z_sgm + z_mu 87 | else: 88 | z = z_mu 89 | # decoder 90 | y = self.decode_1(z) 91 | y = self.relu(y) 92 | y = self.decode_2(y) 93 | return y, z_mu, z_lsgms 94 | 95 | 96 | class RecurrentClassifier(nn.Module): 97 | """ 98 | Recurrent classification: yellow block in the paper 99 | """ 100 | 101 | def __init__(self, input_size, embedding_size, hidden_size, num_layers, num_class=None, is_cuda=False): 102 | super(RecurrentClassifier, self).__init__() 103 | 104 | self.num_layers = num_layers 105 | self.hidden_size = hidden_size 106 | self.is_cuda = is_cuda 107 | 108 | self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True) 109 | self.output = nn.Sequential( 110 | nn.Linear(hidden_size, embedding_size), 111 | nn.ReLU(), 112 | nn.Linear(embedding_size, num_class), 113 | nn.Softmax(dim=1)) 114 | 115 | self.relu = nn.ReLU() 116 | self.hidden = None 117 | 118 | for name, param in self.rnn.named_parameters(): 119 | if 'bias' in name: 120 | nn.init.constant_(param, 0.25) 121 | elif 'weight' in name: 122 | nn.init.xavier_uniform_(param, gain=nn.init.calculate_gain('sigmoid')) 123 | for m in self.modules(): 124 | if isinstance(m, nn.Linear): 125 | m.weight.data =nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu')) 126 | 127 | def init_hidden(self, batch_size): 128 | self.hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size) 129 | if self.is_cuda: 130 | self.hidden = self.hidden.cuda() 131 | 132 | def forward(self, input_raw, pack=True, input_len=None): 133 | sample = input_raw 134 | if pack: 135 | sample = pack_padded_sequence(sample, input_len, batch_first=True) 136 | output_raw, self.hidden = self.rnn(sample, self.hidden) 137 | l_pred = self.output(self.hidden[-1]) 138 | return l_pred, output_raw 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Recurrent Neural Networks for Graph Classification 2 | 3 | This work was presented at the ICLR 2019 workshop session [Representation Learning on Graphs and Manifolds](https://rlgm.github.io/). 4 | 5 | Our paper can be found here: https://rlgm.github.io/papers/9.pdf 6 | 7 | The notebook VRGC.ipynb requires the data that can be downloaded [here](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets). 8 | 9 | ### Abstract 10 | 11 | We address the problem of graph classification based only on structural information. Inspired by natural language processing techniques (NLP), our model sequentially embeds information to estimate class membership probabilities. Besides, we experiment with NLP-like variational regularization techniques, making the model predict the next node in the sequence as it reads it. We experimentally show that our model achieves state-of-the-art classification results on several standard molecular datasets. Finally, we perform a qualitative analysis and give some insights on whether the node prediction helps the model better classify graphs. 12 | 13 | ### Multi-task learning 14 | 15 | Multi-task learning is a powerful leverage to learn rich representation in NLP [1]. We propose to use it for our problem. 16 | 17 |

18 | 19 | #### Graph preprocessing 20 | 21 | We use a BFS node ordering procedure to transform graph into sequence of nodes as in [2]. 22 | 23 | ##### Breadth-first search with random root R for graph enumeration 24 | 25 |

26 |

Figure 2: Example of a BFS node ordering.

27 | 28 | ##### Sequencial truncated node adjacency 29 | 30 | Each node is only related to its two closest neighbors in the order of the BFS to get a low dimensional sequence of nodes. 31 | 32 |

33 |

Figure 3: Example of a BFS node ordering.

34 | 35 | ##### Complete graph-to-sequence embedding 36 | 37 | Each node embedding contains both current and previous node BFS-related adjacency information thanks to RNN memory structure. 38 | 39 |

40 |

Figure 4: Example of a BFS node ordering.

41 | 42 | #### Recurrent neural network for sequence classification 43 | 44 | The fully connected (FC) classifier is fed with sequence of the truncated BFS-ordered embedded node sequence. 45 | 46 |

47 |

Figure 5: Recurrent classifier for for sequence classification.

48 | 49 | #### Variational autoregressive (VAR) node prediction 50 | 51 | A node prediction task is added to help the classifier. The task is performed by a variational autoencoder feed with the same sequence of embedded nodes than the recurrent classifier. 52 | 53 |

54 |

Figure 6: Variational autoregressive node prediction.

55 | 56 | 57 | ### Results 58 | 59 | - VRGC is not structurally invariant to node indexing, like in our previous [graph classification work](https://github.com/edouardpineau/A-simple-baseline-algorithm-for-graph-classification). However our model learns node indexing invariance from numerous training iterations on randomly-rooted BFS-ordered sequential graph embedding. 60 | 61 |

62 |

Figure 7: TSNE projection of the latent state preceding classification for five graphs from four distinct classes, each initiated with 20 different BFS. Colors and markers represent the respective classes of the graphs.

63 | 64 |   65 | 66 | - VAR helps the model finding a more meaningful latent representation for classification while graph dataset becomes larger, with marginal extra computational cost with respect to RNNs. 67 | 68 |

69 |

Figure 8: Classification results.

70 | 71 | ### Citing 72 | 73 | @article{pineau2019variational, 74 | title={Variational recurrent neural networks for graph classification}, 75 | author={Pineau, Edouard and de Lara, Nathan}, 76 | journal={Representation Learning on Graphs and Manifolds Workshop (ICLR)}, 77 | year={2019} 78 | } 79 | 80 | 81 | ### Datasets and references 82 | 83 | All datasets can be found here: https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets 84 | 85 | [1] V. Sanh, T. Wolf, and S. Ruder. A hierarchical multi-task approach for learning embeddings from semantic tasks. arXiv preprint arXiv:1811.06031, 2018. 86 | 87 | [2] GraphRNN: Generating Realistic Graphs with Deep Auto-regressive Models, Jiaxuan You, Rex Ying, Xiang Ren, William Hamilton, Jure Leskovec ; Proceedings of the 35th International Conference on Machine Learning, PMLR 80:5708-5717, 2018. 88 | -------------------------------------------------------------------------------- /train_test_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | # import torch.nn as nn 4 | import torch.nn.functional as functional 5 | 6 | from torch import optim 7 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 8 | 9 | # REMARK: the functions 'pad_packed_sequence' and 'pack_padded_sequence' enable to deal 10 | # with sequences of different length in recurrent neural networks 11 | 12 | 13 | def train_vrgc_epoch(epoch, args, rnn_embedding, var, rnn_classifier, dataloader_train, 14 | optimizer_rnn, optimizer_var, optimizer_classifier): 15 | """ 16 | Training procedure for the VRGC model (rnn+var+classifier) 17 | 18 | :param epoch: number of training epochs 19 | :param args: arguments of the problem 20 | :param rnn_embedding: recurrent embedding 21 | :param var: variational regularizer 22 | :param rnn_classifier: recurrent classifier 23 | :param dataloader_train: train set loader 24 | :param optimizer_rnn: recurrent embedding Adam optimizer 25 | :param optimizer_var: variational regularizer Adam optimizer 26 | :param optimizer_classifier: recurrent classifier Adam optimizer 27 | """ 28 | 29 | rnn_embedding.train() 30 | var.train() 31 | rnn_classifier.train() 32 | 33 | loss_sum, loss_c_sum, accuracy = 0, 0, 0 34 | tot_data = 0 35 | 36 | for batch_idx, data in enumerate(dataloader_train): 37 | x_unsorted = data['x'].float() 38 | y_unsorted = data['y'].float() 39 | 40 | label_unsorted = data['l'].long() 41 | y_len_unsorted = data['len'] 42 | y_len_max = max(y_len_unsorted) 43 | 44 | x_unsorted = x_unsorted[:, :y_len_max, :] 45 | y_unsorted = y_unsorted[:, :y_len_max, :] 46 | 47 | # Sort input 48 | 49 | y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) 50 | y_len = y_len.numpy().tolist() 51 | 52 | x = torch.index_select(x_unsorted, 0, sort_index) 53 | y = torch.index_select(y_unsorted, 0, sort_index) 54 | 55 | labels = torch.index_select(label_unsorted, 0, sort_index) 56 | 57 | if args.cuda: 58 | x = x.cuda() 59 | y = y.cuda() 60 | labels = labels.cuda() 61 | 62 | rnn_embedding.zero_grad() 63 | var.zero_grad() 64 | rnn_classifier.zero_grad() 65 | 66 | rnn_embedding.init_hidden(batch_size=x_unsorted.size(0)) 67 | rnn_classifier.init_hidden(batch_size=x_unsorted.size(0)) 68 | 69 | h = rnn_embedding(x, pack=True, input_len=y_len) 70 | y_pred, z_mu, z_logvar = var(h) 71 | 72 | y_pred = torch.sigmoid(y_pred) 73 | y_pred = pack_padded_sequence(y_pred, y_len, batch_first=True) 74 | y_pred = pad_packed_sequence(y_pred, batch_first=True)[0] 75 | 76 | z_mu = pack_padded_sequence(z_mu, y_len, batch_first=True) 77 | z_mu = pad_packed_sequence(z_mu, batch_first=True)[0] 78 | z_logvar = pack_padded_sequence(z_logvar, y_len, batch_first=True) 79 | z_logvar = pad_packed_sequence(z_logvar, batch_first=True)[0] 80 | 81 | rnn_classifier.init_hidden(x.size(0)) 82 | pred_labels, var_raw = rnn_classifier(h, pack=True, input_len=y_len) 83 | loss_classifier = functional.cross_entropy(pred_labels, labels) 84 | 85 | loss_bce = args.loss(y_pred, y) 86 | 87 | loss_kl = -0.5 * torch.sum(1 + z_logvar - z_mu.pow(2) - z_logvar.exp()) 88 | loss_kl /= y.size(0) * y.size(1) * sum(y_len) 89 | 90 | loss = args.reco_importance * (loss_bce + loss_kl) + loss_classifier 91 | 92 | loss.backward() 93 | 94 | optimizer_var.step() 95 | optimizer_rnn.step() 96 | optimizer_classifier.step() 97 | 98 | accuracy += loss_classifier.item() 99 | tot_data += x_unsorted.size(0) 100 | 101 | if epoch % args.epochs_log == 0: 102 | print('Dataset: {}, Epoch: {}/{}, train bce loss: {:.3f}, train kl loss: {:.3f}, classifier loss: {:.3f}' 103 | .format(args.graph_name, epoch, args.epochs, loss_bce.item(), loss_kl.item(), accuracy)) 104 | 105 | 106 | def test_vrgc_epoch(args, rnn_embedding, var, rnn_classifier, dataloader_test): 107 | """ 108 | Accuracy evaluation at test time 109 | 110 | :param args: arguments of the problem 111 | :param rnn_embedding: recurrent embedding 112 | :param var: variational regularization 113 | :param rnn_classifier: recurrent classifier 114 | :param dataloader_test: test set loader 115 | :return: test accuracy 116 | """ 117 | 118 | rnn_embedding.eval() 119 | var.eval() 120 | rnn_classifier.eval() 121 | 122 | loss_sum, accuracy, tot_data = 0, 0, 0 123 | total_predicted_labels, total_labels = [], [] 124 | 125 | for batch_idx, data in enumerate(dataloader_test): 126 | 127 | x_unsorted = data['x'].float() 128 | label_unsorted = data['l'].long() 129 | y_len_unsorted = data['len'] 130 | y_len_max = max(y_len_unsorted) 131 | x_unsorted = x_unsorted[:, :y_len_max, :] 132 | 133 | # Initialize gradients and LSTM hidden state according to batch size 134 | 135 | rnn_embedding.init_hidden(batch_size=x_unsorted.size(0)) 136 | rnn_classifier.init_hidden(batch_size=x_unsorted.size(0)) 137 | 138 | # Sort input 139 | 140 | y_len, sort_index = torch.sort(y_len_unsorted, 0, descending=True) 141 | y_len = y_len.numpy().tolist() 142 | x = torch.index_select(x_unsorted, 0, sort_index) 143 | labels = torch.index_select(label_unsorted, 0, sort_index) 144 | 145 | if args.cuda: 146 | x = x.cuda() 147 | labels = labels.cuda() 148 | 149 | h = rnn_embedding(x, pack=True, input_len=y_len) 150 | 151 | # Standard GRU classification 152 | 153 | rnn_classifier.init_hidden(x.size(0)) 154 | pred_labels, var_raw = rnn_classifier(h, pack=True, input_len=y_len) 155 | 156 | accuracy += torch.sum((labels == pred_labels.topk(1)[-1].squeeze()).float()).item() 157 | tot_data += x_unsorted.size(0) 158 | 159 | total_predicted_labels.append(pred_labels) 160 | total_labels.append(labels) 161 | 162 | return accuracy / tot_data, total_predicted_labels, total_labels 163 | 164 | 165 | def vote_test(args, rnn_embedding, var, rnn_classifier, dataloader_test, num_iteration=10): 166 | """ 167 | Aggregation of the results at test time 168 | 169 | :param args: arguments of the problem 170 | :param rnn_embedding: recurrent embedding 171 | :param var: variational regularization 172 | :param rnn_classifier: recurrent classifier 173 | :param dataloader_test: test set loader 174 | :param num_iteration: number N of times a graph is tested (with random BFS root) 175 | :return: test accuracy 176 | """ 177 | 178 | scores = [] 179 | acc, pred_labels, true_labels = test_vrgc_epoch(args, rnn_embedding, var, rnn_classifier, 180 | dataloader_test) 181 | 182 | vote = torch.cat(pred_labels, dim=0).cpu().data.numpy() 183 | 184 | for _ in np.arange(num_iteration): 185 | acc, pred_labels, true_labels = test_vrgc_epoch(args, rnn_embedding, var, rnn_classifier, dataloader_test) 186 | vote = np.maximum(vote, torch.cat(pred_labels, dim=0).cpu().data.numpy()) 187 | 188 | scores.append(acc) 189 | 190 | predicted_labels = np.argmax(vote, axis=1) 191 | accuracy_vote = np.sum((torch.cat(true_labels, dim=0).cpu().data.numpy() == predicted_labels)) / predicted_labels.shape[0] 192 | 193 | if args.cuda: 194 | del pred_labels 195 | torch.cuda.empty_cache() 196 | 197 | return accuracy_vote, scores, predicted_labels, torch.cat(true_labels), vote 198 | 199 | 200 | def classifier_train(args, dataloader_train, dataloader_test, rnn_embedding, var, rnn_classifier): 201 | """ 202 | 203 | :param args: arguments of the problem 204 | :param dataloader_train: train set loader (90% of the data) 205 | :param dataloader_test: test set loader (10% of the data) 206 | :param rnn_embedding: recurrent embedding 207 | :param var: variational regularizer 208 | :param rnn_classifier: recurrent classifier 209 | :return: all the test accuracies for the 10 folds cross-validation 210 | """ 211 | 212 | epoch = 1 213 | 214 | # Initialize optimizers 215 | 216 | optimizer_rnn = optim.Adam(rnn_embedding.parameters(), lr=args.lr) 217 | optimizer_var = optim.Adam(var.parameters(), lr=args.lr) 218 | optimizer_classifier = optim.Adam(rnn_classifier.parameters(), lr=args.lr) 219 | 220 | # Start main loop 221 | 222 | all_test_losses = [] 223 | 224 | while epoch <= args.epochs: 225 | train_vrgc_epoch(epoch, args, rnn_embedding, var, rnn_classifier, 226 | dataloader_train, optimizer_rnn, optimizer_var, optimizer_classifier) 227 | 228 | if epoch % 50 == 0: 229 | # For the published Github version, we screen the test accuracy every 50 epochs 230 | 231 | accuracy_test, scores, predicted_labels, true_labels, vote = vote_test(args, 232 | rnn_embedding, 233 | var, 234 | rnn_classifier, 235 | dataloader_test, 236 | num_iteration=10) 237 | all_test_losses.append(accuracy_test) 238 | 239 | print('Epoch: {}, Test accuracy: {:.3f}'.format(epoch, accuracy_test)) 240 | 241 | epoch += 1 242 | 243 | return all_test_losses 244 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import torch 4 | import random 5 | from torch.utils import data 6 | from operator import itemgetter 7 | 8 | from sklearn.model_selection import StratifiedKFold 9 | from sklearn.preprocessing import LabelEncoder 10 | 11 | # Some functions of the module data.py are taken from https://github.com/JiaxuanYou/graph-generation 12 | 13 | 14 | def create_loaders(graphs, args): 15 | """ 16 | Returns all train and test loaders for the 10-cross validation classification 17 | 18 | :param graphs: list of graphs in networkx format 19 | :param args: arguments of the problem 20 | :return: train and test loaders for the 10-cross validation classification 21 | """ 22 | random.shuffle(graphs) 23 | 24 | skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=0) 25 | labels = np.array([g.graph['label'] for g in graphs]) 26 | le = LabelEncoder() 27 | labels = le.fit_transform(labels) 28 | skf.get_n_splits(graphs, labels) 29 | 30 | dataloaders_train, dataloaders_test = [], [] 31 | 32 | for train_index, test_index in skf.split(graphs, labels): 33 | graphs_train = itemgetter(*train_index)(graphs) 34 | graphs_test = itemgetter(*test_index)(graphs) 35 | 36 | dataset_train = GraphSequenceSamplerPytorch(graphs_train, node_dim=args.node_dim) 37 | dataset_test = GraphSequenceSamplerPytorch(graphs_test, node_dim=args.node_dim) 38 | 39 | dataloaders_train.append(torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size)) 40 | dataloaders_test.append(torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size)) 41 | 42 | args.num_class = int(np.max([g.graph['label'] for g in graphs]) - np.min([g.graph['label'] for g in graphs]) + 1) 43 | 44 | args.max_num_node = max([graphs[i].number_of_nodes() for i in range(len(graphs))]) 45 | args.max_num_edge = max([graphs[i].number_of_edges() for i in range(len(graphs))]) 46 | args.min_num_edge = min([graphs[i].number_of_edges() for i in range(len(graphs))]) 47 | 48 | # Show graphs statistics 49 | 50 | print('total graph num: {}, training set: {}'.format(len(graphs), len(graphs_train))) 51 | print('max number node: {}'.format(args.max_num_node)) 52 | print('max/min number edge: {}; {}'.format(args.max_num_edge, args.min_num_edge)) 53 | print('max previous node: {}'.format(args.node_dim)) 54 | 55 | return dataloaders_train, dataloaders_test 56 | 57 | 58 | def graph_load_batch(data_directory, name): 59 | """ 60 | Reads graphs from files in a given directory and transforms them into networkx objects 61 | 62 | :param data_directory: data location 63 | :param name: dataset name (prefix of files to read) 64 | :return: list of networkx graphs 65 | """ 66 | 67 | print('Loading graph dataset: ' + str(name)) 68 | graph = nx.Graph() 69 | 70 | # load data 71 | path = data_directory + name + '/' 72 | data_adj = np.loadtxt(path + name + '_A.txt', delimiter=',').astype(int) 73 | 74 | data_graph_indicator = np.loadtxt(path + name + '_graph_indicator.txt', delimiter=',').astype(int) 75 | data_graph_labels = np.loadtxt(path + name + '_graph_labels.txt', delimiter=',').astype(int) 76 | 77 | data_tuple = list(map(tuple, data_adj)) 78 | 79 | # add edges 80 | graph.add_edges_from(data_tuple) 81 | 82 | graph.remove_nodes_from(list(nx.isolates(graph))) 83 | 84 | # split into graphs 85 | graph_num = data_graph_indicator.max() 86 | node_list = np.arange(data_graph_indicator.shape[0]) + 1 87 | graphs = [] 88 | 89 | for i in range(graph_num): 90 | # find the nodes for each graph 91 | nodes = node_list[data_graph_indicator == i + 1] 92 | graph_sub = graph.subgraph(nodes).copy() 93 | graph_sub.graph['label'] = data_graph_labels[i] 94 | 95 | graphs.append(graph_sub) 96 | 97 | print('Loaded') 98 | return graphs 99 | 100 | 101 | def bfs_seq(graph, root): 102 | """ 103 | Get a BFS transformation of a graph 104 | 105 | :param graph: a networkx graph 106 | :param root: a node index 107 | :return: the BFS-ordered node indices 108 | """ 109 | 110 | dictionary = dict(nx.bfs_successors(graph, root)) 111 | to_visit = [root] 112 | output = [root] 113 | level_seq = [0] 114 | level = 1 115 | while len(to_visit) > 0: 116 | next_level = [] 117 | while len(to_visit) > 0: 118 | current = to_visit.pop(0) 119 | neighbor = dictionary.get(current) 120 | if neighbor is not None: 121 | next_level += neighbor 122 | level_seq += [level] * len(neighbor) 123 | output += next_level 124 | to_visit = next_level 125 | level += 1 126 | return output 127 | 128 | 129 | def encode_adj(adjacency, max_prev_node=10): 130 | """ 131 | Transforms an adjacency matrix to be passed as an input to the RNN 132 | 133 | :param adjacency: adjacency matrix of a graph 134 | :param max_prev_node: size of the node representation depth (size kept after truncation) 135 | :return: a sequence of truncated node adjacency 136 | """ 137 | 138 | # pick up lower tri 139 | adjacency = np.tril(adjacency, k=-1) 140 | n_nodes = adjacency.shape[0] 141 | adjacency = adjacency[1:n_nodes, 0:n_nodes - 1] 142 | 143 | # use max_prev_node to truncate 144 | # note: now adj is a (n-1)*(n-1) matrix 145 | adj_output = np.zeros((adjacency.shape[0], max_prev_node)) 146 | for i in range(adjacency.shape[0]): 147 | input_start = max(0, i - max_prev_node + 1) 148 | input_end = i + 1 149 | output_start = max_prev_node + input_start - input_end 150 | output_end = max_prev_node 151 | adj_output[i, output_start:output_end] = adjacency[i, input_start:input_end] 152 | adj_output[i, :] = adj_output[i, :][::-1] # reverse order 153 | 154 | return adj_output 155 | 156 | 157 | class GraphSequenceSamplerPytorch(data.Dataset): 158 | def __init__(self, graph_list, node_dim=None): 159 | """ 160 | 161 | :param graph_list: list of Networkx graph objects 162 | :param node_dim: dimensionality of the truncated node dimensionality 163 | """ 164 | 165 | self.adj_all = [] 166 | self.len_all = [] 167 | self.labels = [] 168 | 169 | for graph in graph_list: 170 | self.adj_all.append(np.asarray(nx.to_numpy_matrix(graph))) 171 | self.len_all.append(graph.number_of_nodes()) 172 | self.labels.append(graph.graph['label']) 173 | 174 | self.labels = [l - np.min(self.labels) for l in self.labels] 175 | self.max_num_node = max(self.len_all) 176 | 177 | if node_dim is None: 178 | print('calculating max previous node, total iteration: {}'.format(20000)) 179 | self.node_dim = max(self.calc_max_prev_node(iter=20000)) 180 | print('max previous node: {}'.format(self.node_dim)) 181 | else: 182 | self.node_dim = node_dim 183 | 184 | def __len__(self): 185 | return len(self.adj_all) 186 | 187 | def __getitem__(self, idx): 188 | adj_copy = self.adj_all[idx].copy() 189 | labels_copy = self.labels[idx].copy() 190 | x_batch = np.zeros((self.max_num_node, self.node_dim)) # here zeros are padded for small graph 191 | x_batch[0, :] = 1 # the first input token is all ones 192 | y_batch = np.zeros((self.max_num_node, self.node_dim)) # here zeros are padded for small graph 193 | 194 | len_batch = adj_copy.shape[0] 195 | x_idx = np.random.permutation(adj_copy.shape[0]) 196 | adj_copy = adj_copy[np.ix_(x_idx, x_idx)] 197 | adj_copy_matrix = np.asmatrix(adj_copy) 198 | graph = nx.from_numpy_matrix(adj_copy_matrix) 199 | 200 | # ---- Definition of the ordering of the nodes ---- # 201 | 202 | start_idx = np.random.randint(adj_copy.shape[0]) 203 | x_idx = np.array(bfs_seq(graph, start_idx)) 204 | adj_copy_ = adj_copy[np.ix_(x_idx, x_idx)] 205 | adj_encoded = encode_adj(adj_copy_.copy(), max_prev_node=self.node_dim) 206 | 207 | # get x and y and adj 208 | # for small graph the rest are zero padded 209 | y_batch[0:adj_encoded.shape[0], :] = adj_encoded 210 | x_batch[1:adj_encoded.shape[0] + 1, :] = adj_encoded 211 | 212 | return {'x': x_batch, 'y': y_batch, 'l': labels_copy, 'len': len_batch} 213 | 214 | def calc_max_prev_node(self, iter=20000, topk=10): 215 | max_prev_node = [] 216 | for i in range(iter): 217 | if i % (iter / 5) == 0: 218 | print('iter {} times'.format(i)) 219 | adj_idx = np.random.randint(len(self.adj_all)) 220 | adj_copy = self.adj_all[adj_idx].copy() 221 | 222 | x_idx = np.random.permutation(adj_copy.shape[0]) 223 | adj_copy = adj_copy[np.ix_(x_idx, x_idx)] 224 | adj_copy_matrix = np.asmatrix(adj_copy) 225 | G = nx.from_numpy_matrix(adj_copy_matrix) 226 | 227 | # BFS 228 | start_idx = np.random.randint(adj_copy.shape[0]) 229 | x_idx = np.array(bfs_seq(G, start_idx)) 230 | adj_copy = adj_copy[np.ix_(x_idx, x_idx)] 231 | 232 | # encode adj 233 | adj_encoded = encode_adj_flexible(adj_copy.copy()) 234 | max_encoded_len = max([len(adj_encoded[i]) for i in range(len(adj_encoded))]) 235 | max_prev_node.append(max_encoded_len) 236 | max_prev_node = sorted(max_prev_node)[-1 * topk:] 237 | return max_prev_node 238 | 239 | 240 | def encode_adj_flexible(adj): 241 | ''' 242 | Remark: used only if node_dim is not already computed 243 | 244 | :param adj: adj matrix 245 | :return: 246 | ''' 247 | # pick up lower tri 248 | adj = np.tril(adj, k=-1) 249 | n = adj.shape[0] 250 | adj = adj[1:n, 0:n - 1] 251 | 252 | adj_output = [] 253 | input_start = 0 254 | for i in range(adj.shape[0]): 255 | input_end = i + 1 256 | adj_slice = adj[i, input_start:input_end] 257 | adj_output.append(adj_slice) 258 | non_zero = np.nonzero(adj_slice)[0] 259 | input_start = input_end - len(adj_slice) + np.amin(non_zero) 260 | 261 | return adj_output 262 | --------------------------------------------------------------------------------