├── LICENSE ├── README.md ├── r52-test-all-terms.txt ├── r52-train-all-terms.txt ├── r8-test-all-terms.txt ├── r8-train-all-terms.txt └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text-Level-GNN 2 | An implementation to the paper: Text Level Graph Neural Network for Text Classification (https://arxiv.org/pdf/1910.02356.pdf) 3 | 4 | ## Features: 5 | - Dynamic edge weights instead of static edge weights 6 | - All documents are from a big graph instead of every documents having its own structure 7 | - Public edge sharing (achieved by computing edge statistics during dataset construction and masking during training, a novel mechanism roughly described by the paper yet without much further information) 8 | - Flexible argument controls and early stopping features 9 | - Detailed explanations about intermediate operations 10 | - The number of parameters in this model is close to the amount of parameters mentioned in the paper 11 | 12 | ## File structure: 13 | ``` 14 | +---embeddings\ 15 | | +---glove.6B.50d.txt 16 | | +---glove.6B.100d.txt 17 | | +---glove.6B.200d.txt 18 | | +---glove.6B.300d.txt 19 | +---train.py 20 | +---r52-test-all-terms.txt 21 | +---r52-train-all-terms.txt 22 | +---r8-test-all-terms.txt 23 | +---r8-train-all-terms.txt 24 | ``` 25 | 26 | Since the original link DOES NOT work anymore, I hereby provide the original link and the corresponding dataset file in this repository for anyone who is also looking for the r8 and r52 dataset. 27 | 28 | https://www.cs.umb.edu/~smimarog/textmining/datasets/r8-train-all-terms.txt => r8-train-all-terms.txt 29 | https://www.cs.umb.edu/~smimarog/textmining/datasets/r8-test-all-terms.txt => r8-test-all-terms.txt 30 | https://www.cs.umb.edu/~smimarog/textmining/datasets/r52-train-all-terms.txt => r52-train-all-terms.txt 31 | https://www.cs.umb.edu/~smimarog/textmining/datasets/r52-test-all-terms.txt => 32 | r52-test-all-terms.txt 33 | 34 | ## Environment: 35 | - Python 3.7.4 36 | - PyTorch 1.5.1 + CUDA 10.1 37 | - Pandas 1.0.5 38 | - Numpy 1.19.0 39 | 40 | Successful run on RTX 2070, RTX 2080 Ti and RTX 3090. However, the memory consumption is quite large that it requires smaller batch size / shorter MAX_LENGTH / smaller embedding_size on RTX 2070. 41 | 42 | ## Usage: 43 | - Linux: 44 | - OMP_NUM_THREADS=1 python train.py --cuda=0 --embedding_size=300 --p=3 --min_freq=2 --max_length=70 --dropout=0 --epoch=300 45 | - Windows: 46 | - python train.py --cuda=0 --embedding_size=300 --p=3 --min_freq=2 --max_length=70 --dropout=0 --epoch=300 47 | 48 | ## Result: 49 | I only tested the model on r8 dataset and is unable to achieve the figure as described in the paper despite having tried some hyperparameter tunings. The closest run that I could get is: 50 | Train Accuracy|Validation Accuracy|Test Accuracy 51 | :---:|:---:|:---: 52 | 99.91%|95.7%|96.2% 53 | 54 | with `embedding_size=300`, `p=3` and `70<=max_length<=150` and `dropout=0`. 55 | As the experiment settings described in the paper is not clearly stated, I assumed they used a learning rate decay mechanism too. I also added a warming up mechanism to pretrain the model. But actually the model converged quite fast and does not even need to use warming up technique. 56 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset, DataLoader, random_split 6 | import numpy as np 7 | from time import time 8 | import argparse 9 | 10 | 11 | class GloveTokenizer: 12 | def __init__(self, filename, unk='', pad=''): 13 | self.filename = filename 14 | self.unk = unk 15 | self.pad = pad 16 | self.stoi = dict() 17 | self.itos = dict() 18 | self.embedding_matrix = list() 19 | with open(filename, 'r', encoding='utf8') as f: # Read tokenizer file 20 | for i, line in enumerate(f): 21 | values = line.split() 22 | self.stoi[values[0]] = i 23 | self.itos[i] = values[0] 24 | self.embedding_matrix.append([float(v) for v in values[1:]]) 25 | if self.unk is not None: # Add unk token into the tokenizer 26 | i += 1 27 | self.stoi[self.unk] = i 28 | self.itos[i] = self.unk 29 | self.embedding_matrix.append(np.random.rand(len(self.embedding_matrix[0]))) 30 | if self.pad is not None: # Add pad token into the tokenizer 31 | i += 1 32 | self.stoi[self.pad] = i 33 | self.itos[i] = self.pad 34 | self.embedding_matrix.append(np.zeros(len(self.embedding_matrix[0]))) 35 | self.embedding_matrix = np.array(self.embedding_matrix).astype(np.float32) # Convert if from double to float for efficiency 36 | 37 | def encode(self, sentence): 38 | if type(sentence) == str: 39 | sentence = sentence.split(' ') 40 | elif len(sentence): # Convertible to list 41 | sentence = list(sentence) 42 | else: 43 | raise TypeError('sentence should be either a str or a list of str!') 44 | encoded_sentence = [] 45 | for word in sentence: 46 | encoded_sentence.append(self.stoi.get(word, self.stoi[self.unk])) 47 | return encoded_sentence 48 | 49 | def decode(self, encoded_sentence): 50 | try: 51 | encoded_sentence = list(encoded_sentence) 52 | except Exception as e: 53 | print(e) 54 | raise TypeError('encoded_sentence should be either a str or a data type that is convertible to list type!') 55 | sentence = [] 56 | for encoded_word in encoded_sentence: 57 | sentence.append(self.itos[encoded_word]) 58 | return sentence 59 | 60 | def embedding(self, encoded_sentence): 61 | return self.embedding_matrix[np.array(encoded_sentence)] 62 | 63 | 64 | class TextLevelGNNDataset(Dataset): # For instantiating train, validation and test dataset 65 | def __init__(self, node_sets, neighbor_sets, public_edge_mask, labels): 66 | super(TextLevelGNNDataset).__init__() 67 | self.node_sets = node_sets 68 | self.neighbor_sets = neighbor_sets 69 | self.public_edge_mask = public_edge_mask 70 | self.labels = labels 71 | 72 | def __getitem__(self, i): 73 | return torch.LongTensor(self.node_sets[i]), \ 74 | torch.nn.utils.rnn.pad_sequence([torch.LongTensor(neighbor) for neighbor in self.neighbor_sets[i]], batch_first=True, padding_value=1), \ 75 | self.public_edge_mask[torch.LongTensor(self.node_sets[i]).unsqueeze(-1).repeat(1, torch.nn.utils.rnn.pad_sequence([torch.LongTensor(neighbor) for neighbor in self.neighbor_sets[i]], batch_first=True, padding_value=1).shape[-1]), torch.nn.utils.rnn.pad_sequence([torch.LongTensor(neighbor) for neighbor in self.neighbor_sets[i]], batch_first=True, padding_value=1)], \ 76 | torch.FloatTensor(self.labels[i]) 77 | 78 | def __len__(self): 79 | return len(self.labels) 80 | 81 | 82 | class TextLevelGNNDatasetClass: # This class is used to achieve parameters sharing among datasets 83 | def __init__(self, train_filename, test_filename, tokenizer, MAX_LENGTH=10, p=2, min_freq=2, train_validation_split=0.8): 84 | self.train_filename = train_filename 85 | self.test_filename = test_filename 86 | self.tokenizer = tokenizer 87 | self.MAX_LENGTH = MAX_LENGTH 88 | self.p = p 89 | self.min_freq = min_freq 90 | self.train_validation_split = train_validation_split 91 | 92 | self.train_data = pd.read_csv(self.train_filename, sep='\t', header=None) 93 | self.test_data = pd.read_csv(self.test_filename, sep='\t', header=None) 94 | 95 | self.stoi = {'': 0, '': 1} # Re-index 96 | self.itos = {0: '', 1: ''} # Re-index 97 | self.vocab_count = len(self.stoi) 98 | self.embedding_matrix = None 99 | self.label_dict = dict(zip(self.train_data[0].unique(), pd.get_dummies(self.train_data[0].unique()).values.tolist())) 100 | 101 | self.train_dataset, self.validation_dataset = random_split(self.train_data.to_numpy(), [int(len(self.train_data) * train_validation_split), len(self.train_data) - int(len(self.train_data) * train_validation_split)]) 102 | self.test_dataset = self.test_data.to_numpy() 103 | 104 | self.build_vocab() # Based on train_dataset only. Updates self.stoi, self.itos, self.vocab_count and self.embedding_matrix 105 | 106 | self.train_dataset, self.validation_dataset, self.test_dataset, self.edge_stat, self.public_edge_mask = self.prepare_dataset() 107 | 108 | def build_vocab(self): 109 | vocab_list = [sentence.split(' ') for _, sentence in self.train_dataset] 110 | unique_vocab = [] 111 | for vocab in vocab_list: 112 | unique_vocab.extend(vocab) 113 | unique_vocab = list(set(unique_vocab)) 114 | for vocab in unique_vocab: 115 | if vocab in self.tokenizer.stoi.keys(): 116 | self.stoi[vocab] = self.vocab_count 117 | self.itos[self.vocab_count] = vocab 118 | self.vocab_count += 1 119 | self.embedding_matrix = self.tokenizer.embedding(self.tokenizer.encode(list(self.stoi.keys()))) 120 | 121 | def prepare_dataset(self): # will also build self.edge_stat and self.public_edge_mask 122 | # preparing self.train_dataset 123 | node_sets = [[self.stoi.get(vocab, 0) for vocab in sentence.strip().split(' ')][:self.MAX_LENGTH] for _, sentence in self.train_dataset] # Only retrieve the first MAX_LENGTH words in each document 124 | neighbor_sets = [create_neighbor_set(node_set, p=self.p) for node_set in node_sets] 125 | labels = [self.label_dict[label] for label, _ in self.train_dataset] 126 | 127 | # Construct edge statistics and public edge mask 128 | edge_stat, public_edge_mask = self.build_public_edge_mask(node_sets, neighbor_sets, min_freq=self.min_freq) 129 | 130 | train_dataset = TextLevelGNNDataset(node_sets, neighbor_sets, public_edge_mask, labels) 131 | 132 | # preparing self.validation_dataset 133 | node_sets = [[self.stoi.get(vocab, 0) for vocab in sentence.strip().split(' ')][:self.MAX_LENGTH] for _, sentence in self.validation_dataset] # Only retrieve the first MAX_LENGTH words in each document 134 | neighbor_sets = [create_neighbor_set(node_set, p=self.p) for node_set in node_sets] 135 | labels = [self.label_dict[label] for label, _ in self.validation_dataset] 136 | validation_dataset = TextLevelGNNDataset(node_sets, neighbor_sets, public_edge_mask, labels) 137 | 138 | # preparing self.test_dataset 139 | node_sets = [[self.stoi.get(vocab, 0) for vocab in sentence.strip().split(' ')][:self.MAX_LENGTH] for _, sentence in self.test_dataset] # Only retrieve the first MAX_LENGTH words in each document 140 | neighbor_sets = [create_neighbor_set(node_set, p=self.p) for node_set in node_sets] 141 | labels = [self.label_dict[label] for label, _ in self.test_dataset] 142 | test_dataset = TextLevelGNNDataset(node_sets, neighbor_sets, public_edge_mask, labels) 143 | 144 | return train_dataset, validation_dataset, test_dataset, edge_stat, public_edge_mask 145 | 146 | def build_public_edge_mask(self, node_sets, neighbor_sets, min_freq=2): 147 | edge_stat = torch.zeros(self.vocab_count, self.vocab_count) 148 | for node_set, neighbor_set in zip(node_sets, neighbor_sets): 149 | for neighbor in neighbor_set: 150 | for to_node in neighbor: 151 | edge_stat[node_set, to_node] += 1 152 | public_edge_mask = edge_stat < min_freq # mark True at uncommon edges 153 | return edge_stat, public_edge_mask 154 | 155 | 156 | def create_neighbor_set(node_set, p=2): 157 | if type(node_set[0]) != int: 158 | raise ValueError('node_set should be a 1D list!') 159 | if p < 0: 160 | raise ValueError('p should be an integer >= 0!') 161 | sequence_length = len(node_set) 162 | neighbor_set = [] 163 | for i in range(sequence_length): 164 | neighbor = [] 165 | for j in range(-p, p+1): 166 | if 0 <= i + j < sequence_length: 167 | neighbor.append(node_set[i+j]) 168 | neighbor_set.append(neighbor) 169 | return neighbor_set 170 | 171 | 172 | def pad_custom_sequence(sequences): 173 | ''' 174 | To pad different sequences into a padded tensor for training. The main purpose of this function is to separate different sequence, pad them in different ways and return padded sequences. 175 | Input: 176 | sequences : A sequence with a length of 4, representing the node sets sequence in index 0, neighbor sets sequence in index 1, public edge mask sequence in index 2 and label sequence in index 3. 177 | And the length of each sequences are same as the batch size. 178 | sequences: [node_sets_sequence, neighbor_sets_sequence, public_edge_mask_sequence, label_sequence] 179 | Return: 180 | node_sets_sequence : The padded node sets sequence (works with batch_size >= 1). 181 | neighbor_sets_sequence : The padded neighbor sets sequence (works with batch_size >= 1). 182 | public_edge_mask_sequence : The padded public edge mask sequence (works with batch_size >= 1). 183 | label_sequence : The padded label sequence (works with batch_size >= 1). 184 | ''' 185 | node_sets_sequence = [] 186 | neighbor_sets_sequence = [] 187 | public_edge_mask_sequence = [] 188 | label_sequence = [] 189 | for node_sets, neighbor_sets, public_edge_mask, label in sequences: 190 | node_sets_sequence.append(node_sets) 191 | neighbor_sets_sequence.append(neighbor_sets) 192 | public_edge_mask_sequence.append(public_edge_mask) 193 | label_sequence.append(label) 194 | node_sets_sequence = torch.nn.utils.rnn.pad_sequence(node_sets_sequence, batch_first=True, padding_value=1) 195 | neighbor_sets_sequence, _ = padding_tensor(neighbor_sets_sequence) 196 | public_edge_mask_sequence, _ = padding_tensor(public_edge_mask_sequence) 197 | label_sequence = torch.nn.utils.rnn.pad_sequence(label_sequence, batch_first=True, padding_value=1) 198 | return node_sets_sequence, neighbor_sets_sequence, public_edge_mask_sequence, label_sequence 199 | 200 | 201 | def padding_tensor(sequences, padding_idx=1): 202 | ''' 203 | To pad tensor of different shape to be of the same shape, i.e. padding [tensor.rand(2, 3), tensor.rand(3, 5)] to a shape (2, 3, 5), where 0th dimension is batch_size, 1st and 2nd dimensions are padded. 204 | Input: 205 | sequences : A list of tensors 206 | padding_idx : The index that corresponds to the padding index 207 | Return: 208 | out_tensor : The padded tensor 209 | mask : A boolean torch tensor where 1 (represents '') are marked as true 210 | ''' 211 | num = len(sequences) 212 | max_len_0 = max([s.shape[0] for s in sequences]) 213 | max_len_1 = max([s.shape[1] for s in sequences]) 214 | out_dims = (num, max_len_0, max_len_1) 215 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_idx) 216 | for i, tensor in enumerate(sequences): 217 | len_0 = tensor.size(0) 218 | len_1 = tensor.size(1) 219 | out_tensor[i, :len_0, :len_1] = tensor 220 | mask = out_tensor == padding_idx # Marking all places with padding_idx as mask 221 | return out_tensor, mask 222 | 223 | 224 | class MessagePassing(nn.Module): 225 | def __init__(self, vertice_count, input_size, out_size, dropout_rate=0, padding_idx=1): 226 | super(MessagePassing, self).__init__() 227 | self.vertice_count = vertice_count # |V| 228 | self.input_size = input_size # d 229 | self.out_size = out_size # c 230 | self.dropout_rate = dropout_rate 231 | self.padding_idx = padding_idx 232 | self.information_rate = nn.Parameter(torch.rand(self.vertice_count, 1)) # (|V|, 1), which means it is a column vector 233 | self.linear = nn.Linear(self.input_size, self.out_size) # (d, c) 234 | self.dropout = nn.Dropout(self.dropout_rate) 235 | 236 | def forward(self, node_sets, embedded_node, edge_weight, embedded_neighbor_node): 237 | # node_sets: (batch_size, l) 238 | # embedded_node: (batch_size, l, d) 239 | # edge_weight: (batch_size, max_sentence_length, max_neighbor_count) 240 | # embedded_neighbor_node: (batch_size, max_sentence_length, max_neighbor_count, d) 241 | 242 | tmp_tensor = (edge_weight.view(-1, 1) * embedded_neighbor_node.view(-1, self.input_size)).view(embedded_neighbor_node.shape) # (batch_size, max_sentence_length, max_neighbor_count, d) 243 | tmp_tensor = tmp_tensor.masked_fill(tmp_tensor == 0, -1e18) # (batch_size, max_sentence_length, max_neighbor_count, d), mask for M such that masked places are marked as -1e18 244 | tmp_tensor = self.dropout(tmp_tensor) 245 | M = tmp_tensor.max(dim=2)[0] # (batch_size, max_sentence_length, d), which is same shape as embedded_node (batch_size, l, d) 246 | information_rate = self.information_rate[node_sets] # (batch_size, l, 1) 247 | information_rate = information_rate.masked_fill((node_sets == self.padding_idx).unsqueeze(-1), 1) # (batch_size, l, 1), Fill the information rate of the padding index as 1, such that new e_n = (1-i_r) * M + i_r * e_n = (1-1) * 0 + 1 * e_n = e_n (no update) 248 | embedded_node = (1 - information_rate) * M + information_rate * embedded_node # (batch_size, l, d) 249 | sum_embedded_node = embedded_node.sum(dim=1) # (batch_size, d) 250 | x = F.relu(self.linear(sum_embedded_node)) # (batch_size, c) 251 | # x = self.dropout(x) # if putting dropout with p=0.5 here, it is equivalent to wiping 4 choices out of 8 choices on the question sheet, which does not make sense. If a dropout layer is placed at here, it works the best when p=0 (disabled), followed by p=0.05, ..., p=0.5 (worst and does not even converge). 252 | y = F.softmax(x, dim=1) # (batch_size, c) along the c dimension 253 | return y 254 | 255 | 256 | class TextLevelGNN(nn.Module): 257 | def __init__(self, pretrained_embeddings, out_size=8, dropout_rate=0, padding_idx=1): 258 | super(TextLevelGNN, self).__init__() 259 | self.out_size = out_size # c 260 | self.padding_idx = padding_idx 261 | self.weight_matrix = nn.Parameter(torch.randn(pretrained_embeddings.shape[0], pretrained_embeddings.shape[0])) # (|V|, |V|) 262 | self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings, freeze=False, padding_idx=self.padding_idx) # (|V|, d) 263 | self.message_passing = MessagePassing(vertice_count=pretrained_embeddings.shape[0], input_size=pretrained_embeddings.shape[1], out_size=self.out_size, dropout_rate=dropout_rate, padding_idx=self.padding_idx) # input_size: (d,); out_size: (c,) 264 | self.public_edge_weight = nn.Parameter(torch.randn(1, 1)) # (1, 1) 265 | 266 | def forward(self, node_sets, neighbor_sets, public_edge_mask): 267 | # node_sets: (batch_size, l) 268 | # neighbor_sets: (batch_size, max_sentence_length, max_neighbor_count) 269 | # neighbor_sets_mask: (batch_size, max_sentence_length, max_neighbor_count) (no need) 270 | # public_edge_mask: (batch_size, max_sentence_length, max_neighbor_count) 271 | 272 | embedded_node = self.embedding(node_sets) # (batch_size, l, d) 273 | edge_weight = model.weight_matrix[node_sets.unsqueeze(2).repeat(1, 1, neighbor_sets.shape[-1]), neighbor_sets] # (batch_size, max_sentence_length, max_neighbor_count), neighbor_sets.shape[-1]: eg p=2, this expression=5; p=3, this expression=7. This is to first make node_sets to have same shape with neighbor_sets, then just do 1 query instead of 32*100 queries to speed up performance 274 | a = edge_weight * ~public_edge_mask # (batch_size, max_sentence_length, max_neighbor_count) 275 | b = self.public_edge_weight.unsqueeze(2).expand(1, public_edge_mask.shape[-2], public_edge_mask.shape[-1]) * public_edge_mask # (batch_size, max_sentence_length, max_neighbor_count) 276 | edge_weight = a + b # (batch_size, max_sentence_length, max_neighbor_count) 277 | embedded_neighbor_node = self.embedding(neighbor_sets) # (batch_size, max_sentece_length, max_neighbor_count, d) 278 | 279 | # Apply mask to edge_weight, to mask and cut-off any relationships to the padding nodes 280 | edge_weight = edge_weight.masked_fill((node_sets.unsqueeze(2).repeat(1, 1, neighbor_sets.shape[-1]) == self.padding_idx) | (neighbor_sets == self.padding_idx), 0) # (batch_size, max_sentence_length, max_neighbor_count) 281 | x = self.message_passing(node_sets, embedded_node, edge_weight, embedded_neighbor_node) # (batch_size, c) 282 | return x 283 | 284 | 285 | parser = argparse.ArgumentParser() 286 | parser.add_argument('--cuda', default='0', type=str, required=False, 287 | help='Choosing which cuda to use') 288 | parser.add_argument('--embedding_size', default=300, type=int, required=False, 289 | help='Number of hidden units in each layer of the graph embedding part') 290 | parser.add_argument('--p', default=3, type=int, required=False, 291 | help='The window size') 292 | parser.add_argument('--min_freq', default=2, type=int, required=False, 293 | help='The minimum no. of occurrence for a word to be considered as a meaningful word. Words with less than this occurrence will be mapped to a globally shared embedding weight (to the token). It corresponds to the parameter k in the original paper.') 294 | parser.add_argument('--max_length', default=70, type=int, required=False, 295 | help='The max length of each document to be processed') 296 | parser.add_argument('--dropout', default=0, type=float, required=False, 297 | help='Dropout rate') 298 | parser.add_argument('--lr', default=1e-3, type=float, required=False, 299 | help='Initial learning rate') 300 | parser.add_argument('--lr_decay_factor', default=0.9, type=float, required=False, 301 | help='Multiplicative factor of learning rate decays') 302 | parser.add_argument('--lr_decay_every', default=5, type=int, required=False, 303 | help='Decaying learning rate every ? epochs') 304 | parser.add_argument('--weight_decay', default=1e-4, type=float, required=False, 305 | help='Weight decay (L2 penalty)') 306 | parser.add_argument('--warm_up_epoch', default=0, type=int, required=False, 307 | help='Pretraining for ? epochs before early stopping to be in effect') 308 | parser.add_argument('--early_stopping_patience', default=10, type=int, required=False, 309 | help='Waiting for ? more epochs after the best epoch to see any further improvements') 310 | parser.add_argument('--early_stopping_criteria', default='loss', type=str, required=False, 311 | choices=['accuracy', 'loss'], 312 | help='Early stopping according to validation accuracy or validation loss') 313 | parser.add_argument("--epoch", default=300, type=int, required=False, 314 | help='Number of epochs to train') 315 | args = parser.parse_args() 316 | 317 | tokenizer = GloveTokenizer(f'embeddings/glove.6B.{args.embedding_size}d.txt') 318 | dataset = TextLevelGNNDatasetClass(train_filename='r8-train-all-terms.txt', 319 | test_filename='r8-test-all-terms.txt', 320 | train_validation_split=0.8, 321 | tokenizer=tokenizer, 322 | p=args.p, 323 | min_freq=args.min_freq, 324 | MAX_LENGTH=args.max_length) 325 | train_loader = DataLoader(dataset.train_dataset, batch_size=32, shuffle=True, collate_fn=pad_custom_sequence) 326 | validation_loader = DataLoader(dataset.validation_dataset, batch_size=32, shuffle=True, collate_fn=pad_custom_sequence) 327 | test_loader = DataLoader(dataset.test_dataset, batch_size=32, shuffle=True, collate_fn=pad_custom_sequence) 328 | 329 | device = torch.device(f'cuda:{args.cuda}') if torch.cuda.is_available() else torch.device('cpu') 330 | model = TextLevelGNN(pretrained_embeddings=torch.tensor(dataset.embedding_matrix), dropout_rate=args.dropout).to(device) 331 | criterion = nn.BCELoss() 332 | 333 | lr = args.lr 334 | lr_decay_factor = args.lr_decay_factor 335 | lr_decay_every = args.lr_decay_every 336 | weight_decay = args.weight_decay 337 | 338 | warm_up_epoch = args.warm_up_epoch 339 | early_stopping_patience = args.early_stopping_patience 340 | early_stopping_criteria = args.early_stopping_criteria 341 | best_epoch = 0 # Initialize 342 | 343 | training = {} 344 | validation = {} 345 | testing = {} 346 | training['accuracy'] = [] 347 | training['loss'] = [] 348 | validation['accuracy'] = [] 349 | validation['loss'] = [] 350 | testing['accuracy'] = [] 351 | testing['loss'] = [] 352 | 353 | for epoch in range(args.epoch): 354 | model.train() 355 | train_loss = 0 356 | train_correct_items = 0 357 | previous_epoch_timestamp = time() 358 | 359 | if epoch % lr_decay_every == 0: # Update optimizer for every lr_decay_every epochs 360 | if epoch != 0: # When it is the first epoch, disable the lr_decay_factor 361 | lr *= lr_decay_factor 362 | optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 363 | 364 | for i, (node_sets, neighbor_sets, public_edge_masks, labels) in enumerate(train_loader): 365 | # print('Finished batch:', i) 366 | node_sets = node_sets.to(device) 367 | neighbor_sets = neighbor_sets.to(device) 368 | public_edge_masks = public_edge_masks.to(device) 369 | labels = labels.to(device) 370 | prediction = model(node_sets, neighbor_sets, public_edge_masks) 371 | loss = criterion(prediction, labels).to(device) 372 | optimizer.zero_grad() 373 | loss.backward() 374 | optimizer.step() 375 | train_loss += loss.item() 376 | train_correct_items += (prediction.argmax(dim=1) == labels.argmax(dim=1)).sum().item() 377 | train_accuracy = train_correct_items / len(dataset.train_dataset) 378 | 379 | model.eval() 380 | validation_loss = 0 381 | validation_correct_items = 0 382 | for i, (node_sets, neighbor_sets, public_edge_masks, labels) in enumerate(validation_loader): 383 | node_sets = node_sets.to(device) 384 | neighbor_sets = neighbor_sets.to(device) 385 | public_edge_masks = public_edge_masks.to(device) 386 | labels = labels.to(device) 387 | prediction = model(node_sets, neighbor_sets, public_edge_masks) 388 | loss = criterion(prediction, labels).to(device) 389 | validation_loss += loss.item() 390 | validation_correct_items += (prediction.argmax(dim=1) == labels.argmax(dim=1)).sum().item() 391 | validation_accuracy = validation_correct_items / len(dataset.validation_dataset) 392 | 393 | # model.eval() 394 | test_loss = 0 395 | test_correct_items = 0 396 | for i, (node_sets, neighbor_sets, public_edge_masks, labels) in enumerate(test_loader): 397 | node_sets = node_sets.to(device) 398 | neighbor_sets = neighbor_sets.to(device) 399 | public_edge_masks = public_edge_masks.to(device) 400 | labels = labels.to(device) 401 | prediction = model(node_sets, neighbor_sets, public_edge_masks) 402 | loss = criterion(prediction, labels).to(device) 403 | test_loss += loss.item() 404 | test_correct_items += (prediction.argmax(dim=1) == labels.argmax(dim=1)).sum().item() 405 | test_accuracy = test_correct_items / len(dataset.test_dataset) 406 | print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {validation_loss:.4f}, Testing Loss: {test_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, Validation Accuracy: {validation_accuracy:.4f}, Testing Accuracy: {test_accuracy:.4f}, Time Used: {time()-previous_epoch_timestamp:.2f}s') 407 | training['accuracy'].append(train_accuracy) 408 | training['loss'].append(train_loss) 409 | validation['accuracy'].append(validation_accuracy) 410 | validation['loss'].append(validation_loss) 411 | testing['accuracy'].append(test_accuracy) 412 | testing['loss'].append(test_loss) 413 | 414 | # add warmup mechanism for warm_up_epoch epochs 415 | if epoch >= warm_up_epoch: 416 | best_epoch = warm_up_epoch 417 | # early stopping 418 | if early_stopping_criteria == 'accuracy': 419 | if validation['accuracy'][epoch] > validation['accuracy'][best_epoch]: 420 | best_epoch = epoch 421 | elif epoch >= best_epoch + early_stopping_patience: 422 | print(f'Early stopping... (No further increase in validation accuracy) for consecutive {early_stopping_patience} epochs.') 423 | break 424 | if early_stopping_criteria == 'loss': 425 | if validation['loss'][epoch] < validation['loss'][best_epoch]: 426 | best_epoch = epoch 427 | elif epoch >= best_epoch + early_stopping_patience: 428 | print(f'Early stopping... (No further decrease in validation loss) for consecutive {early_stopping_patience} epochs.') 429 | break 430 | elif epoch + 1 == warm_up_epoch: 431 | print('--- Warm up finished ---') 432 | 433 | df = pd.concat([pd.DataFrame(training), pd.DataFrame(validation), pd.DataFrame(testing)], axis=1) 434 | df.columns = ['Training Accuracy', 'Training Loss', 'Validation Accuracy', 'Validation Loss', 'Testing Accuracy', 'Testing Loss'] 435 | df.to_csv(f'embedding_size={args.embedding_size},p={args.p},min_freq={args.min_freq},max_length={args.max_length},dropout={args.dropout},lr={args.lr},lr_decay_factor={args.lr_decay_factor},lr_decay_every={args.lr_decay_every},weight_decay={args.weight_decay},warm_up_epoch={args.warm_up_epoch},early_stopping_patience={args.early_stopping_patience},early_stopping_criteria={args.early_stopping_criteria},epoch={args.epoch}.csv') # Logging 436 | 437 | # import matplotlib.pyplot as plt 438 | 439 | # plt.plot(training['loss'], label='Training Loss') 440 | # plt.plot(validation['loss'], label='Validation Loss') 441 | # plt.plot(testing['loss'], label='Testing Loss') 442 | # plt.legend() 443 | # plt.xlabel('Epoch') 444 | # plt.ylabel('Loss') 445 | # plt.show() 446 | 447 | # plt.plot(training['accuracy'], label='Training Accuracy') 448 | # plt.plot(validation['accuracy'], label='Validation Accuracy') 449 | # plt.plot(testing['accuracy'], label='Testing Accuracy') 450 | # plt.legend() 451 | # plt.xlabel('Epoch') 452 | # plt.ylabel('Accuracy') 453 | # plt.show() --------------------------------------------------------------------------------