├── .gitignore ├── README.md ├── __pycache__ ├── model.cpython-37.pyc ├── train_eval.cpython-37.pyc └── utils.cpython-37.pyc ├── build_graph.py ├── data_processor.py ├── model.py ├── run.py ├── train_eval.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /datasets -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pyg-TextGCN 2 | 3 | An *efficient* and *simplify* re-implement TextGCN with Pytorch-geometric. Fork from [PyTorch_TextGCN](https://github.com/chengsen/PyTorch_TextGCN). Datasets can be found in that repo. 4 | 5 | 6 | 7 | ## Requirements 8 | 9 | This project was built with: 10 | 11 | - Python 3.7.0 12 | - Pytorch 1.9.0 13 | - scikit-learn 0.24.1 14 | - torch-geometric 1.7.2 15 | - numpy 1.19.5 16 | - pandas 1.1.5 17 | 18 | ## Quick Start 19 | 20 | Process the data first, `python data_processor.py` (Already done) 21 | 22 | Generate graph, `python build_graph.py` (Already done) 23 | 24 | Training model, `python run.py` 25 | 26 | ## References 27 | 28 | [Yao et al.: Graph Convolutional Networks for Text Classification](https://arxiv.org/abs/1809.05679) 29 | 30 | [Pytorch_geometric](https://github.com/rusty1s/pytorch_geometric) 31 | 32 | [PyTorch_TextGCN](https://github.com/chengsen/PyTorch_TextGCN) 33 | -------------------------------------------------------------------------------- /__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A11en0/Pyg-TextGCN/4cfcef50af7e6c92a8881494a3cafccabc363bcd/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/train_eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A11en0/Pyg-TextGCN/4cfcef50af7e6c92a8881494a3cafccabc363bcd/__pycache__/train_eval.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/A11en0/Pyg-TextGCN/4cfcef50af7e6c92a8881494a3cafccabc363bcd/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /build_graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import Counter 3 | 4 | import networkx as nx 5 | 6 | import itertools 7 | import math 8 | from collections import defaultdict 9 | from time import time 10 | 11 | from sklearn.feature_extraction.text import CountVectorizer 12 | from sklearn.feature_extraction.text import TfidfTransformer 13 | from sklearn.pipeline import Pipeline 14 | from tqdm import tqdm 15 | 16 | # from utils import print_graph_detail 17 | 18 | def get_window(content_lst, window_size): 19 | """ 20 | 找出窗口 21 | :param content_lst: 22 | :param window_size: 23 | :return: 24 | """ 25 | word_window_freq = defaultdict(int) # w(i) 单词在窗口单位内出现的次数 26 | word_pair_count = defaultdict(int) # w(i, j) 27 | windows_len = 0 28 | for words in tqdm(content_lst, desc="Split by window"): 29 | windows = list() 30 | 31 | if isinstance(words, str): 32 | words = words.split() 33 | length = len(words) 34 | 35 | if length <= window_size: 36 | windows.append(words) 37 | else: 38 | for j in range(length - window_size + 1): 39 | window = words[j: j + window_size] 40 | windows.append(list(set(window))) 41 | 42 | for window in windows: 43 | for word in window: 44 | word_window_freq[word] += 1 45 | 46 | for word_pair in itertools.combinations(window, 2): 47 | word_pair_count[word_pair] += 1 48 | 49 | windows_len += len(windows) 50 | return word_window_freq, word_pair_count, windows_len 51 | 52 | 53 | def cal_pmi(W_ij, W, word_freq_1, word_freq_2): 54 | p_i = word_freq_1 / W 55 | p_j = word_freq_2 / W 56 | p_i_j = W_ij / W 57 | pmi = math.log(p_i_j / (p_i * p_j)) 58 | 59 | return pmi 60 | 61 | 62 | def count_pmi(windows_len, word_pair_count, word_window_freq, threshold): 63 | word_pmi_lst = list() 64 | for word_pair, W_i_j in tqdm(word_pair_count.items(), desc="Calculate pmi between words"): 65 | word_freq_1 = word_window_freq[word_pair[0]] 66 | word_freq_2 = word_window_freq[word_pair[1]] 67 | 68 | pmi = cal_pmi(W_i_j, windows_len, word_freq_1, word_freq_2) 69 | if pmi <= threshold: 70 | continue 71 | word_pmi_lst.append([word_pair[0], word_pair[1], pmi]) 72 | return word_pmi_lst 73 | 74 | 75 | def get_pmi_edge(content_lst, window_size=20, threshold=0.): 76 | if isinstance(content_lst, str): 77 | content_lst = list(open(content_lst, "r")) 78 | print("pmi read file len:", len(content_lst)) 79 | 80 | pmi_start = time() 81 | word_window_freq, word_pair_count, windows_len = get_window(content_lst, 82 | window_size=window_size) 83 | 84 | pmi_edge_lst = count_pmi(windows_len, word_pair_count, word_window_freq, threshold) 85 | print("Total number of edges between word:", len(pmi_edge_lst)) 86 | pmi_time = time() - pmi_start 87 | return pmi_edge_lst, pmi_time 88 | 89 | 90 | class BuildGraph: 91 | def __init__(self, dataset): 92 | clean_corpus_path = "datasets/text_dataset/clean_corpus" 93 | self.graph_path = "datasets/graph" 94 | if not os.path.exists(self.graph_path): 95 | os.makedirs(self.graph_path) 96 | 97 | self.word2id = dict() # 单词映射 98 | self.dataset = dataset 99 | print(f"\n=== Dataset:{dataset}===") 100 | 101 | self.g = nx.Graph() 102 | 103 | self.content = f"{clean_corpus_path}/{dataset}.txt" 104 | 105 | self.get_tfidf_edge() 106 | self.get_pmi_edge() 107 | self.save() 108 | 109 | def get_pmi_edge(self): 110 | pmi_edge_lst, self.pmi_time = get_pmi_edge(self.content, window_size=20, threshold=0.0) 111 | print("pmi time:", self.pmi_time) 112 | 113 | for edge_item in pmi_edge_lst: 114 | word_indx1 = self.node_num + self.word2id[edge_item[0]] 115 | word_indx2 = self.node_num + self.word2id[edge_item[1]] 116 | if word_indx1 == word_indx2: 117 | continue 118 | self.g.add_edge(word_indx1, word_indx2, weight=edge_item[2]) 119 | 120 | # print_graph_detail(self.g) 121 | 122 | def get_tfidf_edge(self): 123 | # 获得tfidf权重矩阵(sparse)和单词列表 124 | tfidf_vec = self.get_tfidf_vec() 125 | 126 | count_lst = list() # 统计每个句子的长度 127 | for ind, row in tqdm(enumerate(tfidf_vec), 128 | desc="generate tfidf edge"): 129 | count = 0 130 | for col_ind, value in zip(row.indices, row.data): 131 | word_ind = self.node_num + col_ind 132 | self.g.add_edge(ind, word_ind, weight=value) 133 | count += 1 134 | count_lst.append(count) 135 | 136 | # print_graph_detail(self.g) 137 | 138 | def get_tfidf_vec(self): 139 | """ 140 | 学习获得tfidf矩阵,及其对应的单词序列 141 | :param content_lst: 142 | :return: 143 | """ 144 | start = time() 145 | text_tfidf = Pipeline([ 146 | ("vect", CountVectorizer(min_df=1, 147 | max_df=1.0, 148 | token_pattern=r"\S+", 149 | )), 150 | ("tfidf", TfidfTransformer(norm=None, 151 | use_idf=True, 152 | smooth_idf=False, 153 | sublinear_tf=False 154 | )) 155 | ]) 156 | 157 | tfidf_vec = text_tfidf.fit_transform(open(self.content, "r")) 158 | 159 | self.tfidf_time = time() - start 160 | print("tfidf time:", self.tfidf_time) 161 | print("tfidf_vec shape:", tfidf_vec.shape) 162 | print("tfidf_vec type:", type(tfidf_vec)) 163 | 164 | self.node_num = tfidf_vec.shape[0] 165 | 166 | # 映射单词 167 | vocab_lst = text_tfidf["vect"].get_feature_names() 168 | print("vocab_lst len:", len(vocab_lst)) 169 | for ind, word in enumerate(vocab_lst): 170 | self.word2id[word] = ind 171 | 172 | self.vocab_lst = vocab_lst 173 | 174 | return tfidf_vec 175 | 176 | def save(self): 177 | print("total time:", self.pmi_time + self.tfidf_time) 178 | nx.write_weighted_edgelist(self.g, 179 | f"{self.graph_path}/{self.dataset}.txt") 180 | 181 | print("\n") 182 | 183 | 184 | if __name__ == '__main__': 185 | BuildGraph("mr") 186 | # BuildGraph("ohsumed") 187 | # BuildGraph("R52") 188 | # BuildGraph("R8") 189 | # BuildGraph("20ng") -------------------------------------------------------------------------------- /data_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from collections import Counter 4 | from collections import defaultdict 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | 9 | 10 | class StringProcess(object): 11 | def __init__(self): 12 | self.other_char = re.compile(r"[^A-Za-z0-9(),!?\'\`]", flags=0) 13 | self.num = re.compile(r"[+-]?\d+\.?\d*", flags=0) 14 | # self.url = re.compile(r"[a-z]*[:.]+\S+|\n|\s+", flags=0) 15 | self.url = re.compile( 16 | r"(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", flags=0) 17 | self.stop_words = None 18 | self.nlp = None 19 | 20 | def clean_str(self, string): 21 | string = re.sub(self.other_char, " ", string) 22 | string = re.sub(r"\'s", " \'s", string) 23 | string = re.sub(r"\'ve", " \'ve", string) 24 | string = re.sub(r"n\'t", " n\'t", string) 25 | string = re.sub(r"\'re", " \'re", string) 26 | string = re.sub(r"\'d", " \'d", string) 27 | string = re.sub(r"\'ll", " \'ll", string) 28 | string = re.sub(r",", " , ", string) 29 | string = re.sub(r"!", " ! ", string) 30 | string = re.sub(r"\(", " \( ", string) 31 | string = re.sub(r"\)", " \) ", string) 32 | string = re.sub(r"\?", " \? ", string) 33 | string = re.sub(r"\s{2,}", " ", string) 34 | 35 | return string.strip().lower() 36 | 37 | def norm_str(self, string): 38 | string = re.sub(self.other_char, " ", string) 39 | 40 | if self.nlp is None: 41 | from spacy.lang.en import English 42 | self.nlp = English() 43 | 44 | new_doc = list() 45 | doc = self.nlp(string) 46 | for token in doc: 47 | if token.is_space or token.is_punct: 48 | continue 49 | if token.is_digit: 50 | token = "[num]" 51 | else: 52 | token = token.text 53 | 54 | new_doc.append(token) 55 | 56 | return " ".join(new_doc).lower() 57 | 58 | def lean_str_sst(self, string): 59 | """ 60 | Tokenization/string cleaning for the SST yelp_dataset 61 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 62 | """ 63 | string = re.sub(self.other_char, " ", string) 64 | string = re.sub(r"\s{2,}", " ", string) 65 | return string.strip().lower() 66 | 67 | def remove_stopword(self, string): 68 | if self.stop_words is None: 69 | from nltk.corpus import stopwords 70 | self.stop_words = set(stopwords.words('english')) 71 | 72 | if type(string) is str: 73 | string = string.split() 74 | 75 | new_string = list() 76 | for word in string: 77 | if word in self.stop_words: 78 | continue 79 | new_string.append(word) 80 | 81 | return " ".join(new_string) 82 | 83 | def replace_num(self, string): 84 | result = re.sub(self.num, '', string) 85 | return result 86 | 87 | def replace_urls(self, string): 88 | result = re.sub(self.url, '', string) 89 | result = ' '.join(re.split(' +|\n+', result)).strip() 90 | return result 91 | 92 | 93 | def remove_less_word(lines_str, word_st): 94 | return " ".join([word for word in lines_str.split() if word in word_st]) 95 | 96 | 97 | class CorpusProcess: 98 | def __init__(self, dataset, encoding=None): 99 | corpus_path = "data/text_dataset/corpus" 100 | clean_corpus_path = "data/text_dataset/clean_corpus" 101 | if not os.path.exists(clean_corpus_path): 102 | os.makedirs(clean_corpus_path) 103 | 104 | self.dataset = dataset 105 | self.corpus_name = f"{corpus_path}/{dataset}.txt" 106 | self.save_name = f"{clean_corpus_path}/{dataset}.txt" 107 | self.context_dct = defaultdict(dict) 108 | 109 | self.encoding = encoding 110 | self.clean_text() 111 | 112 | def clean_text(self): 113 | sp = StringProcess() 114 | word_lst = list() 115 | with open(self.corpus_name, mode="rb", encoding=self.encoding) as fin: 116 | for indx, item in tqdm(enumerate(fin), desc="clean the text"): 117 | data = item.strip().decode('latin1') 118 | data = sp.clean_str(data) 119 | if self.dataset not in {"mr"}: 120 | data = sp.remove_stopword(data) 121 | word_lst.extend(data.split()) 122 | 123 | word_st = set() 124 | if self.dataset not in {"mr"}: 125 | for word, value in Counter(word_lst).items(): 126 | if value < 5: 127 | continue 128 | word_st.add(word) 129 | else: 130 | word_st = set(word_lst) 131 | 132 | doc_len_lst = list() 133 | with open(self.save_name, mode='w') as fout: 134 | with open(self.corpus_name, mode="rb", encoding=self.encoding) as fin: 135 | for line in tqdm(fin): 136 | lines_str = line.strip().decode('latin1') 137 | lines_str = sp.clean_str(lines_str) 138 | if self.dataset not in {"mr"}: 139 | lines_str = sp.remove_stopword(lines_str) 140 | lines_str = remove_less_word(lines_str, word_st) 141 | 142 | fout.write(lines_str) 143 | fout.write(" \n") 144 | 145 | doc_len_lst.append(len(lines_str.split())) 146 | 147 | print("Average length:", np.mean(doc_len_lst)) 148 | print("doc count:", len(doc_len_lst)) 149 | print("Total number of words:", len(word_st)) 150 | 151 | 152 | def main(): 153 | CorpusProcess("R52") 154 | # CorpusProcess("20ng") 155 | # CorpusProcess("mr") 156 | # CorpusProcess("ohsumed") 157 | # CorpusProcess("R8") 158 | # pass 159 | 160 | 161 | if __name__ == '__main__': 162 | main() 163 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import GCNConv, ChebConv, GATConv # noqa 3 | import torch.nn.functional as F 4 | 5 | class GCN(torch.nn.Module): 6 | def __init__(self, nfeat, nhid, nclass, dropout): 7 | super(GCN, self).__init__() 8 | self.conv1 = GCNConv(nfeat, nhid, cached=True, normalize=True) 9 | self.conv2 = GCNConv(nhid, nclass, cached=True, normalize=True) 10 | # self.conv1 = ChebConv(nfeat, nhid, K=2) 11 | # self.conv2 = ChebConv(nhid, nclass, K=2) 12 | self.dropout = dropout 13 | 14 | def forward(self, data): 15 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 16 | x = F.relu(self.conv1(x, edge_index, edge_weight)) 17 | x = torch.dropout(x, self.dropout, train=self.training) 18 | x = self.conv2(x, edge_index, edge_weight) 19 | return F.log_softmax(x, dim=1) 20 | 21 | class GAT(torch.nn.Module): 22 | def __init__(self, nfeat, nhid, nclass, dropout): 23 | super(GAT, self).__init__() 24 | self.conv1 = GATConv(nfeat, nhid, heads=8, dropout=dropout) 25 | self.conv2 = GATConv(nhid * 8, nclass, heads=1, concat=False, 26 | dropout=dropout) 27 | self.dropout = dropout 28 | 29 | def forward(self, data): 30 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 31 | x = F.dropout(x, p=self.dropout, training=self.training) 32 | x = F.elu(self.conv1(x, edge_index)) 33 | x = F.dropout(x, p=self.dropout, training=self.training) 34 | x = self.conv2(x, edge_index) 35 | return F.log_softmax(x, dim=-1) -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | import argparse 4 | from model import GAT, GCN 5 | from train_eval import TextGCNTrainer 6 | from utils import return_seed, LoadData 7 | 8 | parser = argparse.ArgumentParser(description='TextGCN') 9 | parser.add_argument('--model', type=str, default='TextGCN', help='choose a model') 10 | args = parser.parse_args() 11 | 12 | 13 | def run(dataset, times): 14 | args.dataset = dataset 15 | args.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 16 | # args.device = torch.device('cpu') 17 | args.nhid = 100 18 | args.max_epoch = 200 19 | args.dropout = 0.5 20 | args.val_ratio = 0.1 21 | args.early_stopping = 10 22 | args.lr = 0.02 23 | model = GCN 24 | # model = GAT 25 | print(args) 26 | 27 | predata = LoadData(args) 28 | seed_lst = list() 29 | for ind, seed in enumerate(return_seed(times)): 30 | # print(f"\n\n==> {ind}, seed:{seed}") 31 | args.seed = seed 32 | seed_lst.append(seed) 33 | 34 | framework = TextGCNTrainer(model=model, args=args, pre_data=predata) 35 | framework.fit() 36 | 37 | # framework.test() 38 | # del framework 39 | # gc.collect() 40 | # 41 | # if torch.cuda.is_available(): 42 | # torch.cuda.empty_cache() 43 | 44 | # print("==> seed set:") 45 | # print(seed_lst) 46 | 47 | if __name__ == '__main__': 48 | run("mr", 1) 49 | # run("R8", 1) 50 | 51 | -------------------------------------------------------------------------------- /train_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from time import time 4 | from sklearn import metrics 5 | from sklearn.metrics import accuracy_score, f1_score 6 | from sklearn.model_selection import train_test_split 7 | 8 | from utils import get_time_dif 9 | 10 | 11 | class TextGCNTrainer: 12 | def __init__(self, args, model, pre_data): 13 | self.args = args 14 | self.model = model 15 | self.device = args.device 16 | self.max_epoch = self.args.max_epoch 17 | self.set_seed() 18 | self.dataset = args.dataset 19 | self.predata = pre_data 20 | self.earlystopping = EarlyStopping(args.early_stopping) 21 | self.criterion = torch.nn.CrossEntropyLoss() 22 | 23 | def set_seed(self): 24 | torch.manual_seed(self.args.seed) 25 | np.random.seed(self.args.seed) 26 | 27 | def fit(self): 28 | self.prepare_data() 29 | self.model = self.model(nfeat=self.predata.nfeat_dim, 30 | nhid=self.args.nhid, 31 | nclass=self.nclass, 32 | dropout=self.args.dropout) 33 | self.model = self.model.to(self.device) 34 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr) 35 | print(self.model.parameters) 36 | self.model_param = sum(param.numel() for param in self.model.parameters()) 37 | print('model parameters:', self.model_param) 38 | self.convert_tensor() 39 | self.train() 40 | self.test() 41 | 42 | def prepare_data(self): 43 | self.target = self.predata.target 44 | self.nclass = self.predata.nclass 45 | self.data = self.predata.graph 46 | 47 | self.train_lst, self.val_lst = train_test_split(self.predata.train_lst, 48 | test_size=self.args.val_ratio, 49 | shuffle=True, 50 | random_state=self.args.seed) 51 | self.test_lst = self.predata.test_lst 52 | 53 | def convert_tensor(self): 54 | self.model = self.model.to(self.device) 55 | self.data = self.data.to(self.device) 56 | self.target = torch.tensor(self.target).long().to(self.device) 57 | self.train_lst = torch.tensor(self.train_lst).long().to(self.device) 58 | self.val_lst = torch.tensor(self.val_lst).long().to(self.device) 59 | 60 | def train(self): 61 | start_time = time() 62 | for epoch in range(self.max_epoch): 63 | self.model.train() 64 | self.optimizer.zero_grad() 65 | 66 | logits = self.model.forward(self.data) 67 | loss = self.criterion(logits[self.train_lst], 68 | self.target[self.train_lst]) 69 | 70 | loss.backward() 71 | self.optimizer.step() 72 | pred = torch.max(logits[self.train_lst].data, 1)[1].cpu().numpy() 73 | target = self.target[self.train_lst].data.cpu().numpy() 74 | train_acc = accuracy_score(pred, target) 75 | val_loss, val_acc, val_f1 = self.val(self.val_lst) 76 | time_dif = get_time_dif(start_time) 77 | msg = 'Epoch: {:>2}, Train Loss: {:>6.3}, Train Acc: {:>6.2%}, Val Loss: {:>6.3}, Val Acc: {:>6.2%}, Time: {}' 78 | print(msg.format(epoch, loss.item(), train_acc, val_loss, val_acc, time_dif)) 79 | if self.earlystopping(val_loss): 80 | break 81 | 82 | @torch.no_grad() 83 | def val(self, x, test=False): 84 | self.model.eval() 85 | with torch.no_grad(): 86 | logits = self.model.forward(self.data) 87 | loss = self.criterion(logits[x], 88 | self.target[x]) 89 | 90 | pred = torch.max(logits[x].data, 1)[1].cpu().numpy() 91 | target = self.target[x].data.cpu().numpy() 92 | acc = accuracy_score(pred, target) 93 | f1 = f1_score(pred, target, average='macro') 94 | if test: 95 | report = metrics.classification_report(pred, target, digits=4) 96 | # report = metrics.classification_report(pred, target, target_names=config.class_list, digits=4) 97 | confusion = metrics.confusion_matrix(pred, target) 98 | return acc, report, confusion 99 | return loss.item(), acc, f1 100 | 101 | @torch.no_grad() 102 | def test(self): 103 | self.test_lst = torch.tensor(self.test_lst).long().to(self.device) 104 | acc, report, confusion = self.val(self.test_lst, test=True) 105 | msg = '\nTest Acc: {:>6.2%}' 106 | print(msg.format(acc)) 107 | print("Precision, Recall and F1-Score...") 108 | print(report) 109 | print("Confusion Matrix...") 110 | print(confusion) 111 | 112 | class EarlyStopping: 113 | """Early stops the training if validation loss doesn't improve after a given patience.""" 114 | 115 | def __init__(self, patience=7, verbose=False, delta=0): 116 | """ 117 | Args: 118 | patience (int): How long to wait after last time validation loss improved. 119 | Default: 7 120 | verbose (bool): If True, prints a message for each validation loss improvement. 121 | Default: False 122 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 123 | Default: 0 124 | """ 125 | self.patience = patience 126 | self.verbose = verbose 127 | self.counter = 0 128 | self.best_score = None 129 | self.early_stop = False 130 | self.val_loss_min = np.Inf 131 | self.delta = delta 132 | self.model_path = "hdd_data/prepare_dataset/model/model.pt" 133 | 134 | def __call__(self, val_loss, model=None): 135 | 136 | score = -val_loss 137 | 138 | if self.best_score is None: 139 | self.best_score = score 140 | # self.save_checkpoint(val_loss, model) 141 | elif score < self.best_score + self.delta: 142 | self.counter += 1 143 | if self.verbose: 144 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 145 | if self.counter >= self.patience: 146 | self.early_stop = True 147 | return True 148 | else: 149 | self.best_score = score 150 | # self.save_checkpoint(val_loss, model) 151 | self.counter = 0 152 | 153 | def save_checkpoint(self, val_loss, model): 154 | '''Saves model when validation loss decrease.''' 155 | if self.verbose: 156 | print( 157 | f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 158 | torch.save(model.state_dict(), self.model_path) 159 | self.val_loss_min = val_loss 160 | 161 | def load_model(self): 162 | return torch.load(self.model_path) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import random 4 | import pandas as pd 5 | import numpy as np 6 | from datetime import timedelta 7 | from torch_geometric.data import Data 8 | 9 | 10 | class LoadData: 11 | def __init__(self, args): 12 | # print("prepare data") 13 | self.graph_path = "datasets/graph" 14 | self.args = args 15 | self.nodes = set() 16 | 17 | # node 18 | edges = [] 19 | edge_weight = [] 20 | with open(f"{self.graph_path}/{args.dataset}.txt", "r") as f: 21 | for line in f.readlines(): 22 | val = line.split() 23 | if val[0] not in self.nodes: 24 | self.nodes.add(val[0]) 25 | if val[1] not in self.nodes: 26 | self.nodes.add(val[1]) 27 | edges.append([int(val[0]), int(val[1])]) 28 | edge_weight.append(float(val[2])) 29 | 30 | edge_index = torch.LongTensor(edges).t().contiguous() 31 | edge_weight = torch.FloatTensor(edge_weight) 32 | 33 | # feature 34 | self.nfeat_dim = len(self.nodes) 35 | row = list(range(self.nfeat_dim)) 36 | col = list(range(self.nfeat_dim)) 37 | value = [1.] * self.nfeat_dim 38 | shape = (self.nfeat_dim, self.nfeat_dim) 39 | indices = torch.from_numpy( 40 | np.vstack((row, col)).astype(np.int64)) 41 | values = torch.FloatTensor(value) 42 | shape = torch.Size(shape) 43 | 44 | # self.features = th.sparse.FloatTensor(indices, values, shape).to_dense() 45 | features = torch.sparse.FloatTensor(indices, values, shape) 46 | self.graph = Data(x=features, edge_index=edge_index, edge_attr=edge_weight) 47 | 48 | # target 49 | target_fn = f"datasets/text_dataset/{self.args.dataset}.txt" 50 | target = np.array(pd.read_csv(target_fn, 51 | sep="\t", 52 | header=None)[2]) 53 | target2id = {label: indx for indx, label in enumerate(set(target))} 54 | self.target = [target2id[label] for label in target] 55 | self.nclass = len(target2id) 56 | 57 | # train val test split 58 | self.train_lst, self.test_lst = get_train_test(target_fn) 59 | 60 | def get_train_test(target_fn): 61 | train_lst = list() 62 | test_lst = list() 63 | with read_file(target_fn, mode="r") as fin: 64 | for indx, item in enumerate(fin): 65 | if item.split("\t")[1] in {"train", "training", "20news-bydate-train"}: 66 | train_lst.append(indx) 67 | else: 68 | test_lst.append(indx) 69 | 70 | return train_lst, test_lst 71 | 72 | def read_file(path, mode='r', encoding=None): 73 | if mode not in {"r", "rb"}: 74 | raise ValueError("only read") 75 | return open(path, mode=mode, encoding=encoding) 76 | 77 | def return_seed(nums=10): 78 | seed = random.sample(range(0, 100000), nums) 79 | return seed 80 | 81 | def get_time_dif(start_time): 82 | """获取已使用时间""" 83 | end_time = time.time() 84 | time_dif = end_time - start_time 85 | return timedelta(seconds=int(round(time_dif))) --------------------------------------------------------------------------------