├── README.md
└── final.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | # Understanding Graph Convolutional Networks for Text Classification
2 | Official Implementation for AAAI 2022 on DLG
3 | (https://arxiv.org/abs/2203.16060)
4 |
5 |
8 | *co first author
9 |
10 |
11 |
12 | ## Easy running using .ipynb file
13 | You can simply run the code with your data using `final.ipynb`, remember to fill in your dataset into a list of documents/labels
14 | ```python
15 | original_train_sentences =
16 | original_labels_train =
17 | original_test_sentences =
18 | original_labels_test =
19 |
20 | # example
21 | # original_train_sentences = ['this is sample 1','this is sample 2']
22 | # original_labels_train = ['postive','negative']
23 | # original_test_sentences = ['this is sample 1','this is sample 2']
24 | # original_labels_test = ['postive','negative']
25 | ```
26 | Also, some other parameters can be modified
27 | ```python
28 |
29 | # EDGE: 0 means only d2w edge, 1 means d2w+w2w, 2 means d2w+w2w+d2d edge
30 | EDGE = 0
31 |
32 | # NODE: 0 means one-hot as input, 1 means BERT embedding as input
33 | NODE = 0
34 |
35 | NUM_LAYERS = 2
36 | HIDDEN_DIM = 200
37 | DROP_OUT = 0.5
38 | LR = 0.02
39 | WEIGHT_DECAY = 0
40 | EARLY_STOPPING = 10
41 | NUM_EPOCHS = 200
42 | ```
43 |
--------------------------------------------------------------------------------
/final.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "hJPI-IXrBkrP"
7 | },
8 | "source": [
9 | "# Dataset Preparation"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "-GyzNkI7W03D"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "original_train_sentences = \n",
21 | "original_labels_train = \n",
22 | "original_test_sentences = \n",
23 | "original_labels_test = \n",
24 | "\n",
25 | "# example \n",
26 | "# original_train_sentences = ['this is sample 1','this is sample 2']\n",
27 | "# original_labels_train = ['postive','negative']\n",
28 | "# original_test_sentences = ['this is sample 1','this is sample 2']\n",
29 | "# original_labels_test = ['postive','negative']\n",
30 | "\n",
31 | "train_size = len(original_train_sentences)\n",
32 | "test_size = len(original_test_sentences)\n",
33 | "sentences = original_train_sentences + original_test_sentences"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {
39 | "id": "6K9dWTv5I07_"
40 | },
41 | "source": [
42 | "# Hyper Parameters"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "EDGE = 2 # 0:d2w 1:d2w+w2w 2:d2w+w2w+d2d\n",
52 | "NODE = 0 # 0:one-hot #1:BERT \n",
53 | "NUM_LAYERS = 2 \n",
54 | "\n",
55 | "HIDDEN_DIM = 200\n",
56 | "DROP_OUT = 0.5\n",
57 | "LR = 0.02\n",
58 | "WEIGHT_DECAY = 0\n",
59 | "EARLY_STOPPING = 10\n",
60 | "NUM_EPOCHS = 200"
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {
66 | "id": "a2W7wKTBfa71"
67 | },
68 | "source": [
69 | "# Preprocess"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {
75 | "id": "hobYcJ5OX5oT"
76 | },
77 | "source": [
78 | "## Label Encoding"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {
85 | "id": "PtWyhXiueMOq"
86 | },
87 | "outputs": [],
88 | "source": [
89 | "import numpy as np\n",
90 | "from sklearn.preprocessing import LabelEncoder\n",
91 | "\n",
92 | "unique_labels=np.unique(original_labels_train)\n",
93 | "\n",
94 | "num_class = len(unique_labels)\n",
95 | "lEnc = LabelEncoder()\n",
96 | "lEnc.fit(unique_labels)\n",
97 | "\n",
98 | "print(unique_labels)\n",
99 | "print(lEnc.transform(unique_labels))\n",
100 | "\n",
101 | "train_labels = lEnc.transform(original_labels_train)\n",
102 | "test_labels = lEnc.transform(original_labels_test)\n",
103 | "\n",
104 | "import torch\n",
105 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
106 | "\n",
107 | "labels = train_labels.tolist()+test_labels.tolist()\n",
108 | "labels = torch.LongTensor(labels).to(device)"
109 | ]
110 | },
111 | {
112 | "cell_type": "markdown",
113 | "metadata": {
114 | "id": "ZMkEBxr6fMQi"
115 | },
116 | "source": [
117 | "## Remove Stopwords and less frequent words, tokenize sentences"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "metadata": {
124 | "id": "1xRG94uDfaBV"
125 | },
126 | "outputs": [],
127 | "source": [
128 | "from nltk.corpus import stopwords\n",
129 | "from keras.preprocessing.sequence import pad_sequences\n",
130 | "import nltk\n",
131 | "import re\n",
132 | "\n",
133 | "nltk.download('stopwords')\n",
134 | "stop_words = set(stopwords.words('english'))\n",
135 | "remove_limit = 5\n",
136 | "\n",
137 | "\n",
138 | "def clean_str(string):\n",
139 | " string = re.sub(r\"[^A-Za-z0-9(),!?\\'\\`]\", \" \", string)\n",
140 | " string = re.sub(r\"\\'s\", \" \\'s\", string)\n",
141 | " string = re.sub(r\"\\'ve\", \" \\'ve\", string)\n",
142 | " string = re.sub(r\"n\\'t\", \" n\\'t\", string)\n",
143 | " string = re.sub(r\"\\'re\", \" \\'re\", string)\n",
144 | " string = re.sub(r\"\\'d\", \" \\'d\", string)\n",
145 | " string = re.sub(r\"\\'ll\", \" \\'ll\", string)\n",
146 | " string = re.sub(r\",\", \" , \", string)\n",
147 | " string = re.sub(r\"!\", \" ! \", string)\n",
148 | " string = re.sub(r\"\\(\", \" \\( \", string)\n",
149 | " string = re.sub(r\"\\)\", \" \\) \", string)\n",
150 | " string = re.sub(r\"\\?\", \" \\? \", string)\n",
151 | " string = re.sub(r\"\\s{2,}\", \" \", string)\n",
152 | " return string.strip().lower()\n",
153 | "\n",
154 | "original_word_freq = {} # to remove rare words\n",
155 | "for sentence in sentences:\n",
156 | " temp = clean_str(sentence)\n",
157 | " word_list = temp.split()\n",
158 | " for word in word_list:\n",
159 | " if word in original_word_freq:\n",
160 | " original_word_freq[word] += 1\n",
161 | " else:\n",
162 | " original_word_freq[word] = 1 \n",
163 | "\n",
164 | "tokenize_sentences = []\n",
165 | "word_list_dict = {}\n",
166 | "for sentence in sentences:\n",
167 | " temp = clean_str(sentence)\n",
168 | " word_list_temp = temp.split()\n",
169 | " doc_words = []\n",
170 | " for word in word_list_temp: \n",
171 | " if word in original_word_freq and word not in stop_words and original_word_freq[word] >= remove_limit:\n",
172 | " doc_words.append(word)\n",
173 | " word_list_dict[word] = 1\n",
174 | " tokenize_sentences.append(doc_words)\n",
175 | "word_list = list(word_list_dict.keys())\n",
176 | "vocab_length = len(word_list)\n",
177 | "\n",
178 | "#word to id dict\n",
179 | "word_id_map = {}\n",
180 | "for i in range(vocab_length):\n",
181 | " word_id_map[word_list[i]] = i "
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "metadata": {
188 | "id": "dqLUncB2Pn_L"
189 | },
190 | "outputs": [],
191 | "source": [
192 | "node_size = train_size + vocab_length + test_size"
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "metadata": {
198 | "id": "g0o8wcXgrTiD"
199 | },
200 | "source": [
201 | "# Model input"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "metadata": {
208 | "id": "EZbRV2wYxY1U"
209 | },
210 | "outputs": [],
211 | "source": [
212 | "import torch\n",
213 | "import torch.nn as nn\n",
214 | "import torch.nn.functional as F\n",
215 | "from tqdm.notebook import tqdm"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "metadata": {
221 | "id": "znJ7Grz7fQ2L"
222 | },
223 | "source": [
224 | "## Build Graph"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": null,
230 | "metadata": {
231 | "id": "-BSg1uNgV3_7"
232 | },
233 | "outputs": [],
234 | "source": [
235 | "from math import log\n",
236 | "row = []\n",
237 | "col = []\n",
238 | "weight = []"
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "metadata": {
244 | "id": "QESQPT88AqsI"
245 | },
246 | "source": [
247 | "### word-word: PMI"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": null,
253 | "metadata": {
254 | "id": "KNlJoLFagXhv"
255 | },
256 | "outputs": [],
257 | "source": [
258 | "if EDGE >= 1:\n",
259 | " window_size = 20\n",
260 | " total_W = 0\n",
261 | " word_occurrence = {}\n",
262 | " word_pair_occurrence = {}\n",
263 | "\n",
264 | " def ordered_word_pair(a, b):\n",
265 | " if a > b:\n",
266 | " return b, a\n",
267 | " else:\n",
268 | " return a, b\n",
269 | "\n",
270 | " def update_word_and_word_pair_occurrence(q):\n",
271 | " unique_q = list(set(q))\n",
272 | " for i in unique_q:\n",
273 | " try:\n",
274 | " word_occurrence[i] += 1\n",
275 | " except:\n",
276 | " word_occurrence[i] = 1\n",
277 | " for i in range(len(unique_q)):\n",
278 | " for j in range(i+1, len(unique_q)):\n",
279 | " word1 = unique_q[i]\n",
280 | " word2 = unique_q[j]\n",
281 | " word1, word2 = ordered_word_pair(word1, word2)\n",
282 | " try:\n",
283 | " word_pair_occurrence[(word1, word2)] += 1\n",
284 | " except:\n",
285 | " word_pair_occurrence[(word1, word2)] = 1\n",
286 | "\n",
287 | "\n",
288 | " for ind in tqdm(range(train_size+test_size)):\n",
289 | " words = tokenize_sentences[ind]\n",
290 | "\n",
291 | " q = []\n",
292 | " # push the first (window_size) words into a queue\n",
293 | " for i in range(min(window_size, len(words))):\n",
294 | " q += [word_id_map[words[i]]]\n",
295 | " # update the total number of the sliding windows\n",
296 | " total_W += 1\n",
297 | " # update the number of sliding windows that contain each word and word pair\n",
298 | " update_word_and_word_pair_occurrence(q)\n",
299 | "\n",
300 | " now_next_word_index = window_size\n",
301 | " # pop the first word out and let the next word in, keep doing this until the end of the document\n",
302 | " while now_next_word_index=2:\n",
424 | " tokenize_sentences_set = [set(s) for s in tokenize_sentences]\n",
425 | " jaccard_threshold = 0.2\n",
426 | " for i in tqdm(range(len(tokenize_sentences))):\n",
427 | " for j in range(i+1, len(tokenize_sentences)):\n",
428 | " jaccard_w = 1 - nltk.jaccard_distance(tokenize_sentences_set[i], tokenize_sentences_set[j])\n",
429 | " if jaccard_w > jaccard_threshold:\n",
430 | " if i < train_size:\n",
431 | " row.append(i)\n",
432 | " else:\n",
433 | " row.append(i + vocab_length)\n",
434 | " if j < train_size:\n",
435 | " col.append(j)\n",
436 | " else:\n",
437 | " col.append(vocab_length + j)\n",
438 | " weight.append(jaccard_w)\n",
439 | " if j < train_size:\n",
440 | " row.append(j)\n",
441 | " else:\n",
442 | " row.append(j + vocab_length)\n",
443 | " if i < train_size:\n",
444 | " col.append(i)\n",
445 | " else:\n",
446 | " col.append(vocab_length + i)\n",
447 | " weight.append(jaccard_w)"
448 | ]
449 | },
450 | {
451 | "cell_type": "markdown",
452 | "metadata": {
453 | "id": "uIkGgB2aZDk7"
454 | },
455 | "source": [
456 | "### Adjacent matrix"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": null,
462 | "metadata": {
463 | "id": "C0O1Ucdhod9a"
464 | },
465 | "outputs": [],
466 | "source": [
467 | "import scipy.sparse as sp\n",
468 | "adj = sp.csr_matrix((weight, (row, col)), shape=(node_size, node_size))\n",
469 | "\n",
470 | "# build symmetric adjacency matrix\n",
471 | "adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)"
472 | ]
473 | },
474 | {
475 | "cell_type": "code",
476 | "execution_count": null,
477 | "metadata": {
478 | "id": "ivyuexATkQFW"
479 | },
480 | "outputs": [],
481 | "source": [
482 | "def normalize_adj(adj):\n",
483 | " \"\"\"Symmetrically normalize adjacency matrix.\"\"\"\n",
484 | " adj = sp.coo_matrix(adj)\n",
485 | " rowsum = np.array(adj.sum(1))\n",
486 | " d_inv_sqrt = np.power(rowsum, -0.5).flatten()\n",
487 | " d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.\n",
488 | " d_mat_inv_sqrt = sp.diags(d_inv_sqrt)\n",
489 | " return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo(), d_inv_sqrt\n",
490 | " \n",
491 | "adj, norm_item = normalize_adj(adj + sp.eye(adj.shape[0]))\n",
492 | "\n",
493 | "\n",
494 | "def sparse_mx_to_torch_sparse_tensor(sparse_mx):\n",
495 | " \"\"\"Convert a scipy sparse matrix to a torch sparse tensor.\"\"\"\n",
496 | " sparse_mx = sparse_mx.tocoo().astype(np.float32)\n",
497 | " indices = torch.from_numpy(\n",
498 | " np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))\n",
499 | " values = torch.from_numpy(sparse_mx.data)\n",
500 | " shape = torch.Size(sparse_mx.shape)\n",
501 | " return torch.sparse.FloatTensor(indices, values, shape).to(device)\n",
502 | "\n",
503 | "adj = sparse_mx_to_torch_sparse_tensor(adj)"
504 | ]
505 | },
506 | {
507 | "cell_type": "markdown",
508 | "metadata": {
509 | "id": "pMgbhTstMSUA"
510 | },
511 | "source": [
512 | "## Features"
513 | ]
514 | },
515 | {
516 | "cell_type": "code",
517 | "execution_count": null,
518 | "metadata": {
519 | "id": "mP9dqCskOrXT"
520 | },
521 | "outputs": [],
522 | "source": [
523 | "if NODE == 0:\n",
524 | " features = np.arange(node_size)\n",
525 | " features = torch.FloatTensor(features).to(device)\n",
526 | "else:\n",
527 | " \n",
528 | " from flair.embeddings import TransformerDocumentEmbeddings, TransformerWordEmbeddings\n",
529 | " from flair.data import Sentence\n",
530 | " doc_embedding = TransformerDocumentEmbeddings('bert-base-uncased', fine_tune=False)\n",
531 | " word_embedding = TransformerWordEmbeddings('bert-base-uncased', layers='-1',subtoken_pooling=\"mean\")\n",
532 | "\n",
533 | " sent_embs = []\n",
534 | " word_embs = {}\n",
535 | "\n",
536 | " for ind in tqdm(range(train_size+test_size)):\n",
537 | " sent = tokenize_sentences[ind]\n",
538 | " sentence = Sentence(\" \".join(sent[:512]),use_tokenizer=False)\n",
539 | " doc_embedding.embed(sentence)\n",
540 | " sent_embs.append(sentence.get_embedding().tolist())\n",
541 | " words = Sentence(\" \".join(sent[:512]),use_tokenizer=False)\n",
542 | " word_embedding.embed(words)\n",
543 | " for token in words:\n",
544 | " word = token.text\n",
545 | " embedding = token.embedding.tolist()\n",
546 | " if word not in word_embs:\n",
547 | " word_embs[word] = embedding\n",
548 | " else:\n",
549 | " word_embs[word] = np.minimum(word_embs[word], embedding)\n",
550 | "\n",
551 | " word_embs_list = []\n",
552 | " for word in word_list:\n",
553 | " word_embs_list.append(word_embs[word])\n",
554 | "\n",
555 | " features = sent_embs[:train_size] + word_embs_list + sent_embs[train_size:]\n",
556 | "\n",
557 | " import scipy.sparse as sp\n",
558 | " def preprocess_features(features):\n",
559 | " \"\"\"Row-normalize feature matrix and convert to tuple representation\"\"\"\n",
560 | " rowsum = np.array(features.sum(1))\n",
561 | " r_inv = np.power(rowsum, -1).flatten()\n",
562 | " r_inv[np.isinf(r_inv)] = 0.\n",
563 | " r_mat_inv = sp.diags(r_inv)\n",
564 | " features = r_mat_inv.dot(features)\n",
565 | " return features\n",
566 | "\n",
567 | " features = preprocess_features(sp.csr_matrix(features)).todense()\n",
568 | " features = torch.FloatTensor(features).to(device)"
569 | ]
570 | },
571 | {
572 | "cell_type": "markdown",
573 | "metadata": {
574 | "id": "pdx6RrUvjbF0"
575 | },
576 | "source": [
577 | "# Model"
578 | ]
579 | },
580 | {
581 | "cell_type": "markdown",
582 | "metadata": {
583 | "id": "39Kj8NQujiDH"
584 | },
585 | "source": [
586 | "## GCN Layer"
587 | ]
588 | },
589 | {
590 | "cell_type": "code",
591 | "execution_count": null,
592 | "metadata": {
593 | "id": "jNVkA-h7b3sP"
594 | },
595 | "outputs": [],
596 | "source": [
597 | "import math\n",
598 | "\n",
599 | "import torch\n",
600 | "\n",
601 | "from torch.nn.parameter import Parameter\n",
602 | "from torch.nn.modules.module import Module\n",
603 | "\n",
604 | "\n",
605 | "class GraphConvolution(Module):\n",
606 | " \"\"\"\n",
607 | " Simple GCN layer, similar to https://arxiv.org/abs/1609.02907\n",
608 | " \"\"\"\n",
609 | "\n",
610 | " def __init__(self, in_features, out_features, drop_out = 0, activation=None, bias=True):\n",
611 | " super(GraphConvolution, self).__init__()\n",
612 | " self.in_features = in_features\n",
613 | " self.out_features = out_features\n",
614 | " self.weight = Parameter(torch.FloatTensor(in_features, out_features))\n",
615 | " if bias:\n",
616 | " self.bias = Parameter(torch.zeros(1, out_features))\n",
617 | " else:\n",
618 | " self.register_parameter('bias', None)\n",
619 | " self.reset_parameters(in_features, out_features)\n",
620 | " self.dropout = torch.nn.Dropout(drop_out)\n",
621 | " self.activation = activation\n",
622 | "\n",
623 | " def reset_parameters(self,in_features, out_features):\n",
624 | " stdv = np.sqrt(6.0/(in_features+out_features))\n",
625 | " # stdv = 1. / math.sqrt(self.weight.size(1))\n",
626 | " self.weight.data.uniform_(-stdv, stdv)\n",
627 | " # if self.bias is not None:\n",
628 | " # torch.nn.init.zeros_(self.bias)\n",
629 | " # self.bias.data.uniform_(-stdv, stdv)\n",
630 | "\n",
631 | "\n",
632 | " def forward(self, input, adj, feature_less = False):\n",
633 | " if feature_less:\n",
634 | " support = self.weight\n",
635 | " support = self.dropout(support)\n",
636 | " else:\n",
637 | " input = self.dropout(input)\n",
638 | " support = torch.mm(input, self.weight)\n",
639 | " output = torch.spmm(adj, support)\n",
640 | " if self.bias is not None:\n",
641 | " output = output + self.bias\n",
642 | " if self.activation is not None:\n",
643 | " output = self.activation(output)\n",
644 | " return output\n",
645 | "\n",
646 | " def __repr__(self):\n",
647 | " return self.__class__.__name__ + ' (' \\\n",
648 | " + str(self.in_features) + ' -> ' \\\n",
649 | " + str(self.out_features) + ')'"
650 | ]
651 | },
652 | {
653 | "cell_type": "markdown",
654 | "metadata": {
655 | "id": "k57M4sz4s4Md"
656 | },
657 | "source": [
658 | "## GCN Model"
659 | ]
660 | },
661 | {
662 | "cell_type": "code",
663 | "execution_count": null,
664 | "metadata": {
665 | "id": "aJ-ZQuMzs5tZ"
666 | },
667 | "outputs": [],
668 | "source": [
669 | "import torch.nn as nn\n",
670 | "import torch.nn.functional as F\n",
671 | "\n",
672 | "class GCN(nn.Module):\n",
673 | " def __init__(self, nfeat, nhid, nclass, dropout, n_layers = 2):\n",
674 | " super(GCN, self).__init__()\n",
675 | " self.n_layers = n_layers\n",
676 | " self.gc_list = []\n",
677 | " if n_layers >= 2:\n",
678 | " self.gc1 = GraphConvolution(nfeat, nhid, dropout, activation = nn.ReLU())\n",
679 | " self.gc_list = nn.ModuleList([GraphConvolution(nhid, nhid, dropout, activation = nn.ReLU()) for _ in range(self.n_layers-2)])\n",
680 | " self.gcf = GraphConvolution(nhid, nclass, dropout)\n",
681 | " else:\n",
682 | " self.gc1 = GraphConvolution(nfeat, nclass, dropout)\n",
683 | "\n",
684 | " def forward(self, x, adj):\n",
685 | " if self.n_layers>=2:\n",
686 | " x = self.gc1(x, adj, feature_less = True)\n",
687 | " for i in range(self.n_layers-2):\n",
688 | " x = self.gc_list[i](x,adj)\n",
689 | " x = self.gcf(x,adj)\n",
690 | " else:\n",
691 | " x = self.gc1(x, adj, feature_less = True)\n",
692 | " return x"
693 | ]
694 | },
695 | {
696 | "cell_type": "code",
697 | "execution_count": null,
698 | "metadata": {
699 | "id": "qmhOG1yG--Ji"
700 | },
701 | "outputs": [],
702 | "source": [
703 | "def cal_accuracy(predictions,labels):\n",
704 | " pred = torch.argmax(predictions,-1).cpu().tolist()\n",
705 | " lab = labels.cpu().tolist()\n",
706 | " cor = 0\n",
707 | " for i in range(len(pred)):\n",
708 | " if pred[i] == lab[i]:\n",
709 | " cor += 1\n",
710 | " return cor/len(pred)"
711 | ]
712 | },
713 | {
714 | "cell_type": "markdown",
715 | "metadata": {
716 | "id": "zEE4JxeUthCb"
717 | },
718 | "source": [
719 | "# Training"
720 | ]
721 | },
722 | {
723 | "cell_type": "markdown",
724 | "metadata": {
725 | "id": "bIxII4QoticA"
726 | },
727 | "source": [
728 | "## Initialize model"
729 | ]
730 | },
731 | {
732 | "cell_type": "code",
733 | "execution_count": null,
734 | "metadata": {
735 | "id": "hdNsgxMG-Wwu"
736 | },
737 | "outputs": [],
738 | "source": [
739 | "import torch.optim as optim\n",
740 | "\n",
741 | "\n",
742 | "criterion = nn.CrossEntropyLoss()\n",
743 | "\n",
744 | "model = GCN(nfeat=node_size, nhid=HIDDEN_DIM, nclass=num_class, dropout=DROP_OUT,n_layers=NUM_LAYERS).to(device)\n",
745 | "optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)"
746 | ]
747 | },
748 | {
749 | "cell_type": "markdown",
750 | "metadata": {
751 | "id": "T98r4qZuuFyn"
752 | },
753 | "source": [
754 | "## Training and Validating"
755 | ]
756 | },
757 | {
758 | "cell_type": "code",
759 | "execution_count": null,
760 | "metadata": {
761 | "id": "Bv9br9pgGw9R"
762 | },
763 | "outputs": [],
764 | "source": [
765 | "def generate_train_val(train_pro=0.9):\n",
766 | " real_train_size = int(train_pro*train_size)\n",
767 | " val_size = train_size-real_train_size\n",
768 | "\n",
769 | " idx_train = np.random.choice(train_size, real_train_size,replace=False)\n",
770 | " idx_train.sort()\n",
771 | " idx_val = []\n",
772 | " pointer = 0\n",
773 | " for v in range(train_size):\n",
774 | " if pointer EARLY_STOPPING and np.min(val_loss[-EARLY_STOPPING:]) > np.min(val_loss[:-EARLY_STOPPING]) :\n",
821 | " if show_result:\n",
822 | " print(\"Early Stopping...\")\n",
823 | " break\n",
824 | "\n",
825 | "train_model()"
826 | ]
827 | },
828 | {
829 | "cell_type": "markdown",
830 | "metadata": {
831 | "id": "OQwlWq6dyYJm"
832 | },
833 | "source": [
834 | "## Evaluation"
835 | ]
836 | },
837 | {
838 | "cell_type": "code",
839 | "execution_count": null,
840 | "metadata": {
841 | "id": "jmPNukmk40gd"
842 | },
843 | "outputs": [],
844 | "source": [
845 | "from sklearn.metrics import f1_score, accuracy_score\n",
846 | "def test():\n",
847 | " model.eval()\n",
848 | " output = model(features, adj)\n",
849 | " predictions = torch.argmax(output[idx_test],-1).cpu().tolist()\n",
850 | " acc = accuracy_score(test_labels,predictions)\n",
851 | " f11 = f1_score(test_labels,predictions, average='macro')\n",
852 | " f12 = f1_score(test_labels,predictions, average = 'weighted')\n",
853 | " return acc, f11, f12\n",
854 | "\n",
855 | "print(test())"
856 | ]
857 | },
858 | {
859 | "cell_type": "markdown",
860 | "metadata": {
861 | "id": "LOFsVlv4hTgc"
862 | },
863 | "source": [
864 | "# Test 10 times"
865 | ]
866 | },
867 | {
868 | "cell_type": "code",
869 | "execution_count": null,
870 | "metadata": {
871 | "id": "ydMqrCkehVPW"
872 | },
873 | "outputs": [],
874 | "source": [
875 | "test_acc_list = []\n",
876 | "test_f11_list = []\n",
877 | "test_f12_list = []\n",
878 | "\n",
879 | "for t in range(10):\n",
880 | " model = GCN(nfeat=node_size, nhid=HIDDEN_DIM, nclass=num_class, dropout=DROP_OUT,n_layers=NUM_LAYERS).to(device)\n",
881 | " optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
882 | " idx_train, idx_val, idx_test = generate_train_val()\n",
883 | " train_model(show_result=False)\n",
884 | " acc, f11, f12 = test()\n",
885 | " test_acc_list.append(acc)\n",
886 | " test_f11_list.append(f11)\n",
887 | " test_f12_list.append(f12)\n",
888 | "\n",
889 | "\n",
890 | "print(\"Accuracy:\",np.round(np.mean(test_acc_list),4))\n",
891 | "print(\"Macro F1:\",np.round(np.mean(test_f11_list),4))\n",
892 | "print(\"Weighted F1:\",np.round(np.mean(test_f12_list),4))"
893 | ]
894 | }
895 | ],
896 | "metadata": {
897 | "accelerator": "GPU",
898 | "colab": {
899 | "collapsed_sections": [],
900 | "machine_shape": "hm",
901 | "name": "final code.ipynb",
902 | "provenance": [],
903 | "toc_visible": true
904 | },
905 | "kernelspec": {
906 | "display_name": "Python 3",
907 | "name": "python3"
908 | }
909 | },
910 | "nbformat": 4,
911 | "nbformat_minor": 0
912 | }
913 |
--------------------------------------------------------------------------------