├── README.md └── final.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Understanding Graph Convolutional Networks for Text Classification 2 | Official Implementation for AAAI 2022 on DLG 3 | (https://arxiv.org/abs/2203.16060) 4 | 5 |

6 | Han, C.*, Yuan, Z.*, Wang, K., Long, S., & Poon, J. (2022).
Understanding Graph Convolutional Networks for Text Classification
In proceeding of AAAI 2022 on DLG
7 |

8 | *co first author 9 | 10 | 11 | 12 | ## Easy running using .ipynb file 13 | You can simply run the code with your data using `final.ipynb`, remember to fill in your dataset into a list of documents/labels 14 | ```python 15 | original_train_sentences = 16 | original_labels_train = 17 | original_test_sentences = 18 | original_labels_test = 19 | 20 | # example 21 | # original_train_sentences = ['this is sample 1','this is sample 2'] 22 | # original_labels_train = ['postive','negative'] 23 | # original_test_sentences = ['this is sample 1','this is sample 2'] 24 | # original_labels_test = ['postive','negative'] 25 | ``` 26 | Also, some other parameters can be modified 27 | ```python 28 | 29 | # EDGE: 0 means only d2w edge, 1 means d2w+w2w, 2 means d2w+w2w+d2d edge 30 | EDGE = 0 31 | 32 | # NODE: 0 means one-hot as input, 1 means BERT embedding as input 33 | NODE = 0 34 | 35 | NUM_LAYERS = 2 36 | HIDDEN_DIM = 200 37 | DROP_OUT = 0.5 38 | LR = 0.02 39 | WEIGHT_DECAY = 0 40 | EARLY_STOPPING = 10 41 | NUM_EPOCHS = 200 42 | ``` 43 | -------------------------------------------------------------------------------- /final.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "hJPI-IXrBkrP" 7 | }, 8 | "source": [ 9 | "# Dataset Preparation" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "-GyzNkI7W03D" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "original_train_sentences = \n", 21 | "original_labels_train = \n", 22 | "original_test_sentences = \n", 23 | "original_labels_test = \n", 24 | "\n", 25 | "# example \n", 26 | "# original_train_sentences = ['this is sample 1','this is sample 2']\n", 27 | "# original_labels_train = ['postive','negative']\n", 28 | "# original_test_sentences = ['this is sample 1','this is sample 2']\n", 29 | "# original_labels_test = ['postive','negative']\n", 30 | "\n", 31 | "train_size = len(original_train_sentences)\n", 32 | "test_size = len(original_test_sentences)\n", 33 | "sentences = original_train_sentences + original_test_sentences" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": { 39 | "id": "6K9dWTv5I07_" 40 | }, 41 | "source": [ 42 | "# Hyper Parameters" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "EDGE = 2 # 0:d2w 1:d2w+w2w 2:d2w+w2w+d2d\n", 52 | "NODE = 0 # 0:one-hot #1:BERT \n", 53 | "NUM_LAYERS = 2 \n", 54 | "\n", 55 | "HIDDEN_DIM = 200\n", 56 | "DROP_OUT = 0.5\n", 57 | "LR = 0.02\n", 58 | "WEIGHT_DECAY = 0\n", 59 | "EARLY_STOPPING = 10\n", 60 | "NUM_EPOCHS = 200" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": { 66 | "id": "a2W7wKTBfa71" 67 | }, 68 | "source": [ 69 | "# Preprocess" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": { 75 | "id": "hobYcJ5OX5oT" 76 | }, 77 | "source": [ 78 | "## Label Encoding" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "id": "PtWyhXiueMOq" 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "import numpy as np\n", 90 | "from sklearn.preprocessing import LabelEncoder\n", 91 | "\n", 92 | "unique_labels=np.unique(original_labels_train)\n", 93 | "\n", 94 | "num_class = len(unique_labels)\n", 95 | "lEnc = LabelEncoder()\n", 96 | "lEnc.fit(unique_labels)\n", 97 | "\n", 98 | "print(unique_labels)\n", 99 | "print(lEnc.transform(unique_labels))\n", 100 | "\n", 101 | "train_labels = lEnc.transform(original_labels_train)\n", 102 | "test_labels = lEnc.transform(original_labels_test)\n", 103 | "\n", 104 | "import torch\n", 105 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 106 | "\n", 107 | "labels = train_labels.tolist()+test_labels.tolist()\n", 108 | "labels = torch.LongTensor(labels).to(device)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": { 114 | "id": "ZMkEBxr6fMQi" 115 | }, 116 | "source": [ 117 | "## Remove Stopwords and less frequent words, tokenize sentences" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": { 124 | "id": "1xRG94uDfaBV" 125 | }, 126 | "outputs": [], 127 | "source": [ 128 | "from nltk.corpus import stopwords\n", 129 | "from keras.preprocessing.sequence import pad_sequences\n", 130 | "import nltk\n", 131 | "import re\n", 132 | "\n", 133 | "nltk.download('stopwords')\n", 134 | "stop_words = set(stopwords.words('english'))\n", 135 | "remove_limit = 5\n", 136 | "\n", 137 | "\n", 138 | "def clean_str(string):\n", 139 | " string = re.sub(r\"[^A-Za-z0-9(),!?\\'\\`]\", \" \", string)\n", 140 | " string = re.sub(r\"\\'s\", \" \\'s\", string)\n", 141 | " string = re.sub(r\"\\'ve\", \" \\'ve\", string)\n", 142 | " string = re.sub(r\"n\\'t\", \" n\\'t\", string)\n", 143 | " string = re.sub(r\"\\'re\", \" \\'re\", string)\n", 144 | " string = re.sub(r\"\\'d\", \" \\'d\", string)\n", 145 | " string = re.sub(r\"\\'ll\", \" \\'ll\", string)\n", 146 | " string = re.sub(r\",\", \" , \", string)\n", 147 | " string = re.sub(r\"!\", \" ! \", string)\n", 148 | " string = re.sub(r\"\\(\", \" \\( \", string)\n", 149 | " string = re.sub(r\"\\)\", \" \\) \", string)\n", 150 | " string = re.sub(r\"\\?\", \" \\? \", string)\n", 151 | " string = re.sub(r\"\\s{2,}\", \" \", string)\n", 152 | " return string.strip().lower()\n", 153 | "\n", 154 | "original_word_freq = {} # to remove rare words\n", 155 | "for sentence in sentences:\n", 156 | " temp = clean_str(sentence)\n", 157 | " word_list = temp.split()\n", 158 | " for word in word_list:\n", 159 | " if word in original_word_freq:\n", 160 | " original_word_freq[word] += 1\n", 161 | " else:\n", 162 | " original_word_freq[word] = 1 \n", 163 | "\n", 164 | "tokenize_sentences = []\n", 165 | "word_list_dict = {}\n", 166 | "for sentence in sentences:\n", 167 | " temp = clean_str(sentence)\n", 168 | " word_list_temp = temp.split()\n", 169 | " doc_words = []\n", 170 | " for word in word_list_temp: \n", 171 | " if word in original_word_freq and word not in stop_words and original_word_freq[word] >= remove_limit:\n", 172 | " doc_words.append(word)\n", 173 | " word_list_dict[word] = 1\n", 174 | " tokenize_sentences.append(doc_words)\n", 175 | "word_list = list(word_list_dict.keys())\n", 176 | "vocab_length = len(word_list)\n", 177 | "\n", 178 | "#word to id dict\n", 179 | "word_id_map = {}\n", 180 | "for i in range(vocab_length):\n", 181 | " word_id_map[word_list[i]] = i " 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": { 188 | "id": "dqLUncB2Pn_L" 189 | }, 190 | "outputs": [], 191 | "source": [ 192 | "node_size = train_size + vocab_length + test_size" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": { 198 | "id": "g0o8wcXgrTiD" 199 | }, 200 | "source": [ 201 | "# Model input" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": { 208 | "id": "EZbRV2wYxY1U" 209 | }, 210 | "outputs": [], 211 | "source": [ 212 | "import torch\n", 213 | "import torch.nn as nn\n", 214 | "import torch.nn.functional as F\n", 215 | "from tqdm.notebook import tqdm" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": { 221 | "id": "znJ7Grz7fQ2L" 222 | }, 223 | "source": [ 224 | "## Build Graph" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": { 231 | "id": "-BSg1uNgV3_7" 232 | }, 233 | "outputs": [], 234 | "source": [ 235 | "from math import log\n", 236 | "row = []\n", 237 | "col = []\n", 238 | "weight = []" 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": { 244 | "id": "QESQPT88AqsI" 245 | }, 246 | "source": [ 247 | "### word-word: PMI" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": { 254 | "id": "KNlJoLFagXhv" 255 | }, 256 | "outputs": [], 257 | "source": [ 258 | "if EDGE >= 1:\n", 259 | " window_size = 20\n", 260 | " total_W = 0\n", 261 | " word_occurrence = {}\n", 262 | " word_pair_occurrence = {}\n", 263 | "\n", 264 | " def ordered_word_pair(a, b):\n", 265 | " if a > b:\n", 266 | " return b, a\n", 267 | " else:\n", 268 | " return a, b\n", 269 | "\n", 270 | " def update_word_and_word_pair_occurrence(q):\n", 271 | " unique_q = list(set(q))\n", 272 | " for i in unique_q:\n", 273 | " try:\n", 274 | " word_occurrence[i] += 1\n", 275 | " except:\n", 276 | " word_occurrence[i] = 1\n", 277 | " for i in range(len(unique_q)):\n", 278 | " for j in range(i+1, len(unique_q)):\n", 279 | " word1 = unique_q[i]\n", 280 | " word2 = unique_q[j]\n", 281 | " word1, word2 = ordered_word_pair(word1, word2)\n", 282 | " try:\n", 283 | " word_pair_occurrence[(word1, word2)] += 1\n", 284 | " except:\n", 285 | " word_pair_occurrence[(word1, word2)] = 1\n", 286 | "\n", 287 | "\n", 288 | " for ind in tqdm(range(train_size+test_size)):\n", 289 | " words = tokenize_sentences[ind]\n", 290 | "\n", 291 | " q = []\n", 292 | " # push the first (window_size) words into a queue\n", 293 | " for i in range(min(window_size, len(words))):\n", 294 | " q += [word_id_map[words[i]]]\n", 295 | " # update the total number of the sliding windows\n", 296 | " total_W += 1\n", 297 | " # update the number of sliding windows that contain each word and word pair\n", 298 | " update_word_and_word_pair_occurrence(q)\n", 299 | "\n", 300 | " now_next_word_index = window_size\n", 301 | " # pop the first word out and let the next word in, keep doing this until the end of the document\n", 302 | " while now_next_word_index=2:\n", 424 | " tokenize_sentences_set = [set(s) for s in tokenize_sentences]\n", 425 | " jaccard_threshold = 0.2\n", 426 | " for i in tqdm(range(len(tokenize_sentences))):\n", 427 | " for j in range(i+1, len(tokenize_sentences)):\n", 428 | " jaccard_w = 1 - nltk.jaccard_distance(tokenize_sentences_set[i], tokenize_sentences_set[j])\n", 429 | " if jaccard_w > jaccard_threshold:\n", 430 | " if i < train_size:\n", 431 | " row.append(i)\n", 432 | " else:\n", 433 | " row.append(i + vocab_length)\n", 434 | " if j < train_size:\n", 435 | " col.append(j)\n", 436 | " else:\n", 437 | " col.append(vocab_length + j)\n", 438 | " weight.append(jaccard_w)\n", 439 | " if j < train_size:\n", 440 | " row.append(j)\n", 441 | " else:\n", 442 | " row.append(j + vocab_length)\n", 443 | " if i < train_size:\n", 444 | " col.append(i)\n", 445 | " else:\n", 446 | " col.append(vocab_length + i)\n", 447 | " weight.append(jaccard_w)" 448 | ] 449 | }, 450 | { 451 | "cell_type": "markdown", 452 | "metadata": { 453 | "id": "uIkGgB2aZDk7" 454 | }, 455 | "source": [ 456 | "### Adjacent matrix" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "metadata": { 463 | "id": "C0O1Ucdhod9a" 464 | }, 465 | "outputs": [], 466 | "source": [ 467 | "import scipy.sparse as sp\n", 468 | "adj = sp.csr_matrix((weight, (row, col)), shape=(node_size, node_size))\n", 469 | "\n", 470 | "# build symmetric adjacency matrix\n", 471 | "adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": null, 477 | "metadata": { 478 | "id": "ivyuexATkQFW" 479 | }, 480 | "outputs": [], 481 | "source": [ 482 | "def normalize_adj(adj):\n", 483 | " \"\"\"Symmetrically normalize adjacency matrix.\"\"\"\n", 484 | " adj = sp.coo_matrix(adj)\n", 485 | " rowsum = np.array(adj.sum(1))\n", 486 | " d_inv_sqrt = np.power(rowsum, -0.5).flatten()\n", 487 | " d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.\n", 488 | " d_mat_inv_sqrt = sp.diags(d_inv_sqrt)\n", 489 | " return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo(), d_inv_sqrt\n", 490 | " \n", 491 | "adj, norm_item = normalize_adj(adj + sp.eye(adj.shape[0]))\n", 492 | "\n", 493 | "\n", 494 | "def sparse_mx_to_torch_sparse_tensor(sparse_mx):\n", 495 | " \"\"\"Convert a scipy sparse matrix to a torch sparse tensor.\"\"\"\n", 496 | " sparse_mx = sparse_mx.tocoo().astype(np.float32)\n", 497 | " indices = torch.from_numpy(\n", 498 | " np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))\n", 499 | " values = torch.from_numpy(sparse_mx.data)\n", 500 | " shape = torch.Size(sparse_mx.shape)\n", 501 | " return torch.sparse.FloatTensor(indices, values, shape).to(device)\n", 502 | "\n", 503 | "adj = sparse_mx_to_torch_sparse_tensor(adj)" 504 | ] 505 | }, 506 | { 507 | "cell_type": "markdown", 508 | "metadata": { 509 | "id": "pMgbhTstMSUA" 510 | }, 511 | "source": [ 512 | "## Features" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": null, 518 | "metadata": { 519 | "id": "mP9dqCskOrXT" 520 | }, 521 | "outputs": [], 522 | "source": [ 523 | "if NODE == 0:\n", 524 | " features = np.arange(node_size)\n", 525 | " features = torch.FloatTensor(features).to(device)\n", 526 | "else:\n", 527 | " \n", 528 | " from flair.embeddings import TransformerDocumentEmbeddings, TransformerWordEmbeddings\n", 529 | " from flair.data import Sentence\n", 530 | " doc_embedding = TransformerDocumentEmbeddings('bert-base-uncased', fine_tune=False)\n", 531 | " word_embedding = TransformerWordEmbeddings('bert-base-uncased', layers='-1',subtoken_pooling=\"mean\")\n", 532 | "\n", 533 | " sent_embs = []\n", 534 | " word_embs = {}\n", 535 | "\n", 536 | " for ind in tqdm(range(train_size+test_size)):\n", 537 | " sent = tokenize_sentences[ind]\n", 538 | " sentence = Sentence(\" \".join(sent[:512]),use_tokenizer=False)\n", 539 | " doc_embedding.embed(sentence)\n", 540 | " sent_embs.append(sentence.get_embedding().tolist())\n", 541 | " words = Sentence(\" \".join(sent[:512]),use_tokenizer=False)\n", 542 | " word_embedding.embed(words)\n", 543 | " for token in words:\n", 544 | " word = token.text\n", 545 | " embedding = token.embedding.tolist()\n", 546 | " if word not in word_embs:\n", 547 | " word_embs[word] = embedding\n", 548 | " else:\n", 549 | " word_embs[word] = np.minimum(word_embs[word], embedding)\n", 550 | "\n", 551 | " word_embs_list = []\n", 552 | " for word in word_list:\n", 553 | " word_embs_list.append(word_embs[word])\n", 554 | "\n", 555 | " features = sent_embs[:train_size] + word_embs_list + sent_embs[train_size:]\n", 556 | "\n", 557 | " import scipy.sparse as sp\n", 558 | " def preprocess_features(features):\n", 559 | " \"\"\"Row-normalize feature matrix and convert to tuple representation\"\"\"\n", 560 | " rowsum = np.array(features.sum(1))\n", 561 | " r_inv = np.power(rowsum, -1).flatten()\n", 562 | " r_inv[np.isinf(r_inv)] = 0.\n", 563 | " r_mat_inv = sp.diags(r_inv)\n", 564 | " features = r_mat_inv.dot(features)\n", 565 | " return features\n", 566 | "\n", 567 | " features = preprocess_features(sp.csr_matrix(features)).todense()\n", 568 | " features = torch.FloatTensor(features).to(device)" 569 | ] 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "metadata": { 574 | "id": "pdx6RrUvjbF0" 575 | }, 576 | "source": [ 577 | "# Model" 578 | ] 579 | }, 580 | { 581 | "cell_type": "markdown", 582 | "metadata": { 583 | "id": "39Kj8NQujiDH" 584 | }, 585 | "source": [ 586 | "## GCN Layer" 587 | ] 588 | }, 589 | { 590 | "cell_type": "code", 591 | "execution_count": null, 592 | "metadata": { 593 | "id": "jNVkA-h7b3sP" 594 | }, 595 | "outputs": [], 596 | "source": [ 597 | "import math\n", 598 | "\n", 599 | "import torch\n", 600 | "\n", 601 | "from torch.nn.parameter import Parameter\n", 602 | "from torch.nn.modules.module import Module\n", 603 | "\n", 604 | "\n", 605 | "class GraphConvolution(Module):\n", 606 | " \"\"\"\n", 607 | " Simple GCN layer, similar to https://arxiv.org/abs/1609.02907\n", 608 | " \"\"\"\n", 609 | "\n", 610 | " def __init__(self, in_features, out_features, drop_out = 0, activation=None, bias=True):\n", 611 | " super(GraphConvolution, self).__init__()\n", 612 | " self.in_features = in_features\n", 613 | " self.out_features = out_features\n", 614 | " self.weight = Parameter(torch.FloatTensor(in_features, out_features))\n", 615 | " if bias:\n", 616 | " self.bias = Parameter(torch.zeros(1, out_features))\n", 617 | " else:\n", 618 | " self.register_parameter('bias', None)\n", 619 | " self.reset_parameters(in_features, out_features)\n", 620 | " self.dropout = torch.nn.Dropout(drop_out)\n", 621 | " self.activation = activation\n", 622 | "\n", 623 | " def reset_parameters(self,in_features, out_features):\n", 624 | " stdv = np.sqrt(6.0/(in_features+out_features))\n", 625 | " # stdv = 1. / math.sqrt(self.weight.size(1))\n", 626 | " self.weight.data.uniform_(-stdv, stdv)\n", 627 | " # if self.bias is not None:\n", 628 | " # torch.nn.init.zeros_(self.bias)\n", 629 | " # self.bias.data.uniform_(-stdv, stdv)\n", 630 | "\n", 631 | "\n", 632 | " def forward(self, input, adj, feature_less = False):\n", 633 | " if feature_less:\n", 634 | " support = self.weight\n", 635 | " support = self.dropout(support)\n", 636 | " else:\n", 637 | " input = self.dropout(input)\n", 638 | " support = torch.mm(input, self.weight)\n", 639 | " output = torch.spmm(adj, support)\n", 640 | " if self.bias is not None:\n", 641 | " output = output + self.bias\n", 642 | " if self.activation is not None:\n", 643 | " output = self.activation(output)\n", 644 | " return output\n", 645 | "\n", 646 | " def __repr__(self):\n", 647 | " return self.__class__.__name__ + ' (' \\\n", 648 | " + str(self.in_features) + ' -> ' \\\n", 649 | " + str(self.out_features) + ')'" 650 | ] 651 | }, 652 | { 653 | "cell_type": "markdown", 654 | "metadata": { 655 | "id": "k57M4sz4s4Md" 656 | }, 657 | "source": [ 658 | "## GCN Model" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": null, 664 | "metadata": { 665 | "id": "aJ-ZQuMzs5tZ" 666 | }, 667 | "outputs": [], 668 | "source": [ 669 | "import torch.nn as nn\n", 670 | "import torch.nn.functional as F\n", 671 | "\n", 672 | "class GCN(nn.Module):\n", 673 | " def __init__(self, nfeat, nhid, nclass, dropout, n_layers = 2):\n", 674 | " super(GCN, self).__init__()\n", 675 | " self.n_layers = n_layers\n", 676 | " self.gc_list = []\n", 677 | " if n_layers >= 2:\n", 678 | " self.gc1 = GraphConvolution(nfeat, nhid, dropout, activation = nn.ReLU())\n", 679 | " self.gc_list = nn.ModuleList([GraphConvolution(nhid, nhid, dropout, activation = nn.ReLU()) for _ in range(self.n_layers-2)])\n", 680 | " self.gcf = GraphConvolution(nhid, nclass, dropout)\n", 681 | " else:\n", 682 | " self.gc1 = GraphConvolution(nfeat, nclass, dropout)\n", 683 | "\n", 684 | " def forward(self, x, adj):\n", 685 | " if self.n_layers>=2:\n", 686 | " x = self.gc1(x, adj, feature_less = True)\n", 687 | " for i in range(self.n_layers-2):\n", 688 | " x = self.gc_list[i](x,adj)\n", 689 | " x = self.gcf(x,adj)\n", 690 | " else:\n", 691 | " x = self.gc1(x, adj, feature_less = True)\n", 692 | " return x" 693 | ] 694 | }, 695 | { 696 | "cell_type": "code", 697 | "execution_count": null, 698 | "metadata": { 699 | "id": "qmhOG1yG--Ji" 700 | }, 701 | "outputs": [], 702 | "source": [ 703 | "def cal_accuracy(predictions,labels):\n", 704 | " pred = torch.argmax(predictions,-1).cpu().tolist()\n", 705 | " lab = labels.cpu().tolist()\n", 706 | " cor = 0\n", 707 | " for i in range(len(pred)):\n", 708 | " if pred[i] == lab[i]:\n", 709 | " cor += 1\n", 710 | " return cor/len(pred)" 711 | ] 712 | }, 713 | { 714 | "cell_type": "markdown", 715 | "metadata": { 716 | "id": "zEE4JxeUthCb" 717 | }, 718 | "source": [ 719 | "# Training" 720 | ] 721 | }, 722 | { 723 | "cell_type": "markdown", 724 | "metadata": { 725 | "id": "bIxII4QoticA" 726 | }, 727 | "source": [ 728 | "## Initialize model" 729 | ] 730 | }, 731 | { 732 | "cell_type": "code", 733 | "execution_count": null, 734 | "metadata": { 735 | "id": "hdNsgxMG-Wwu" 736 | }, 737 | "outputs": [], 738 | "source": [ 739 | "import torch.optim as optim\n", 740 | "\n", 741 | "\n", 742 | "criterion = nn.CrossEntropyLoss()\n", 743 | "\n", 744 | "model = GCN(nfeat=node_size, nhid=HIDDEN_DIM, nclass=num_class, dropout=DROP_OUT,n_layers=NUM_LAYERS).to(device)\n", 745 | "optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)" 746 | ] 747 | }, 748 | { 749 | "cell_type": "markdown", 750 | "metadata": { 751 | "id": "T98r4qZuuFyn" 752 | }, 753 | "source": [ 754 | "## Training and Validating" 755 | ] 756 | }, 757 | { 758 | "cell_type": "code", 759 | "execution_count": null, 760 | "metadata": { 761 | "id": "Bv9br9pgGw9R" 762 | }, 763 | "outputs": [], 764 | "source": [ 765 | "def generate_train_val(train_pro=0.9):\n", 766 | " real_train_size = int(train_pro*train_size)\n", 767 | " val_size = train_size-real_train_size\n", 768 | "\n", 769 | " idx_train = np.random.choice(train_size, real_train_size,replace=False)\n", 770 | " idx_train.sort()\n", 771 | " idx_val = []\n", 772 | " pointer = 0\n", 773 | " for v in range(train_size):\n", 774 | " if pointer EARLY_STOPPING and np.min(val_loss[-EARLY_STOPPING:]) > np.min(val_loss[:-EARLY_STOPPING]) :\n", 821 | " if show_result:\n", 822 | " print(\"Early Stopping...\")\n", 823 | " break\n", 824 | "\n", 825 | "train_model()" 826 | ] 827 | }, 828 | { 829 | "cell_type": "markdown", 830 | "metadata": { 831 | "id": "OQwlWq6dyYJm" 832 | }, 833 | "source": [ 834 | "## Evaluation" 835 | ] 836 | }, 837 | { 838 | "cell_type": "code", 839 | "execution_count": null, 840 | "metadata": { 841 | "id": "jmPNukmk40gd" 842 | }, 843 | "outputs": [], 844 | "source": [ 845 | "from sklearn.metrics import f1_score, accuracy_score\n", 846 | "def test():\n", 847 | " model.eval()\n", 848 | " output = model(features, adj)\n", 849 | " predictions = torch.argmax(output[idx_test],-1).cpu().tolist()\n", 850 | " acc = accuracy_score(test_labels,predictions)\n", 851 | " f11 = f1_score(test_labels,predictions, average='macro')\n", 852 | " f12 = f1_score(test_labels,predictions, average = 'weighted')\n", 853 | " return acc, f11, f12\n", 854 | "\n", 855 | "print(test())" 856 | ] 857 | }, 858 | { 859 | "cell_type": "markdown", 860 | "metadata": { 861 | "id": "LOFsVlv4hTgc" 862 | }, 863 | "source": [ 864 | "# Test 10 times" 865 | ] 866 | }, 867 | { 868 | "cell_type": "code", 869 | "execution_count": null, 870 | "metadata": { 871 | "id": "ydMqrCkehVPW" 872 | }, 873 | "outputs": [], 874 | "source": [ 875 | "test_acc_list = []\n", 876 | "test_f11_list = []\n", 877 | "test_f12_list = []\n", 878 | "\n", 879 | "for t in range(10):\n", 880 | " model = GCN(nfeat=node_size, nhid=HIDDEN_DIM, nclass=num_class, dropout=DROP_OUT,n_layers=NUM_LAYERS).to(device)\n", 881 | " optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n", 882 | " idx_train, idx_val, idx_test = generate_train_val()\n", 883 | " train_model(show_result=False)\n", 884 | " acc, f11, f12 = test()\n", 885 | " test_acc_list.append(acc)\n", 886 | " test_f11_list.append(f11)\n", 887 | " test_f12_list.append(f12)\n", 888 | "\n", 889 | "\n", 890 | "print(\"Accuracy:\",np.round(np.mean(test_acc_list),4))\n", 891 | "print(\"Macro F1:\",np.round(np.mean(test_f11_list),4))\n", 892 | "print(\"Weighted F1:\",np.round(np.mean(test_f12_list),4))" 893 | ] 894 | } 895 | ], 896 | "metadata": { 897 | "accelerator": "GPU", 898 | "colab": { 899 | "collapsed_sections": [], 900 | "machine_shape": "hm", 901 | "name": "final code.ipynb", 902 | "provenance": [], 903 | "toc_visible": true 904 | }, 905 | "kernelspec": { 906 | "display_name": "Python 3", 907 | "name": "python3" 908 | } 909 | }, 910 | "nbformat": 4, 911 | "nbformat_minor": 0 912 | } 913 | --------------------------------------------------------------------------------