├── .gitignore ├── README.md └── source ├── acc.svg ├── loss.svg ├── models ├── classifier.py ├── diffpool.py └── graphsage.py ├── notebooks └── dataset_stats.ipynb ├── train.py └── utils └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### directories 2 | 3 | data/ 4 | **/.idea 5 | **/.ipynb_checkpoints 6 | **/__pycache__ 7 | 8 | 9 | ### files 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Convolution + Pooling NN PyTorch Implementation 2 | 3 | ## Generating a graph classification NN in PyTorch, combining Graph Convolutions and Graph Pooling modules 4 | 5 | Work based on: 6 | 7 | - [GCNs](https://arxiv.org/pdf/1609.02907.pdf) 8 | - [GraphSAGE](https://arxiv.org/pdf/1706.02216.pdf) 9 | - [DiffPool](https://arxiv.org/pdf/1806.08804.pdf) 10 | - [Graph U-net](https://openreview.net/pdf?id=HJePRoAct7) 11 | -------------------------------------------------------------------------------- /source/acc.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | 32 | 33 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 262 | 288 | 309 | 330 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 380 | 389 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 593 | 685 | 686 | 687 | 779 | 780 | 781 | 784 | 785 | 786 | 789 | 790 | 791 | 794 | 795 | 796 | 799 | 800 | 801 | 802 | 803 | 835 | 836 | 837 | 838 | 839 | 840 | 841 | 842 | 843 | 844 | 855 | 856 | 857 | 860 | 861 | 862 | 863 | 864 | 865 | 874 | 880 | 891 | 917 | 937 | 956 | 957 | 958 | 959 | 960 | 961 | 962 | 963 | 964 | 965 | 966 | 967 | 968 | 969 | 970 | 971 | 974 | 975 | 976 | 977 | 978 | 979 | 989 | 1005 | 1038 | 1039 | 1040 | 1041 | 1042 | 1043 | 1044 | 1045 | 1046 | 1047 | 1048 | 1049 | 1050 | 1051 | 1052 | 1053 | 1054 | 1055 | 1056 | 1057 | 1058 | 1059 | -------------------------------------------------------------------------------- /source/loss.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | 32 | 33 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 262 | 288 | 309 | 330 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 643 | 644 | 645 | 737 | 738 | 739 | 742 | 743 | 744 | 747 | 748 | 749 | 752 | 753 | 754 | 757 | 758 | 759 | 760 | 761 | 767 | 798 | 799 | 800 | 801 | 802 | 803 | 804 | 805 | 806 | 807 | 808 | 819 | 820 | 821 | 824 | 825 | 826 | 827 | 828 | 829 | 838 | 870 | 881 | 907 | 927 | 946 | 947 | 948 | 949 | 950 | 951 | 952 | 953 | 954 | 955 | 956 | 957 | 958 | 959 | 960 | 961 | 964 | 965 | 966 | 967 | 968 | 969 | 979 | 995 | 1028 | 1029 | 1030 | 1031 | 1032 | 1033 | 1034 | 1035 | 1036 | 1037 | 1038 | 1039 | 1040 | 1041 | 1042 | 1043 | 1044 | 1045 | 1046 | 1047 | 1048 | 1049 | -------------------------------------------------------------------------------- /source/models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | from models.diffpool import DiffPool 7 | from models.graphsage import GraphSAGE 8 | 9 | class Classifier(nn.Module): 10 | 11 | def __init__(self, device="cuda:0"): 12 | super(Classifier, self).__init__() 13 | self.device = device 14 | self.sage1 = GraphSAGE(7, 14, device=self.device) 15 | self.pool1 = DiffPool(14, 14, device=self.device) 16 | self.sage2 = GraphSAGE(14, 28, device=self.device) 17 | self.pool2 = DiffPool(28, 1, final_layer=True, device=self.device) 18 | self.linear1 = nn.Linear(28, 10) 19 | self.linear2 = nn.Linear(10, 2) 20 | 21 | def forward(self, x, a): 22 | x = self.sage1(x, a) 23 | x, a = self.pool1(x, a) 24 | x = self.sage2(x, a) 25 | x, a = self.pool2(x, a) 26 | y = self.linear1(x.squeeze(0)) 27 | y = self.linear2(y) 28 | return x, a, y 29 | -------------------------------------------------------------------------------- /source/models/diffpool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from models.graphsage import GraphSAGE 6 | 7 | class DiffPool(nn.Module): 8 | 9 | def __init__(self, feature_size, output_dim, device="cuda:0", final_layer=False): 10 | super(DiffPool, self).__init__() 11 | self.device = device 12 | self.feature_size = feature_size 13 | self.output_dim = output_dim 14 | self.embed = GraphSAGE(self.feature_size, self.feature_size, device=self.device) 15 | self.pool = GraphSAGE(self.feature_size, self.output_dim, device=self.device) 16 | self.final_layer = final_layer 17 | 18 | def forward(self, x, a): 19 | z = self.embed(x, a) 20 | if self.final_layer: 21 | s = torch.ones(x.size(0), self.output_dim, device=self.device) 22 | else: 23 | s = F.softmax(self.pool(x, a), dim=1) 24 | x_new = s.t() @ z 25 | a_new = s.t() @ a @ s 26 | return x_new, a_new 27 | 28 | -------------------------------------------------------------------------------- /source/models/graphsage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class GraphSAGE(nn.Module): 6 | 7 | def __init__(self, input_feat, output_feat, device="cuda:0", normalize=True): 8 | super(GraphSAGE, self).__init__() 9 | self.device = device 10 | self.normalize = normalize 11 | self.input_feat = input_feat 12 | self.output_feat = output_feat 13 | self.linear = nn.Linear(self.input_feat, self. output_feat) 14 | self.layer_norm = nn.LayerNorm(self.output_feat) # elementwise_affine=False 15 | nn.init.xavier_uniform_(self.linear.weight) 16 | 17 | def aggregate_convolutional(self, x, a): 18 | eye = torch.eye(a.shape[0], dtype=torch.float, device=self.device) 19 | a = a + eye 20 | h_hat = a @ x 21 | 22 | return h_hat 23 | 24 | def forward(self, x, a): 25 | h_hat = self.aggregate_convolutional(x, a) 26 | h = F.relu(self.linear(h_hat)) 27 | if self.normalize: 28 | # h = F.normalize(h, p=2, dim=1) # Normalize edge embeddings 29 | h = self.layer_norm(h) # Normalize layerwise (mean=0, std=1) 30 | 31 | return h 32 | -------------------------------------------------------------------------------- /source/notebooks/dataset_stats.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import argparse\n", 10 | "\n", 11 | "import torch\n", 12 | "import torch.nn as nn\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "from torch.nn import functional as F\n", 15 | "from torchvision.utils import make_grid\n", 16 | "from torchvision.utils import save_image\n", 17 | "\n", 18 | "import os\n", 19 | "\n", 20 | "import pickle\n", 21 | "from collections import Counter\n", 22 | "import numpy as np" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# ds_name = '../../data/collab.graph' # Without node features\n", 32 | "ds_name = '../../data/mutag.graph' # With node features" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "def load_data(ds_name):\n", 42 | " f = open(ds_name, \"rb\")\n", 43 | " print(\"Found dataset:\", ds_name)\n", 44 | " data = pickle.load(f, encoding=\"latin1\")\n", 45 | " graph_data = data[\"graph\"]\n", 46 | " labels = data[\"labels\"]\n", 47 | " labels = np.array(labels, dtype = np.float)\n", 48 | " return graph_data, labels" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 4, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "Found dataset: ../../data/mutag.graph\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "graphs, labels = load_data(ds_name)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 5, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "Dataset: ../../data/mutag.graph \n", 78 | "Number of Graphs: 188 \n", 79 | "Label distribution: Counter({1.0: 125, -1.0: 63})\n", 80 | "\n", 81 | "Mean #nodes: 17.930851063829788 \n", 82 | "Median #nodes: 17.5 \n", 83 | "Max #nodes: 28 \n", 84 | "Min #nodes: 10 \n", 85 | "Total #nodes: 3371\n", 86 | "\n", 87 | "Mean #edges: 2.2076535152773658 \n", 88 | "Median #edges: 2.0 \n", 89 | "Max #edges: 4 \n", 90 | "Min #edges: 1 \n", 91 | "Total #edges: 7442\n", 92 | "\n", 93 | "Mean #features_len: 1.0 \n", 94 | "Median #features_len: 1.0 \n", 95 | "Max #features_len: 1 \n", 96 | "Min #features_len: 1 \n", 97 | "Total #features_len: 3371\n", 98 | "\n", 99 | "Number of nodes with features: 3371\n", 100 | "Features distribution: Counter({(3,): 2395, (7,): 593, (6,): 345, (2,): 23, (4,): 12, (1,): 2, (5,): 1})\n" 101 | ] 102 | } 103 | ], 104 | "source": [ 105 | "print (\"Dataset: %s \\nNumber of Graphs: %s \\nLabel distribution: %s\"%(ds_name, len(graphs), Counter(labels)))\n", 106 | "avg_edges = []\n", 107 | "avg_nodes = []\n", 108 | "n_features = 0\n", 109 | "avg_features = []\n", 110 | "features = []\n", 111 | "for gidxs, nodes in graphs.items():\n", 112 | " for n in nodes:\n", 113 | " avg_edges.append(len(nodes[n]['neighbors']))\n", 114 | " if nodes[n]['label'] != '':\n", 115 | " n_features += 1\n", 116 | " avg_features.append(len(nodes[n]['label']))\n", 117 | " features.append(nodes[n]['label'])\n", 118 | " else:\n", 119 | " avg_features.append(0)\n", 120 | " features.append(None)\n", 121 | " avg_nodes.append(len(nodes))\n", 122 | "print(\"\\nMean #nodes: %s \\nMedian #nodes: %s \\nMax #nodes: %s \\nMin #nodes: %s \\nTotal #nodes: %s\"%(np.mean(avg_nodes), np.median(avg_nodes), max(avg_nodes), min(avg_nodes), sum(avg_nodes)))\n", 123 | "print(\"\\nMean #edges: %s \\nMedian #edges: %s \\nMax #edges: %s \\nMin #edges: %s \\nTotal #edges: %s\"%(np.mean(avg_edges), np.median(avg_edges), max(avg_edges), min(avg_edges), sum(avg_edges)))\n", 124 | "print(\"\\nMean #features_len: %s \\nMedian #features_len: %s \\nMax #features_len: %s \\nMin #features_len: %s \\nTotal #features_len: %s\"%(np.mean(avg_features), np.median(avg_features), max(avg_features), min(avg_features), sum(avg_features)))\n", 125 | "print(\"\\nNumber of nodes with features: %s\"%(n_features))\n", 126 | "print(\"Features distribution: %s\"%(Counter(features)))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 6, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "Example of Graph keys (e.g. node_ids):\n", 139 | " dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "graph = graphs[5]\n", 145 | "nodes = graph.keys()\n", 146 | "print(\"Example of Graph keys (e.g. node_ids):\\n\", nodes)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 7, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "Example of Node:\n", 159 | " {'neighbors': array([ 1, 13], dtype=uint8), 'label': (3,)}\n" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "node = graph[0]\n", 165 | "print(\"Example of Node:\\n\",node)" 166 | ] 167 | } 168 | ], 169 | "metadata": { 170 | "kernelspec": { 171 | "display_name": "Python 3", 172 | "language": "python", 173 | "name": "python3" 174 | }, 175 | "language_info": { 176 | "codemirror_mode": { 177 | "name": "ipython", 178 | "version": 3 179 | }, 180 | "file_extension": ".py", 181 | "mimetype": "text/x-python", 182 | "name": "python", 183 | "nbconvert_exporter": "python", 184 | "pygments_lexer": "ipython3", 185 | "version": "3.6.4" 186 | } 187 | }, 188 | "nbformat": 4, 189 | "nbformat_minor": 2 190 | } 191 | -------------------------------------------------------------------------------- /source/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import matplotlib.pyplot as plt 11 | from torch.nn import functional as F 12 | from torchvision.utils import make_grid 13 | from torchvision.utils import save_image 14 | 15 | from models.classifier import Classifier 16 | from models.graphsage import GraphSAGE 17 | from models.diffpool import DiffPool 18 | from utils.utils import * 19 | 20 | SEED = 42 21 | NORMALIZE = True 22 | 23 | def compute_accuracy(pred, target, device="cuda:0"): 24 | pred_labels = torch.stack(pred, dim=0).to(device=device) 25 | acc = (pred_labels.long() == target.long()).float().mean() 26 | 27 | return acc 28 | 29 | def run_epoch(classifier, optimizer, criterion, x_data, a_data, t_data, eval=False, device="cuda:0"): 30 | data_len = x_data.size(0) 31 | pred = [] 32 | losses = [] 33 | scores = [] 34 | if eval: 35 | classifier.eval() 36 | else: 37 | classifier.train() 38 | 39 | for i in range(data_len): 40 | optimizer.zero_grad() 41 | x = x_data[i] 42 | a = a_data[i] 43 | t = t_data[i] 44 | 45 | _, _, y = classifier(x, a) 46 | loss = criterion(y.unsqueeze(0), t.long().unsqueeze(0)) 47 | if not eval: 48 | loss.backward() 49 | optimizer.step() 50 | pred.append(y.argmax()) 51 | losses.append(loss) 52 | scores.append(y) 53 | acc_epoch = compute_accuracy(pred, t_data, device=device).item() 54 | loss_epoch = torch.FloatTensor(losses).mean().item() 55 | 56 | return acc_epoch, loss_epoch 57 | 58 | 59 | def main(): 60 | print(ARGS) 61 | start_time = time.time() 62 | 63 | device = torch.device(ARGS.device) 64 | np.random.seed(SEED) 65 | torch.manual_seed(SEED) 66 | if ARGS.device != "cpu": 67 | torch.cuda.manual_seed(SEED) 68 | normalize = NORMALIZE 69 | 70 | dataset_helper = DatasetHelper() 71 | dataset_helper.load_dataset(ARGS.dataset, device=device, seed=SEED, normalize=normalize) 72 | (x_train, a_train, labels_train) = dataset_helper.train 73 | (x_valid, a_valid, labels_valid) = dataset_helper.valid 74 | # feature_size = dataset_helper.feature_size 75 | print("Imported dataset, generated train and validation splits, took: {:.3f}s".format(time.time() - start_time)) 76 | 77 | # x1 = x_train[0] 78 | # a1 = a_train[0] 79 | # t1 = labels_train[0] 80 | # # t1_onehot = to_onehot_labels(t1) 81 | # classifier = Classifier(device=device).to(device=device) 82 | # _ = classifier(x1, a1) 83 | 84 | # # Try GraphSAGE 85 | # gcn = GraphSAGE(feature_size, feature_size*2, device=device, normalize=normalize)to(device=device) 86 | # gcn.train() 87 | # weights = gcn.linear.weight 88 | # z1 = gcn(x1, a1) 89 | # print() 90 | # 91 | # # Try DiffPool 92 | # diffpool = DiffPool(feature_size, x1.size(0)//2, device=device)to(device=device) 93 | # diffpool.train() 94 | # x1_new, a1_new = diffpool(x1, a1) 95 | # print() 96 | # input("remove test") 97 | 98 | # Try Classifier 99 | 100 | classifier = Classifier(device=device).to(device=device) 101 | optimizer = optim.Adam(classifier.parameters()) 102 | criterion = nn.CrossEntropyLoss() 103 | 104 | assert x_train.size(0) == dataset_helper.train_size == labels_train.size(0) 105 | assert x_valid.size(0) == dataset_helper.valid_size == labels_valid.size(0) 106 | assert (labels_train.sum() + labels_valid.sum()) < 1000000 107 | 108 | print("\nStarted training") 109 | acc_valid, loss_valid = run_epoch(classifier, optimizer, criterion, x_valid, a_valid, labels_valid, eval=True, 110 | device=device) 111 | print("Epoch {:.0f} | Acc (Valid): {:.3f} | Loss (Valid): {:.3f} |".format(0, acc_valid, loss_valid)) 112 | 113 | measures = { 114 | "acc" : { 115 | "train" : [], 116 | "valid" : [] 117 | }, 118 | "loss" : { 119 | "train" : [], 120 | "valid" : [] 121 | } 122 | } 123 | 124 | for e in range(ARGS.epochs): 125 | 126 | acc_train, loss_train = run_epoch(classifier, optimizer, criterion, x_train, a_train, labels_train, eval=False, device=device) 127 | acc_valid, loss_valid = run_epoch(classifier, optimizer, criterion, x_valid, a_valid, labels_valid, eval=True, device=device) 128 | print("Epoch {:.0f} | Acc (Train/Valid): {:.3f}/{:.3f} | Loss (Train/Valid): {:.3f}/{:.3f} |" 129 | .format(e + 1, acc_train, acc_valid, loss_train, loss_valid)) 130 | 131 | measures["acc"]["train"].append(acc_train) 132 | measures["acc"]["valid"].append(acc_valid) 133 | measures["loss"]["train"].append(loss_train) 134 | measures["loss"]["valid"].append(loss_valid) 135 | 136 | if e % 10 == 0: 137 | for k in measures.keys(): 138 | generate_plot(measures[k]["train"], measures[k]["valid"], title=k) 139 | 140 | print() 141 | 142 | 143 | 144 | 145 | if __name__ == "__main__": 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument('--epochs', default=100, type=int, 148 | help='max number of epochs') 149 | parser.add_argument('--batch_size', default=16, type=int, 150 | help='size of each batch') 151 | parser.add_argument('--lr_rate', default=1e-4, type=float, 152 | help='learning rate') 153 | parser.add_argument('--device', default="cuda:0", type=str, 154 | help='training device') 155 | parser.add_argument('--dataset', default="../data/mutag.graph", type=str, 156 | help='dataset path') 157 | 158 | ARGS = parser.parse_args() 159 | 160 | main() 161 | 162 | -------------------------------------------------------------------------------- /source/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import errno 4 | import pickle 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | import numpy as np 9 | import torch 10 | from torch.nn import functional as F 11 | 12 | KNOWN_DATASETS = {"../data/mutag.graph"} 13 | 14 | 15 | class DatasetHelper(object): 16 | 17 | def __init__(self): 18 | self.train = None 19 | self.valid = None 20 | self.feature_size = -1 21 | self.n_graphs = -1 22 | self.n_nodes = -1 23 | self.train_size = -1 24 | self.valid_size = -1 25 | self.onehot = None 26 | self.labels_set = None 27 | self.TRAIN_RATIO = 0.8 28 | 29 | 30 | def read_file(self, ds_name): 31 | 32 | if not os.path.isfile(ds_name): 33 | raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), ds_name) 34 | if ds_name not in KNOWN_DATASETS: 35 | raise NotImplementedError("Dataset unknown to 'load_dataset()' in 'utils/utils.py'") 36 | 37 | f = open(ds_name, "rb") 38 | print("Found dataset:", ds_name) 39 | data = pickle.load(f, encoding="latin1") 40 | graphs = data["graph"] 41 | labels = np.array(data["labels"], dtype=np.float) 42 | 43 | return graphs, labels 44 | 45 | def normalize_label(self, l): 46 | # from -1,1 to 0,1 47 | n = (l + 1) // 2 48 | if n != 0 and n != 1: 49 | raise ValueError 50 | return n 51 | 52 | 53 | def load_dataset(self, ds_name, device="cuda:0", seed=42, normalize=True, onehot=True): 54 | 55 | graphs, labels = self.read_file(ds_name) 56 | self.labels_set = set(labels) 57 | 58 | # Compute shape dimensions n_graphs, n_nodes, features_size 59 | self.n_graphs = len(graphs) 60 | self.n_nodes = -1 # Max number of nodes among all of the graphs 61 | self.onehot = onehot 62 | 63 | # Find the feature size (scalar or onehot) 64 | if self.onehot: 65 | # Find the size of the onehot vector for the features (i.e.: the maximum value present in the dataset) 66 | self.feature_size = 0 # Feature array size for each node is onehot vector 67 | # min_value = 1000 68 | for i in range(self.n_graphs): 69 | for j in range(len(graphs[i])): 70 | for _, d in enumerate(graphs[i][j]['label']): 71 | self.feature_size = max(self.feature_size, d) 72 | # min_value = min(min_value, d) 73 | else: 74 | self.feature_size = len(graphs[0][0]['label']) # Feature array size for each node 75 | 76 | # Find number of nodes 77 | for gidxs, graph in graphs.items(): 78 | self.n_nodes = max(self.n_nodes, len(graph)) 79 | assert self.n_nodes > 0, "Apparently,there are no nodes in these graphs" 80 | 81 | # Generate train and valid splits 82 | torch.manual_seed(seed) 83 | shuffled_idx = torch.randperm(self.n_graphs) 84 | self.train_size = int(self.n_graphs * self.TRAIN_RATIO) 85 | self.valid_size = self.n_graphs - self.train_size 86 | train_idx = shuffled_idx[:self.train_size] 87 | valid_idx = shuffled_idx[self.train_size:] 88 | 89 | # Generate PyTorch tensors for train 90 | a_train = torch.zeros((self.train_size, self.n_nodes, self.n_nodes), dtype=torch.float, device=device) 91 | x_train = torch.zeros((self.train_size, self.n_nodes, self.feature_size), dtype=torch.float, device=device) 92 | labels_train = torch.LongTensor(self.train_size).to(device=device) 93 | for i in range(self.train_size): 94 | idx = train_idx[i].item() 95 | labels_train[i] = self.normalize_label(labels[idx]) 96 | for j in range(len(graphs[idx])): 97 | for k in graphs[idx][j]['neighbors']: 98 | a_train[i, j, k] = 1 99 | for k, d in enumerate(graphs[idx][j]['label']): 100 | if self.onehot: 101 | x_train[i, j, :] = to_onehot(d, self.feature_size, device) 102 | else: 103 | x_train[i, j, k] = float(d) 104 | 105 | # Generate PyTorch tensors for valid 106 | a_valid = torch.zeros((self.valid_size, self.n_nodes, self.n_nodes), dtype=torch.float, device=device) 107 | x_valid = torch.zeros((self.valid_size, self.n_nodes, self.feature_size), dtype=torch.float, device=device) 108 | labels_valid = torch.LongTensor(self.valid_size).to(device=device) 109 | for i in range(self.valid_size): 110 | idx = valid_idx[i].item() 111 | labels_valid[i] = self.normalize_label(labels[idx]) 112 | for j in range(len(graphs[idx])): 113 | for k in graphs[idx][j]['neighbors']: 114 | a_valid[i, j, k] = 1 115 | for k, d in enumerate(graphs[idx][j]['label']): 116 | if self.onehot: 117 | x_valid[i, j, :] = to_onehot(d, self.feature_size, device) 118 | else: 119 | x_valid[i, j, k] = float(d) 120 | 121 | if normalize: 122 | x_train = F.normalize(x_train, p=2, dim=1) 123 | x_valid = F.normalize(x_valid, p=2, dim=1) 124 | 125 | self.train = (x_train, a_train, labels_train) 126 | self.valid = (x_valid, a_valid, labels_valid) 127 | 128 | 129 | def to_onehot(x, size, device="cuda:0"): 130 | t = torch.zeros(size, device=device) 131 | t[x - 1] = 1 # Since the minimum value is 1, we index -1 the features because we don't need element 0 132 | return t 133 | 134 | def to_onehot_labels(x, device="cuda:0"): 135 | t = torch.zeros(2, device=device) 136 | if x.item() == 1: 137 | t[1] = 1 138 | else: 139 | t[0] = 1 140 | return t 141 | 142 | def generate_plot(train, valid, title="Insert Title here"): 143 | plt.figure() 144 | plt.plot(valid, label="Validation") 145 | plt.plot(train, label="Training") 146 | plt.xlabel("Epoch") 147 | # plt.ylabel() 148 | plt.legend() 149 | plt.title(title) 150 | # plt.show() 151 | plt.savefig("./"+title+".svg", format='svg', dpi=1000) 152 | return 153 | --------------------------------------------------------------------------------